diff --git a/.codex b/.codex
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/README.md b/README.md
index 705fbcb9150b20d12ea933ca72787a58bd62454c..28b12d4812f6fec090c49b2f35e9395ce6e6bc8b 100644
--- a/README.md
+++ b/README.md
@@ -21,59 +21,149 @@ For events, please visit [vllm.ai/events](https://vllm.ai/events) to join us.
## About
-vLLM is a fast and easy-to-use library for LLM inference and serving.
+The model compression function of kv cache pruning has been added to the official vllm.
-Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
+vLLM prune with:
-vLLM is fast with:
+- [**SNAPKV**](https://arxiv.org/pdf/2404.14469)
+- [**COMPACTOR**](https://arxiv.org/pdf/2507.08143)
+- [**CRITICALADAKV**](https://arxiv.org/pdf/2502.03805)
-- State-of-the-art serving throughput
-- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
-- Continuous batching of incoming requests
-- Fast model execution with CUDA/HIP graph
-- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
-- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
-- Speculative decoding
-- Chunked prefill
-
-vLLM is flexible and easy to use with:
-
-- Seamless integration with popular Hugging Face models
-- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
-- Tensor, pipeline, data and expert parallelism support for distributed inference
-- Streaming outputs
-- OpenAI-compatible API server
-- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
-- Prefix caching support
-- Multi-LoRA support
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
-- Transformer-like LLMs (e.g., Llama)
-- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
-- Embedding Models (e.g., E5-Mistral)
-- Multi-modal LLMs (e.g., LLaVA)
-
-Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
+- Transformer-like LLMs (e.g., Qwen3/Llama)
-## Getting Started
-
-Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
+## Env
```bash
-pip install vllm
+cd vllm
+python use_existing_torch.py
+# then add torch in requires of pyproject.toml
+export SETUPTOOLS_SCM_PRETEND_VERSION_FOR_VLLM="0.6.0"
+pip install -e . --no-build-isolation -v -i https://mirrors.aliyun.com/pypi/simple/
+pip install numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/
```
-Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
+More related libraries:
+
+- flash_attn-2.8.3+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
+- torchvision-0.24.0+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
+- triton-3.5.1+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
+
+## Quick Start
+Basic Chat Generation with Compression:
+```
+python test.py --schedule pdtriton
+```
+test.py:
+
+```python
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# PYTHONPATH=/home/vllm-project/vllm python test.py --schedule pdtriton
+
+from __future__ import annotations
+
+import argparse
+import os
+import sys
+from multiprocessing import freeze_support
+
+
+def _apply_kvprune_attention_env(schedule: str | None) -> None:
+ """Map CLI -> VLLM_KVPRUNE_ATTENTION_SCHEDULE (fa_triton | pdtriton | pdfa)."""
+ if not schedule:
+ return
+ os.environ["VLLM_KVPRUNE_ATTENTION_SCHEDULE"] = schedule
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--schedule",
+ type=str,
+ default="pdtriton",
+ choices=("fa_triton", "pdtriton", "pdfa"),
+ help=(
+ "fa_triton=FA prefill + Triton decode;"
+ "pdtriton=Triton prefill + Triton decode;"
+ "pdfa=FA prefill + FA decode (page KV writing is Triton);"
+ ),
+ )
+ args, _unknown = parser.parse_known_args()
+ _apply_kvprune_attention_env(args.schedule)
+
+ from transformers import AutoTokenizer
+
+ from vllm import CompressionParams, LLM, SamplingParams
+
+ model_id = "Qwen/Qwen3-8B"
+
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+ sampling_params = SamplingParams(
+ temperature=0.7,
+ top_p=0.8,
+ repetition_penalty=1.05,
+ max_tokens=512,
+ )
+
+ llm = LLM(
+ model=model_id,
+ tensor_parallel_size=4,
+ max_model_len=8192,
+ gpu_memory_utilization=0.85,
+ kvprune_compression=True,
+ )
+
+ prompt = (
+ "Write a 200-word English prompt for a creative writing task. The prompt should be "
+ "a single coherent paragraph without any bullet points, numbered lists, or markdown "
+ "formatting. It should describe a specific scenario, character, or conflict, and end "
+ "with a clear question that invites the writer to continue the story. Do not use any "
+ "special symbols or line breaks. The tone can be mysterious, tense, or reflective. "
+ "After the paragraph, include the question on the same line directly following the "
+ "period, without hitting enter."
+ )
+
+ messages = [{"role": "user", "content": prompt}]
+
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True, # True
+ )
+
+ compression = [
+ CompressionParams(
+ compression_ratio=0.5,
+ compression_method="snapkv",
+ ),
+ ]
+
+ outputs = llm.generate(
+ [text],
+ sampling_params=sampling_params,
+ compression=compression,
+ )
+
+ for output in outputs:
+ generated_text = output.outputs[0].text
+ print(f"Generated text: {generated_text!r}")
+
+
+if __name__ == "__main__":
+ freeze_support()
+ main()
+```
-- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
-- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
-- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
## Contributing
We welcome and value any contributions and collaborations.
-Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
## Citation
diff --git a/README_vllm.md b/README_vllm.md
new file mode 100644
index 0000000000000000000000000000000000000000..705fbcb9150b20d12ea933ca72787a58bd62454c
--- /dev/null
+++ b/README_vllm.md
@@ -0,0 +1,103 @@
+
+
+
+
+
+
+
+
+
+Easy, fast, and cheap LLM serving for everyone
+
+
+
+| Documentation | Blog | Paper | Twitter/X | User Forum | Developer Slack |
+
+
+🔥 We have built a vllm website to help you get started with vllm. Please visit [vllm.ai](https://vllm.ai) to learn more.
+For events, please visit [vllm.ai/events](https://vllm.ai/events) to join us.
+
+---
+
+## About
+
+vLLM is a fast and easy-to-use library for LLM inference and serving.
+
+Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
+
+vLLM is fast with:
+
+- State-of-the-art serving throughput
+- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
+- Continuous batching of incoming requests
+- Fast model execution with CUDA/HIP graph
+- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
+- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
+- Speculative decoding
+- Chunked prefill
+
+vLLM is flexible and easy to use with:
+
+- Seamless integration with popular Hugging Face models
+- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
+- Tensor, pipeline, data and expert parallelism support for distributed inference
+- Streaming outputs
+- OpenAI-compatible API server
+- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
+- Prefix caching support
+- Multi-LoRA support
+
+vLLM seamlessly supports most popular open-source models on HuggingFace, including:
+
+- Transformer-like LLMs (e.g., Llama)
+- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
+- Embedding Models (e.g., E5-Mistral)
+- Multi-modal LLMs (e.g., LLaVA)
+
+Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
+
+## Getting Started
+
+Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
+
+```bash
+pip install vllm
+```
+
+Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
+
+- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
+- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
+- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
+
+## Contributing
+
+We welcome and value any contributions and collaborations.
+Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
+
+## Citation
+
+If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
+
+```bibtex
+@inproceedings{kwon2023efficient,
+ title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
+ author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
+ booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
+ year={2023}
+}
+```
+
+## Contact Us
+
+
+- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
+- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
+- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
+- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
+- For collaborations and partnerships, please contact us at [collaboration@vllm.ai](mailto:collaboration@vllm.ai)
+
+
+## Media Kit
+
+- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit)
diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu
index 758a777955535e0a948f63c810a5fdef4c1b1e11..a5197425878521bd27c1d94860781d8594e92e24 100644
--- a/csrc/activation_kernels.cu
+++ b/csrc/activation_kernels.cu
@@ -8,6 +8,16 @@
#include "cuda_vec_utils.cuh"
#include "dispatch_utils.h"
+// ROCm/HIP often assumes at most 256 threads per block unless the kernel
+// declares otherwise; launching more triggers runtime warnings / UB. NVIDIA
+// CUDA builds keep the original 1024 cap. No __launch_bounds__ on templated
+// kernels here — HIP/clang can fail to compile those (see act_and_mul_kernel).
+#ifdef USE_ROCM
+#define VLLM_ACTIVATION_GATE_MAX_THREADS 256
+#else
+#define VLLM_ACTIVATION_GATE_MAX_THREADS 1024
+#endif
+
namespace vllm {
template = 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
@@ -191,7 +201,7 @@ packed_gelu_tanh_kernel(const packed_t& val) {
}); \
} \
} else { \
- dim3 block(std::min(d, 1024)); \
+ dim3 block(std::min(d, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter::Type, \
@@ -387,7 +397,7 @@ __global__ void swigluoai_and_mul_kernel(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
- dim3 block(std::min(d / vec_size, 1024)); \
+ dim3 block(std::min(d / vec_size, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
@@ -414,7 +424,7 @@ __global__ void swigluoai_and_mul_kernel(
}); \
} \
} else { \
- dim3 block(std::min(d, 1024)); \
+ dim3 block(std::min(d, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter::Type, \
@@ -429,7 +439,7 @@ __global__ void swigluoai_and_mul_kernel(
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
- dim3 block(std::min(d, 1024)); \
+ dim3 block(std::min(d, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
@@ -520,7 +530,7 @@ __global__ void activation_kernel(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
- dim3 block(std::min(d / vec_size, 1024)); \
+ dim3 block(std::min(d / vec_size, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel, true, true> \
@@ -535,7 +545,7 @@ __global__ void activation_kernel(
}); \
} \
} else { \
- dim3 block(std::min(d, 1024)); \
+ dim3 block(std::min(d, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel, false> \
<<>>(out.data_ptr(), \
diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu
index b5645b33b9073f7265527e3d1a98cd572ad6234c..3d8e845e469c0e4c7560afcde54e720713ea3c5b 100644
--- a/csrc/pos_encoding_kernels.cu
+++ b/csrc/pos_encoding_kernels.cu
@@ -74,7 +74,10 @@ inline __device__ void apply_rotary_embedding(
}
template
-__global__ void rotary_embedding_kernel(
+// HIP/ROCm commonly enforces max 256 threads/block unless explicitly raised.
+// Keep blockDim.x <= 256 here so launch always matches compiler bounds (avoids
+// UB when __launch_bounds__ and host code get out of sync across rebuilds).
+__global__ __launch_bounds__(256, 1) void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
@@ -162,7 +165,7 @@ void rotary_embedding(
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens);
- dim3 block(std::min(num_heads * rot_dim / 2, 512));
+ dim3 block(std::min(num_heads * rot_dim / 2, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
diff --git a/examples/offline_inference/qwen3_kvprune_chat_inference.py b/examples/offline_inference/qwen3_kvprune_chat_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d845109d0c5dda5622cbdeb9ebd5b7a01c53a3f
--- /dev/null
+++ b/examples/offline_inference/qwen3_kvprune_chat_inference.py
@@ -0,0 +1,74 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+python -P test-vllmkvprune.py
+"""
+
+from __future__ import annotations
+
+import os
+import sys
+from multiprocessing import freeze_support
+
+from transformers import AutoTokenizer
+
+
+from vllm import CompressionParams, LLM, SamplingParams
+
+
+def main() -> None:
+ model_id = "Qwen/Qwen3-8B"
+
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+ sampling_params = SamplingParams(
+ temperature=0.7,
+ top_p=0.8,
+ repetition_penalty=1.05,
+ max_tokens=512,
+ )
+
+ # TP=1:进程内共享权重 compactor。TP>1 时请用 ``qwen3_kvprune_tp4_inference.py`` 或自行设
+ # ``tensor_parallel_size>=2`` 并传 ``compression``(走 collective_rpc + 每卡 ModelRunner)。
+ llm = LLM(
+ model=model_id,
+ tensor_parallel_size=1,
+ max_model_len=8192,
+ gpu_memory_utilization=0.85,
+ kvprune_compression=True,
+ )
+
+ prompt = "Give me a short introduction to large language models."
+ messages = [
+ {"role": "user", "content": prompt},
+ ]
+
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True,
+ )
+
+ # 剪枝:compression_ratio < 1 时走 compactor;每条 prompt 对应一条 CompressionParams。
+ compression = [
+ CompressionParams(
+ compression_ratio=0.5,
+ compression_method="compactor",
+ ),
+ ]
+
+ outputs = llm.generate(
+ [text],
+ sampling_params=sampling_params,
+ compression=compression,
+ )
+
+ for output in outputs:
+ generated_text = output.outputs[0].text
+ print(f"Generated text: {generated_text!r}")
+
+
+if __name__ == "__main__":
+ freeze_support()
+ main()
diff --git a/pyproject.toml b/pyproject.toml
index 64a6de30e225c5e0cb973215ca58298cba4f5a3b..3822cd12bc2a48018ad55401f43d93a1daee5539 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,7 +6,7 @@ requires = [
"packaging>=24.2",
"setuptools>=77.0.3,<81.0.0",
"setuptools-scm>=8.0",
- "torch == 2.10.0",
+ "torch",
"wheel",
"jinja2",
]
@@ -171,4 +171,3 @@ ser = "ser"
ure = "ure"
[tool.uv]
-no-build-isolation-package = ["torch"]
diff --git a/requirements/build.txt b/requirements/build.txt
index c46880a05ebb0201477a0ee365b3cf42370fb645..639de69bb86ea9f0f284b40e974bf341528aa244 100644
--- a/requirements/build.txt
+++ b/requirements/build.txt
@@ -4,7 +4,6 @@ ninja
packaging>=24.2
setuptools>=77.0.3,<81.0.0
setuptools-scm>=8
-torch==2.10.0
wheel
jinja2>=3.1.6
regex
diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt
index 3893b0026978bc24c04100ae00730bac9c308e8c..095351c863ce11b28d22a984b6137a80706c5e3b 100644
--- a/requirements/cpu-build.txt
+++ b/requirements/cpu-build.txt
@@ -3,8 +3,6 @@ ninja
packaging>=24.2
setuptools==77.0.3 # this version can reuse CMake build dir
setuptools-scm>=8
-torch==2.10.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
-torch==2.10.0; platform_machine == "aarch64" or platform_system == "Darwin" or platform_machine == "ppc64le"
wheel
jinja2>=3.1.6
regex
diff --git a/requirements/cpu.txt b/requirements/cpu.txt
index 378f61ba868620bd102ae0cc9786b05faaa7d4ab..d82b6cd2b3db2e7c45ccbc2d85f7f8afc03d3096 100644
--- a/requirements/cpu.txt
+++ b/requirements/cpu.txt
@@ -6,16 +6,9 @@ setuptools==77.0.3 # this version can reuse CMake build dir
numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding
# Dependencies for CPUs
-torch==2.10.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
-torch==2.10.0; platform_machine == "aarch64" or platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "riscv64"
-# required for the image processor of minicpm-o-2_6, this must be updated alongside torch
-torchaudio; platform_machine != "s390x" and platform_machine != "riscv64"
-# required for the image processor of phi3v, this must be updated alongside torch
-torchvision; platform_machine != "s390x" and platform_machine != "riscv64"
-# Intel Extension for PyTorch, only for x86_64 CPUs
intel-openmp==2024.2.1; platform_machine == "x86_64"
# Use this to gather CPU info and optimize based on ARM Neoverse cores
diff --git a/requirements/cuda.txt b/requirements/cuda.txt
index 44b7c38093d2692ba59eeaf99698d56ae35df6e3..57204f5d55f26525ec5bdfa2cef5a3de938d38c9 100644
--- a/requirements/cuda.txt
+++ b/requirements/cuda.txt
@@ -4,10 +4,6 @@
numba == 0.61.2 # Required for N-gram speculative decoding
# Dependencies for NVIDIA GPUs
-torch==2.10.0
-torchaudio==2.10.0
-# These must be updated alongside torch
-torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.6
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt
index 6f96c7d55742b054fcd480d3b162d53f38850bb1..e5f5a22f41bffac46aa57239210632fa5dc6d66a 100644
--- a/requirements/rocm-build.txt
+++ b/requirements/rocm-build.txt
@@ -1,10 +1,6 @@
# Common dependencies
-r common.txt
---extra-index-url https://download.pytorch.org/whl/rocm7.1
-torch==2.10.0
-torchvision==0.25.0
-torchaudio==2.10.0
triton==3.6.0
cmake>=3.26.1,<4
packaging>=24.2
diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt
index 9014ab1eaf899dce38edc4e9bbfa63b7cb46134b..730971608b43d28b0ef49e56c884811208a4e225 100644
--- a/requirements/rocm-test.txt
+++ b/requirements/rocm-test.txt
@@ -69,8 +69,6 @@ multiprocess==0.70.16
# Required for v1/metrics/test_engine_logger_apis.py
ray[cgraph,default]>=2.48.0
-torchgeo==0.7.0
- # via terratorch
# MTEB Benchmark Test
mteb[bm25s]>=2, <3
@@ -85,7 +83,6 @@ fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.
# Required for suffix decoding test
arctic-inference == 0.1.1
# Required for Nemotron test
-open-clip-torch==2.32.0
# Required for isaac Multi-Modal generation test
perceptron==0.1.4
# Required for the multi-modal models test
@@ -99,9 +96,7 @@ huggingface-hub==0.36.2
# Pin Mistral Common
mistral-common[image,audio]==1.10.0
# Required for Prithvi tests
-terratorch==1.2.2
# Required for Prithvi tests
-segmentation-models-pytorch==0.5.0
# Required for Prithvi tests
imagehash==4.3.2
# Required for bitsandbytes quantization test
diff --git a/requirements/rocm.txt b/requirements/rocm.txt
index 7bbd70550ebf0ade8eaee84f3c7f0faf7a4a5198..acdef5c348e93fccd3f967dd5c2e6de5c8a166d8 100644
--- a/requirements/rocm.txt
+++ b/requirements/rocm.txt
@@ -23,6 +23,4 @@ timm>=1.0.17
amd-quark>=0.8.99
# Other necessary dependencies
-torch == 2.10.0
-torchvision == 0.25.0
flash_attn == 2.8.3
\ No newline at end of file
diff --git a/requirements/test.in b/requirements/test.in
index 8bd00514435b4f2e8dba260e23216be9d808a4b0..72cb4eb85b8153936d8c4b08fd04eccb1e4d2cb3 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -16,7 +16,6 @@ blobfile # required for kimi-vl test
einops # required for MPT, qwen-vl
httpx
librosa # required for audio tests
-vector_quantize_pytorch # required for minicpmo_26 test
vocos # required for minicpmo_26 test
peft>=0.15.0 # required for phi-4-mm test
pqdm
@@ -26,14 +25,10 @@ soundfile # required for audio tests
jiwer # required for audio tests
tblib # for pickling test exceptions
timm >=1.0.17 # required for internvl and gemma3n-mm test
-torch==2.10.0
-torchaudio==2.10.0
-torchvision==0.25.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.9.1 # required for voxtral test
num2words # required for smolvlm test
-open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless >= 4.13.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]>=0.4.11 # required for model evaluation test
@@ -61,17 +56,13 @@ fastsafetensors>=0.2.2 # 0.2.2 contains important fixes for multi-GPU mem usage
instanttensor>=0.1.5
pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0
-terratorch >= 1.2.2 # Required for Prithvi tests
imagehash # Required for Prithvi tests
-segmentation-models-pytorch > 0.4.0 # Required for Prithvi tests
gpt-oss >= 0.0.7; python_version > '3.11'
perceptron # required for isaac test
kaldi-native-fbank >= 1.18.7 # required for fireredasr2 test
-# Newer versions of datasets require torchcoded, that makes the tests fail in CI because of a missing library.
-# Older versions are in conflict with teerratorch requirements.
datasets>=3.3.0,<=3.6.0
openpyxl # required for perf comparison excel report
diff --git a/requirements/test.txt b/requirements/test.txt
index e2f9040beecc099958a1a07d9d4e31f085fec010..7c94a62f9acfc8f6aaeeb72f9d355dbe56a2e491 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -1,5 +1,4 @@
# This file was autogenerated by uv via the following command:
-# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28 --python-version 3.12
absl-py==2.1.0
# via
# rouge-score
@@ -25,11 +24,9 @@ aiohttp-cors==0.8.1
aiosignal==1.4.0
# via aiohttp
albucore==0.0.16
- # via terratorch
albumentations==1.4.6
# via
# -r requirements/test.in
- # terratorch
alembic==1.16.4
# via optuna
annotated-doc==0.0.4
@@ -165,7 +162,6 @@ cryptography==46.0.5
# msal
# pyjwt
cuda-bindings==12.9.4
- # via torch
cuda-pathfinder==1.3.3
# via cuda-bindings
cupy-cuda12x==13.6.0
@@ -189,7 +185,6 @@ decorator==5.1.1
decord==0.6.0
# via -r requirements/test.in
diffusers==0.36.0
- # via terratorch
dill==0.3.8
# via
# datasets
@@ -210,12 +205,8 @@ einops==0.8.1
# via
# -r requirements/test.in
# encodec
- # terratorch
- # torchgeo
- # vector-quantize-pytorch
# vocos
einx==0.3.0
- # via vector-quantize-pytorch
email-validator==2.2.0
# via pydantic
encodec==0.1.1
@@ -239,11 +230,9 @@ filelock==3.16.1
# diffusers
# huggingface-hub
# ray
- # torch
# transformers
# virtualenv
fiona==1.10.1
- # via torchgeo
fonttools==4.55.0
# via matplotlib
fqdn==1.5.1
@@ -261,17 +250,13 @@ fsspec==2024.12.0
# fastparquet
# huggingface-hub
# lightning
- # pytorch-lightning
# tacoreader
- # torch
ftfy==6.3.1
- # via open-clip-torch
genai-perf==0.0.16
# via -r requirements/test.in
genson==1.3.0
# via datamodel-code-generator
geopandas==1.0.1
- # via terratorch
gitdb==4.0.12
# via gitpython
gitpython==3.1.44
@@ -320,7 +305,6 @@ h11==0.14.0
h2==4.3.0
# via httpx
h5py==3.13.0
- # via terratorch
harfile==0.3.0
# via schemathesis
hf-xet==1.1.7
@@ -345,11 +329,8 @@ huggingface-hub==0.36.2
# datasets
# diffusers
# evaluate
- # open-clip-torch
# peft
- # segmentation-models-pytorch
# sentence-transformers
- # terratorch
# timm
# tokenizers
# transformers
@@ -406,7 +387,6 @@ jinja2==3.1.6
# datamodel-code-generator
# genai-perf
# lm-eval
- # torch
jiwer==3.0.5
# via -r requirements/test.in
jmespath==1.0.1
@@ -421,7 +401,6 @@ joblib==1.4.2
jsonargparse==4.46.0
# via
# lightning
- # terratorch
jsonlines==4.0.0
# via lm-eval
jsonnet==0.21.0
@@ -445,7 +424,6 @@ kaleido==0.2.1
kiwisolver==1.4.7
# via matplotlib
kornia==0.8.1
- # via torchgeo
kornia-rs==0.1.9
# via kornia
lazy-loader==0.4
@@ -458,19 +436,13 @@ librosa==0.10.2.post1
# via -r requirements/test.in
lightly==1.5.22
# via
- # terratorch
- # torchgeo
lightly-utils==0.0.2
# via lightly
lightning==2.6.1
# via
- # terratorch
- # torchgeo
lightning-utilities==0.14.3
# via
# lightning
- # pytorch-lightning
- # torchmetrics
llvmlite==0.44.0
# via numba
lm-eval==0.4.11
@@ -496,7 +468,6 @@ matplotlib==3.9.2
# -r requirements/test.in
# lightning
# pycocotools
- # torchgeo
mbstrdecoder==1.1.3
# via
# dataproperty
@@ -535,7 +506,6 @@ mypy-extensions==1.0.0
networkx==3.2.1
# via
# scikit-image
- # torch
nltk==3.9.1
# via rouge-score
num2words==0.5.14
@@ -591,18 +561,13 @@ numpy==2.2.6
# scikit-image
# scikit-learn
# scipy
- # segmentation-models-pytorch
# shapely
# soxr
# statsmodels
# tensorboard
# tensorboardx
# tensorizer
- # terratorch
# tifffile
- # torchgeo
- # torchmetrics
- # torchvision
# transformers
# tritonclient
# vocos
@@ -611,46 +576,30 @@ nvidia-cublas-cu12==12.9.1.4
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
- # torch
nvidia-cuda-cupti-cu12==12.9.79
- # via torch
nvidia-cuda-nvrtc-cu12==12.9.86
- # via torch
nvidia-cuda-runtime-cu12==12.9.79
- # via torch
nvidia-cudnn-cu12==9.10.2.21
- # via torch
nvidia-cufft-cu12==11.4.1.4
- # via torch
nvidia-cufile-cu12==1.14.1.1
- # via torch
nvidia-curand-cu12==10.3.10.19
- # via torch
nvidia-cusolver-cu12==11.7.5.82
- # via torch
nvidia-cusparse-cu12==12.5.10.65
# via
# nvidia-cusolver-cu12
- # torch
nvidia-cusparselt-cu12==0.7.1
- # via torch
nvidia-nccl-cu12==2.27.5
- # via torch
nvidia-nvjitlink-cu12==12.9.86
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
- # torch
nvidia-nvshmem-cu12==3.4.5
- # via torch
nvidia-nvtx-cu12==12.9.79
- # via torch
omegaconf==2.3.0
# via
# hydra-core
# lightning
-open-clip-torch==2.32.0
# via -r requirements/test.in
openai-harmony==0.0.4
# via gpt-oss
@@ -709,14 +658,12 @@ packaging==24.2
# pyogrio
# pytest
# pytest-rerunfailures
- # pytorch-lightning
# ray
# rioxarray
# scikit-image
# statsmodels
# tensorboard
# tensorboardx
- # torchmetrics
# transformers
# typepy
# wandb
@@ -730,7 +677,6 @@ pandas==2.2.3
# geopandas
# statsmodels
# tacoreader
- # torchgeo
# xarray
pathspec==0.12.1
# via black
@@ -755,10 +701,7 @@ pillow==10.4.0
# mistral-common
# perceptron
# scikit-image
- # segmentation-models-pytorch
# tensorboard
- # torchgeo
- # torchvision
platformdirs==4.3.6
# via
# black
@@ -817,7 +760,6 @@ pyarrow==23.0.0
# datasets
# genai-perf
# tacoreader
- # terratorch
pyasn1==0.6.1
# via
# pyasn1-modules
@@ -825,7 +767,6 @@ pyasn1==0.6.1
pyasn1-modules==0.4.2
# via google-auth
pycocotools==2.0.8
- # via terratorch
pycountry==24.6.1
# via pydantic-extra-types
pycparser==2.22
@@ -864,7 +805,6 @@ pyproj==3.7.1
# via
# geopandas
# rioxarray
- # torchgeo
pyrate-limiter==3.7.0
# via schemathesis
pystemmer==3.0.0
@@ -902,7 +842,6 @@ pytest-subtests==0.14.1
pytest-timeout==2.3.1
# via -r requirements/test.in
python-box==7.3.2
- # via terratorch
python-dateutil==2.9.0.post0
# via
# arrow
@@ -913,7 +852,6 @@ python-dateutil==2.9.0.post0
# typepy
python-rapidjson==1.20
# via tritonclient
-pytorch-lightning==2.5.2
# via
# lightly
# lightning
@@ -938,7 +876,6 @@ pyyaml==6.0.2
# omegaconf
# optuna
# peft
- # pytorch-lightning
# ray
# responses
# schemathesis
@@ -951,8 +888,6 @@ rapidfuzz==3.12.1
rasterio==1.4.3
# via
# rioxarray
- # terratorch
- # torchgeo
ray==2.48.0
# via -r requirements/test.in
redis==5.2.0
@@ -965,7 +900,6 @@ regex==2024.9.11
# via
# diffusers
# nltk
- # open-clip-torch
# sacrebleu
# tiktoken
# transformers
@@ -1007,10 +941,8 @@ rich==13.9.4
# lightning
# mteb
# perceptron
- # terratorch
# typer
rioxarray==0.19.0
- # via terratorch
rouge-score==0.1.2
# via lm-eval
rpds-py==0.20.1
@@ -1020,7 +952,6 @@ rpds-py==0.20.1
rsa==4.9.1
# via google-auth
rtree==1.4.0
- # via torchgeo
runai-model-streamer==0.15.7
# via -r requirements/test.in
runai-model-streamer-azure==0.15.7
@@ -1037,9 +968,7 @@ safetensors==0.4.5
# via
# accelerate
# diffusers
- # open-clip-torch
# peft
- # segmentation-models-pytorch
# timm
# transformers
schemathesis==3.39.15
@@ -1047,7 +976,6 @@ schemathesis==3.39.15
scikit-image==0.25.2
# via
# albumentations
- # terratorch
scikit-learn==1.5.2
# via
# albumentations
@@ -1055,7 +983,6 @@ scikit-learn==1.5.2
# lm-eval
# mteb
# sentence-transformers
- # terratorch
scipy==1.13.1
# via
# albumentations
@@ -1068,11 +995,8 @@ scipy==1.13.1
# sentence-transformers
# statsmodels
# vocos
-segmentation-models-pytorch==0.5.0
# via
# -r requirements/test.in
- # terratorch
- # torchgeo
sentence-transformers==5.2.0
# via
# -r requirements/test.in
@@ -1084,11 +1008,9 @@ setuptools==77.0.3
# lightning-utilities
# pytablewriter
# tensorboard
- # torch
shapely==2.1.1
# via
# geopandas
- # torchgeo
shellingham==1.5.4
# via
# perceptron
@@ -1141,13 +1063,11 @@ structlog==25.4.0
sympy==1.13.3
# via
# einx
- # torch
tabledata==1.3.3
# via pytablewriter
tabulate==0.9.0
# via sacrebleu
tacoreader==0.5.6
- # via terratorch
tblib==3.1.0
# via -r requirements/test.in
tcolorpy==0.1.6
@@ -1158,7 +1078,6 @@ tenacity==9.1.2
# lm-eval
# plotly
tensorboard==2.20.0
- # via terratorch
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
@@ -1168,15 +1087,12 @@ tensorizer==2.10.1
termcolor==3.1.0
# via
# gpt-oss
- # terratorch
-terratorch==1.2.2
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn
tifffile==2025.3.30
# via
# scikit-image
- # terratorch
tiktoken==0.12.0
# via
# gpt-oss
@@ -1185,10 +1101,6 @@ tiktoken==0.12.0
timm==1.0.17
# via
# -r requirements/test.in
- # open-clip-torch
- # segmentation-models-pytorch
- # terratorch
- # torchgeo
tokenizers==0.22.0
# via
# -r requirements/test.in
@@ -1197,7 +1109,6 @@ tomli==2.2.1
# via schemathesis
tomli-w==1.2.0
# via schemathesis
-torch==2.10.0+cu129
# via
# -r requirements/test.in
# accelerate
@@ -1208,43 +1119,22 @@ torch==2.10.0+cu129
# lightly
# lightning
# mteb
- # open-clip-torch
# peft
- # pytorch-lightning
# runai-model-streamer
- # segmentation-models-pytorch
# sentence-transformers
# tensorizer
- # terratorch
# timm
- # torchaudio
- # torchgeo
- # torchmetrics
- # torchvision
- # vector-quantize-pytorch
# vocos
-torchaudio==2.10.0+cu129
# via
# -r requirements/test.in
# encodec
# vocos
-torchgeo==0.7.0
- # via terratorch
-torchmetrics==1.7.4
# via
# lightning
- # pytorch-lightning
- # terratorch
- # torchgeo
-torchvision==0.25.0+cu129
# via
# -r requirements/test.in
# lightly
- # open-clip-torch
- # segmentation-models-pytorch
- # terratorch
# timm
- # torchgeo
tqdm==4.67.3
# via
# datasets
@@ -1255,15 +1145,11 @@ tqdm==4.67.3
# lm-eval
# mteb
# nltk
- # open-clip-torch
# optuna
# peft
# pqdm
- # pytorch-lightning
- # segmentation-models-pytorch
# sentence-transformers
# tacoreader
- # terratorch
# transformers
transformers==4.57.5
# via
@@ -1275,7 +1161,6 @@ transformers==4.57.5
transformers-stream-generator==0.0.5
# via -r requirements/test.in
triton==3.6.0
- # via torch
tritonclient==2.64.0
# via -r requirements/test.in
typepy==1.3.2
@@ -1316,12 +1201,9 @@ typing-extensions==4.15.0
# pydantic
# pydantic-core
# pydantic-extra-types
- # pytorch-lightning
# sentence-transformers
# sqlalchemy
# starlette
- # torch
- # torchgeo
# typer
# typeshed-client
# typing-inspection
@@ -1344,14 +1226,12 @@ urllib3==2.2.3
# tritonclient
uvicorn==0.35.0
# via gpt-oss
-vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
virtualenv==20.31.2
# via ray
vocos==0.1.0
# via -r requirements/test.in
wandb==0.24.2
- # via terratorch
wcwidth==0.2.13
# via ftfy
webcolors==24.11.1
diff --git a/requirements/xpu.txt b/requirements/xpu.txt
index 3271f9f392758ce6a1e51665c4574a55f2e2dc46..95c1b8d1b3952b58fc6e162b27893b6703d6320a 100644
--- a/requirements/xpu.txt
+++ b/requirements/xpu.txt
@@ -10,9 +10,5 @@ wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.61.2 # Required for N-gram speculative decoding
---extra-index-url=https://download.pytorch.org/whl/xpu
-torch==2.10.0+xpu
-torchaudio
-torchvision
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.3/vllm_xpu_kernels-0.1.3-cp38-abi3-linux_x86_64.whl
diff --git a/tests/conftest.py b/tests/conftest.py
index 719bfa5ed1f044cc7d2fb85c94382e72f64eeeb3..40350375d8b26c3fe9ce3c22281e3ca8ac804d3d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,5 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
+
+# Keep v1 CUDA-graph tests valid: fork default for ``LLM(kvprune_compression=None)``
+# is controlled by ``VLLM_KVPRUNE_COMPRESSION_DEFAULT`` (see ``env_override``).
+os.environ.setdefault("VLLM_KVPRUNE_COMPRESSION_DEFAULT", "0")
+os.environ.setdefault("VLLM_KVPRUNE_RELEASE_V1_KV", "0")
import contextlib
import pathlib
from copy import deepcopy
diff --git a/tests/kvprune/__init__.py b/tests/kvprune/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..208f01a7cb5ee04c88d276fec2082cd4e830884b
--- /dev/null
+++ b/tests/kvprune/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/tests/kvprune/evaluate/eval_longbench.py b/tests/kvprune/evaluate/eval_longbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab39a62a24f07a0e0ad88b0a0304b06afe317e3b
--- /dev/null
+++ b/tests/kvprune/evaluate/eval_longbench.py
@@ -0,0 +1,153 @@
+# SPDX-License-Identifier: Apache-2.0
+"""LongBench evaluation via vLLM ``LLM.generate`` + kvprune compression (same folder layout as RULER)."""
+from __future__ import annotations
+
+import json
+import logging
+import os
+import sys
+from pathlib import Path
+
+from datasets import concatenate_datasets, load_dataset
+
+_SCRIPT_DIR = Path(__file__).resolve().parent
+if str(_SCRIPT_DIR) not in sys.path:
+ sys.path.insert(0, str(_SCRIPT_DIR))
+from longbench_metrics import dataset2metric # noqa: E402
+
+from vllm import LLM, SamplingParams # noqa: E402
+from vllm.kvprune.integration.compression_params import CompressionParams # noqa: E402
+
+
+def _hf_tokenizer(llm: LLM):
+ tok = llm.get_tokenizer()
+ return getattr(tok, "tokenizer", tok)
+
+
+def messages_to_prompts(
+ llm: LLM,
+ messages: list[list[dict]],
+ *,
+ add_generation_prompt: bool,
+ enable_thinking: bool,
+) -> list[str]:
+ inner = _hf_tokenizer(llm)
+ out: list[str] = []
+ kw: dict = {}
+ if enable_thinking:
+ kw["enable_thinking"] = True
+ for conv in messages:
+ text = inner.apply_chat_template(
+ conv,
+ tokenize=False,
+ add_generation_prompt=add_generation_prompt,
+ **kw,
+ )
+ out.append(text)
+ return out
+
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s"
+ )
+ cfg_dir = _SCRIPT_DIR / "longbench_config"
+ prompts = json.load(open(cfg_dir / "dataset2prompt.json", "r", encoding="utf-8"))
+ max_gen_lens = json.load(open(cfg_dir / "dataset2maxlen.json", "r", encoding="utf-8"))
+
+ datasets = [
+ "narrativeqa",
+ "qasper",
+ "multifieldqa_en",
+ "hotpotqa",
+ "2wikimqa",
+ "musique",
+ "gov_report",
+ "qmsum",
+ "multi_news",
+ "trec",
+ "triviaqa",
+ "samsum",
+ "passage_retrieval_en",
+ "passage_count",
+ "lcc",
+ "repobench-p",
+ ]
+ dataset = concatenate_datasets(
+ [
+ load_dataset("THUDM/LongBench", n, split="test", trust_remote_code=True)
+ for n in datasets
+ ]
+ ).shuffle(seed=42)
+
+ dset_names = [
+ item["dataset"] if item["dataset"][-2:] != "_e" else item["dataset"][:-2]
+ for item in dataset
+ ]
+ gen_lengths = [max_gen_lens[dset_name] for dset_name in dset_names]
+
+ messages = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompts[dset_name].format(**item)},
+ ]
+ for dset_name, item in zip(dset_names, dataset)
+ ]
+
+ model = os.environ.get("KVPRUNE_EVAL_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
+ tp = int(os.environ.get("KVPRUNE_EVAL_TP", "2"))
+ seq_ratio = float(os.environ.get("KVPRUNE_SEQ_COMPRESSION_RATIO", "0.25"))
+
+ llm = LLM(
+ model=model,
+ max_num_seqs=64,
+ gpu_memory_utilization=0.95,
+ tensor_parallel_size=tp,
+ max_model_len=128000,
+ kvprune_compression=True,
+ )
+ text_prompts = messages_to_prompts(
+ llm,
+ messages,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ sampling_params = [
+ SamplingParams(max_tokens=g, temperature=0.00001) for g in gen_lengths
+ ]
+ n = len(text_prompts)
+ compression = [
+ CompressionParams(
+ compression_ratio=seq_ratio,
+ compression_method="compactor",
+ protected_first_tokens=8,
+ protected_last_tokens=64,
+ )
+ ] * n
+
+ outputs = llm.generate(text_prompts, sampling_params, compression=compression)
+ responses = [o.outputs[0].text for o in outputs]
+
+ results: dict = {}
+ for dset_name, prediction, item in zip(dset_names, responses, dataset):
+ results.setdefault(dset_name, [])
+ pred = prediction
+ if dset_name in ["trec", "triviaqa", "samsum", "lsht"]:
+ pred = pred.lstrip("\n").split("\n")[0]
+ score = 0.0
+ for ground_truth in item["answers"]:
+ score = max(
+ score,
+ dataset2metric[dset_name](
+ pred, ground_truth, all_classes=item["all_classes"]
+ ),
+ )
+ results[dset_name].append(score)
+
+ all_sum, all_count = 0, 0
+ for task, scores in results.items():
+ avg = sum(scores) / len(scores)
+ print(task, f"{avg:.2f}")
+ all_sum += sum(scores)
+ all_count += len(scores)
+ print(f"ALL: {all_sum / all_count:.2f}")
diff --git a/tests/kvprune/evaluate/eval_ruler.py b/tests/kvprune/evaluate/eval_ruler.py
new file mode 100644
index 0000000000000000000000000000000000000000..322c5bae497880b5507c5b5af4befbbb4afab9d9
--- /dev/null
+++ b/tests/kvprune/evaluate/eval_ruler.py
@@ -0,0 +1,385 @@
+# SPDX-License-Identifier: Apache-2.0
+"""RULER evaluation using vLLM ``LLM.generate`` + integrated kvprune (compactor) compression.
+
+Run from the **repository root** (or any cwd if ``vllm`` is installed), e.g.::
+
+ python tests/kvprune/evaluate/eval_ruler.py \\
+ --dataset-parquet tests/kvprune/evaluate/test-00000-of-00001.parquet \\
+ --dataset-split train \\
+ --model Qwen/Qwen3-8B \\
+ --compression-method compactor \\
+ --seq-compression-ratio 0.5
+
+Set ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` (``fa_triton`` | ``pdtriton`` | ``pdfa``) **before**
+starting Python if you need a specific attention schedule (also supported via ``--attention-schedule``).
+"""
+from __future__ import annotations
+
+import argparse
+import json
+import logging
+import os
+import sys
+from datetime import datetime
+from pathlib import Path
+
+import torch
+from datasets import load_dataset
+
+# Local metrics (same directory as this script)
+_SCRIPT_DIR = Path(__file__).resolve().parent
+if str(_SCRIPT_DIR) not in sys.path:
+ sys.path.insert(0, str(_SCRIPT_DIR))
+from ruler_metrics import score_function # noqa: E402
+
+from vllm import LLM, SamplingParams # noqa: E402
+from vllm.kvprune.integration.compression_params import CompressionParams # noqa: E402
+
+
+def _hf_tokenizer(llm: LLM):
+ tok = llm.get_tokenizer()
+ return getattr(tok, "tokenizer", tok)
+
+
+def messages_to_prompts(
+ llm: LLM,
+ messages: list[list[dict]],
+ *,
+ add_generation_prompt: bool,
+ continue_final_message: bool,
+ enable_thinking: bool,
+) -> list[str]:
+ """Render chat messages to a single prompt string per conversation (HF template)."""
+ inner = _hf_tokenizer(llm)
+ out: list[str] = []
+ for conv in messages:
+ kw: dict = {}
+ if enable_thinking:
+ kw["enable_thinking"] = True
+ text = inner.apply_chat_template(
+ conv,
+ tokenize=False,
+ add_generation_prompt=add_generation_prompt,
+ continue_final_message=continue_final_message,
+ **kw,
+ )
+ out.append(text)
+ return out
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="RULER evaluation with vLLM kvprune (integrated compactor) compression."
+ )
+ parser.add_argument(
+ "--log-level",
+ type=str,
+ default="INFO",
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
+ help="Logging level.",
+ )
+ parser.add_argument(
+ "--dataset-length",
+ type=str,
+ default="4096",
+ help="Dataset configuration name (metadata / output filenames only when using HF hub).",
+ )
+ parser.add_argument(
+ "--dataset-parquet",
+ type=str,
+ default=None,
+ help=(
+ "Local Parquet path (single file or glob). If set, loads via datasets parquet "
+ "instead of simonjegou/ruler."
+ ),
+ )
+ parser.add_argument(
+ "--dataset-split",
+ type=str,
+ default="test",
+ help="Split name (local parquet often uses 'train').",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="Shuffle seed.")
+ parser.add_argument(
+ "--fraction",
+ type=float,
+ default=1.0,
+ help="Fraction of dataset to use in (0, 1].",
+ )
+ parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B", help="HF model id or path.")
+ parser.add_argument("--max-num-seqs", type=int, default=32, help="vLLM max_num_seqs.")
+ parser.add_argument(
+ "--gpu-memory-utilization", type=float, default=0.95, help="GPU memory fraction."
+ )
+ parser.add_argument(
+ "--tensor-parallel-size",
+ type=int,
+ default=1,
+ help=(
+ "vLLM tensor parallel size. Default 1 uses the in-process shared-weight "
+ "compactor on one GPU. For multi-GPU (e.g. 4), set this to the number of "
+ "GPUs; compression then uses the TP collective_rpc path on workers."
+ ),
+ )
+ parser.add_argument("--max-model-len", type=int, default=40960, help="max_model_len.")
+ parser.add_argument(
+ "--enforce-eager",
+ action="store_true",
+ help="vLLM enforce_eager (on by default when --kvprune-compression).",
+ )
+ parser.add_argument(
+ "--kvprune-compression",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Enable kvprune_compression on LLM (skip v1 CUDA graphs, minimal v1 KV blocks). "
+ "Default: True.",
+ )
+ parser.add_argument(
+ "--attention-schedule",
+ type=str,
+ default=None,
+ help=(
+ "If set, assigns VLLM_KVPRUNE_ATTENTION_SCHEDULE before engine init, e.g. "
+ "fa_triton, pdtriton, pdfa (see vllm/kvprune/integration/config_adapter.py)."
+ ),
+ )
+ parser.add_argument("--max-tokens", type=int, default=256, help="max_tokens (generation).")
+ parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.")
+ parser.add_argument(
+ "--compression-method",
+ type=str,
+ default="compactor",
+ choices=["compactor", "criticaladakv", "snapkv"],
+ help="kvprune compression method alias.",
+ )
+ parser.add_argument(
+ "--seq-compression-ratio",
+ type=float,
+ default=0.5,
+ help="Per-sequence compression ratio in (0, 1].",
+ )
+ parser.add_argument(
+ "--protected-first-tokens",
+ type=int,
+ default=8,
+ help="Protected prefix token count for pruning.",
+ )
+ parser.add_argument(
+ "--extra-protected-last-tokens",
+ type=int,
+ default=16,
+ help="Added to tokenized(answer_prefix+question) length for protected_last_tokens.",
+ )
+ parser.add_argument(
+ "--tokenizer-add-generation-prompt",
+ action="store_true",
+ help="apply_chat_template add_generation_prompt=True.",
+ )
+ parser.add_argument(
+ "--tokenizer-enable-thinking",
+ action="store_true",
+ help="apply_chat_template enable_thinking=True (Qwen3).",
+ )
+ parser.add_argument(
+ "--no-tokenizer-continue-final-message",
+ dest="tokenizer_continue_final_message",
+ action="store_false",
+ help="continue_final_message=False (default True).",
+ )
+ parser.set_defaults(tokenizer_continue_final_message=True)
+ parser.add_argument(
+ "--results-dir",
+ type=str,
+ default="results",
+ help="Directory for JSON summary and JSONL details.",
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = parse_args()
+
+ if args.attention_schedule:
+ os.environ["VLLM_KVPRUNE_ATTENTION_SCHEDULE"] = args.attention_schedule.strip()
+
+ torch.manual_seed(args.seed)
+ logging.basicConfig(
+ level=getattr(logging, args.log_level.upper(), logging.INFO),
+ format="%(asctime)s %(levelname)s: %(message)s",
+ )
+ logger = logging.getLogger(__name__)
+
+ if args.dataset_parquet:
+ logger.info(
+ "Loading local parquet from %s (split=%s)",
+ args.dataset_parquet,
+ args.dataset_split,
+ )
+ dataset = load_dataset(
+ "parquet",
+ data_files=args.dataset_parquet,
+ split=args.dataset_split,
+ )
+ else:
+ logger.info(
+ "Loading simonjegou/ruler length=%s split=%s",
+ args.dataset_length,
+ args.dataset_split,
+ )
+ dataset = load_dataset(
+ "simonjegou/ruler",
+ args.dataset_length,
+ split=args.dataset_split,
+ )
+
+ if args.seed is not None and args.seed >= 0:
+ dataset = dataset.shuffle(seed=args.seed)
+ if not (0 < args.fraction <= 1.0):
+ raise ValueError("--fraction must be in (0, 1].")
+ if args.fraction < 1.0:
+ n_examples = max(1, int(len(dataset) * args.fraction))
+ dataset = dataset.select(range(n_examples))
+ logger.info("Examples: %d", len(dataset))
+
+ messages = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": example["context"] + " " + example["question"]},
+ {"role": "assistant", "content": example["answer_prefix"]},
+ ]
+ for example in dataset
+ ]
+
+ llm = LLM(
+ model=args.model,
+ tensor_parallel_size=args.tensor_parallel_size,
+ max_model_len=args.max_model_len,
+ max_num_seqs=args.max_num_seqs,
+ gpu_memory_utilization=args.gpu_memory_utilization,
+ enforce_eager=args.enforce_eager or args.kvprune_compression,
+ kvprune_compression=args.kvprune_compression,
+ )
+
+ tok = _hf_tokenizer(llm)
+ end_protected_lengths = [
+ args.extra_protected_last_tokens
+ + len(
+ tok.encode(
+ example["answer_prefix"] + example["question"],
+ add_special_tokens=False,
+ )
+ )
+ for example in dataset
+ ]
+
+ prompts = messages_to_prompts(
+ llm,
+ messages,
+ add_generation_prompt=args.tokenizer_add_generation_prompt,
+ continue_final_message=args.tokenizer_continue_final_message,
+ enable_thinking=args.tokenizer_enable_thinking,
+ )
+
+ sampling_params = SamplingParams(
+ max_tokens=args.max_tokens,
+ temperature=args.temperature,
+ )
+
+ compression_list = [
+ CompressionParams(
+ compression_ratio=args.seq_compression_ratio,
+ compression_method=args.compression_method,
+ protected_first_tokens=args.protected_first_tokens,
+ protected_last_tokens=end_protected_lengths[i],
+ )
+ for i in range(len(prompts))
+ ]
+
+ logger.info("Running LLM.generate with kvprune compression on %d prompts.", len(prompts))
+ outputs = llm.generate(
+ prompts,
+ sampling_params,
+ compression=compression_list,
+ )
+ responses = [o.outputs[0].text.strip() for o in outputs]
+
+ logger.info("Scoring responses.")
+ results: dict = {}
+ per_example: list = []
+ all_sum, all_count = 0.0, 0
+
+ for idx, (example, response) in enumerate(zip(dataset, responses)):
+ task = example["task"]
+ answer = example["answer"]
+ score = score_function(
+ generated=response,
+ ground_truth=answer,
+ task_category=task,
+ )
+ results.setdefault(task, []).append(score)
+ all_sum += score
+ all_count += 1
+ per_example.append(
+ {
+ "index": idx,
+ "task": task,
+ "context": example["context"],
+ "question": example["question"],
+ "answer_prefix": example["answer_prefix"],
+ "ground_truth": answer,
+ "generated": response,
+ "score": score,
+ "compression_params": {
+ "seq_compression_ratio": args.seq_compression_ratio,
+ "compression_method": args.compression_method,
+ "protected_first_tokens": args.protected_first_tokens,
+ "protected_last_tokens": end_protected_lengths[idx],
+ },
+ }
+ )
+
+ per_task_summary = {}
+ for task, scores in results.items():
+ avg = sum(scores) / len(scores)
+ print(task, f"{avg:.3f}")
+ per_task_summary[task] = {
+ "avg_score": avg,
+ "num_examples": len(scores),
+ "sum_scores": sum(scores),
+ }
+
+ overall_avg = all_sum / all_count if all_count > 0 else 0.0
+ print(f"ALL: {overall_avg:.3f}")
+
+ os.makedirs(args.results_dir, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ safe_model_name = args.model.replace("/", "_")
+ base_name = f"ruler_{args.dataset_length}_{safe_model_name}_{timestamp}"
+ summary_path = os.path.join(args.results_dir, base_name + "_summary.json")
+ details_path = os.path.join(args.results_dir, base_name + "_details.jsonl")
+
+ ds_name = args.dataset_parquet or "simonjegou/ruler"
+ with open(summary_path, "w", encoding="utf-8") as f:
+ json.dump(
+ {
+ "timestamp": timestamp,
+ "model": args.model,
+ "dataset": ds_name,
+ "dataset_length": args.dataset_length,
+ "num_examples": len(dataset),
+ "overall_avg_score": overall_avg,
+ "per_task": per_task_summary,
+ "arguments": vars(args),
+ },
+ f,
+ ensure_ascii=False,
+ indent=2,
+ )
+ with open(details_path, "w", encoding="utf-8") as f:
+ for row in per_example:
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
+ logger.info("Wrote %s and %s", summary_path, details_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/kvprune/evaluate/longbench_config/dataset2maxlen.json b/tests/kvprune/evaluate/longbench_config/dataset2maxlen.json
new file mode 100644
index 0000000000000000000000000000000000000000..79d0d9990e5799c845ebcf839c1ee1a4ff14873e
--- /dev/null
+++ b/tests/kvprune/evaluate/longbench_config/dataset2maxlen.json
@@ -0,0 +1,23 @@
+{
+ "narrativeqa": 128,
+ "qasper": 128,
+ "multifieldqa_en": 64,
+ "multifieldqa_zh": 64,
+ "hotpotqa": 32,
+ "2wikimqa": 32,
+ "musique": 32,
+ "dureader": 128,
+ "gov_report": 512,
+ "qmsum": 512,
+ "multi_news": 512,
+ "vcsum": 512,
+ "trec": 64,
+ "triviaqa": 32,
+ "samsum": 128,
+ "lsht": 64,
+ "passage_count": 32,
+ "passage_retrieval_en": 32,
+ "passage_retrieval_zh": 32,
+ "lcc": 64,
+ "repobench-p": 64
+}
\ No newline at end of file
diff --git a/tests/kvprune/evaluate/longbench_config/dataset2prompt.json b/tests/kvprune/evaluate/longbench_config/dataset2prompt.json
new file mode 100644
index 0000000000000000000000000000000000000000..faf6cc0f847baadc42c5178c6f1f8e93bb0730b7
--- /dev/null
+++ b/tests/kvprune/evaluate/longbench_config/dataset2prompt.json
@@ -0,0 +1,23 @@
+{
+ "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: \n\n\n{context}\n \n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: \n\n\n{context}\n \n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "multifieldqa_en": "Read the following text and answer briefly.\n\n\n{context}\n \n\n Now, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n\n{context}\n \n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
+ "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "dureader": "请基于给定的文章回答下述问题。\n\n文章:\n\n\n{context}\n \n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
+ "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n\n\n{context}\n \n\nNow, write a one-page summary of the report.\n\nSummary:",
+ "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n\n\n{context}\n \n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
+ "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n\n\n{context}\n \n\nNow, write a one-page summary of all the news.\n\nSummary:",
+ "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n\n\n{context}\n \n\n会议总结:",
+ "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n\n{context}\n \n\n{input}",
+ "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n\n{context}\n \n\n{input}",
+ "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n\n{context}\n \n\n{input}",
+ "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n\n{context}\n \n\n{input}",
+ "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n\n{context}\n \n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
+ "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n\n{context}\n \n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
+ "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n\n{context}\n \n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
+ "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
+ "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
+}
\ No newline at end of file
diff --git a/tests/kvprune/evaluate/longbench_config/dataset2prompt_taskagnostic.json b/tests/kvprune/evaluate/longbench_config/dataset2prompt_taskagnostic.json
new file mode 100644
index 0000000000000000000000000000000000000000..faf6cc0f847baadc42c5178c6f1f8e93bb0730b7
--- /dev/null
+++ b/tests/kvprune/evaluate/longbench_config/dataset2prompt_taskagnostic.json
@@ -0,0 +1,23 @@
+{
+ "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: \n\n\n{context}\n \n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: \n\n\n{context}\n \n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "multifieldqa_en": "Read the following text and answer briefly.\n\n\n{context}\n \n\n Now, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n\n{context}\n \n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
+ "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "dureader": "请基于给定的文章回答下述问题。\n\n文章:\n\n\n{context}\n \n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
+ "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n\n\n{context}\n \n\nNow, write a one-page summary of the report.\n\nSummary:",
+ "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n\n\n{context}\n \n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
+ "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n\n\n{context}\n \n\nNow, write a one-page summary of all the news.\n\nSummary:",
+ "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n\n\n{context}\n \n\n会议总结:",
+ "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n\n{context}\n \n\n{input}",
+ "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n\n{context}\n \n\n{input}",
+ "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n\n{context}\n \n\n{input}",
+ "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n\n{context}\n \n\n{input}",
+ "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n\n{context}\n \n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
+ "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n\n{context}\n \n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
+ "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n\n{context}\n \n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
+ "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
+ "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
+}
\ No newline at end of file
diff --git a/tests/kvprune/evaluate/longbench_metrics.py b/tests/kvprune/evaluate/longbench_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfdac7a67267b6390aa2be1abba1b9de6002fa0d
--- /dev/null
+++ b/tests/kvprune/evaluate/longbench_metrics.py
@@ -0,0 +1,176 @@
+import re
+import string
+from collections import Counter
+
+import jieba
+from fuzzywuzzy import fuzz
+from rouge import Rouge
+
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def normalize_zh_answer(s):
+ """Lower text and remove punctuation, extra whitespace."""
+
+ def white_space_fix(text):
+ return "".join(text.split())
+
+ def remove_punc(text):
+ cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
+ all_punctuation = set(string.punctuation + cn_punctuation)
+ return "".join(ch for ch in text if ch not in all_punctuation)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_punc(lower(s)))
+
+
+def count_score(prediction, ground_truth, **kwargs):
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def retrieval_score(prediction, ground_truth, **kwargs):
+ pattern = r"Paragraph (\d+)"
+ matches = re.findall(pattern, ground_truth)
+ ground_truth_id = matches[0]
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth_id):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def retrieval_zh_score(prediction, ground_truth, **kwargs):
+ pattern = r"段落(\d+)"
+ matches = re.findall(pattern, ground_truth)
+ ground_truth_id = matches[0]
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth_id):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def code_sim_score(prediction, ground_truth, **kwargs):
+ all_lines = prediction.lstrip("\n").split("\n")
+ prediction = ""
+ for line in all_lines:
+ if ("`" not in line) and ("#" not in line) and ("//" not in line):
+ prediction = line
+ break
+ return fuzz.ratio(prediction, ground_truth) / 100
+
+
+def classification_score(prediction, ground_truth, **kwargs):
+ em_match_list = []
+ all_classes = kwargs["all_classes"]
+ for class_name in all_classes:
+ if class_name in prediction:
+ em_match_list.append(class_name)
+ for match_term in em_match_list:
+ if match_term in ground_truth and match_term != ground_truth:
+ em_match_list.remove(match_term)
+ if ground_truth in em_match_list:
+ score = 1.0 / len(em_match_list)
+ else:
+ score = 0.0
+ return score
+
+
+def rouge_score(prediction, ground_truth, **kwargs):
+ rouge = Rouge()
+ try:
+ scores = rouge.get_scores([prediction], [ground_truth], avg=True)
+ except:
+ return 0.0
+ return scores["rouge-l"]["f"]
+
+
+def rouge_zh_score(prediction, ground_truth, **kwargs):
+ prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
+ ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
+ score = rouge_score(prediction, ground_truth)
+ return score
+
+
+def f1_score(prediction, ground_truth, **kwargs):
+ common = Counter(prediction) & Counter(ground_truth)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction)
+ recall = 1.0 * num_same / len(ground_truth)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def qa_f1_score(prediction, ground_truth, **kwargs):
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(ground_truth)
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ return f1_score(prediction_tokens, ground_truth_tokens)
+
+
+def qa_f1_zh_score(prediction, ground_truth, **kwargs):
+ prediction_tokens = list(jieba.cut(prediction, cut_all=False))
+ ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
+ prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
+ ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
+ prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
+ ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
+ return f1_score(prediction_tokens, ground_truth_tokens)
+
+
+dataset2metric = {
+ "narrativeqa": qa_f1_score,
+ "qasper": qa_f1_score,
+ "multifieldqa_en": qa_f1_score,
+ "multifieldqa_zh": qa_f1_zh_score,
+ "hotpotqa": qa_f1_score,
+ "2wikimqa": qa_f1_score,
+ "musique": qa_f1_score,
+ "dureader": rouge_zh_score,
+ "gov_report": rouge_score,
+ "qmsum": rouge_score,
+ "multi_news": rouge_score,
+ "vcsum": rouge_zh_score,
+ "trec": classification_score,
+ "triviaqa": qa_f1_score,
+ "samsum": rouge_score,
+ "lsht": classification_score,
+ "passage_retrieval_en": retrieval_score,
+ "passage_count": count_score,
+ "passage_retrieval_zh": retrieval_zh_score,
+ "lcc": code_sim_score,
+ "repobench-p": code_sim_score,
+}
diff --git a/tests/kvprune/evaluate/ruler_metrics.py b/tests/kvprune/evaluate/ruler_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dbda53e7d30214d2e1f42a5e1a51baa2d32251f
--- /dev/null
+++ b/tests/kvprune/evaluate/ruler_metrics.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import re
+from typing import List
+
+import pandas as pd
+
+
+def string_match_part(preds, refs):
+ score = (
+ sum(
+ [
+ max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
+ for pred, ref in zip(preds, refs)
+ ]
+ )
+ / len(preds)
+ * 100
+ )
+ return round(score, 2)
+
+
+def string_match_all(preds, refs):
+ score = (
+ sum(
+ [
+ sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
+ for pred, ref in zip(preds, refs)
+ ]
+ )
+ / len(preds)
+ * 100
+ )
+ return round(score, 2)
+
+
+def calculate_metrics(df: pd.DataFrame) -> dict:
+ scores = {}
+
+ np_pattern = re.compile(r"[\x00-\x1f]")
+ df["predicted_answer"] = df["predicted_answer"].apply(
+ lambda x: np_pattern.sub("", x.strip()).strip()
+ )
+
+ for task, df_task in df.groupby("task"):
+ task_category = task.split("_")[0]
+ metric_fn = string_match_part if task_category == "qa" else string_match_all
+ preds = df_task["predicted_answer"].tolist()
+ refs = df_task["answer"].tolist()
+ score = metric_fn(preds, refs)
+ scores[task] = {"string_match": score}
+ return scores
+
+
+def score_function(*, generated, ground_truth: List[str], task_category: str):
+ np_pattern = re.compile(r"[\x00-\x1f]")
+ generated = np_pattern.sub("", generated.strip()).strip()
+ task_category = task_category.split("_")[0]
+ metric_fn = string_match_part if task_category == "qa" else string_match_all
+ return metric_fn([generated], [ground_truth])
diff --git a/tests/kvprune/evaluate/test-00000-of-00001.parquet b/tests/kvprune/evaluate/test-00000-of-00001.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..a4899bd5767998cd65777b07f2d5c4fd7a4973c9
Binary files /dev/null and b/tests/kvprune/evaluate/test-00000-of-00001.parquet differ
diff --git a/tests/kvprune/evaluate/test.py b/tests/kvprune/evaluate/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfd4e2275190ed3f12b1f7f6fae6da04ce7e2349
--- /dev/null
+++ b/tests/kvprune/evaluate/test.py
@@ -0,0 +1,218 @@
+import argparse
+import inspect
+import logging
+import os
+import sys
+from pathlib import Path
+
+
+def _maybe_add_src_to_path() -> None:
+ # Allow running without `pip install -e .` by pointing to `compactor-vllm/src`.
+ here = Path(__file__).resolve()
+ repo_root = here.parents[1]
+ src_dir = repo_root / "src"
+ if src_dir.is_dir() and str(src_dir) not in sys.path:
+ sys.path.insert(0, str(src_dir))
+
+
+_maybe_add_src_to_path()
+
+from compactor_vllm import LLM, LLMConfig, SamplingParams # noqa: E402
+from compactor_vllm.compression import ( # noqa: E402
+ BatchCompressionParams,
+ CompressionMethod,
+ SequenceCompressionParams,
+)
+from compactor_vllm.config.engine_config import AttentionBackend # noqa: E402
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Minimal smoke test for compactor-vllm (no speculative decoding)."
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default=os.environ.get("MODEL", "/mnt/data/llm-models/Qwen3-8B"),
+ help="Local model directory or HF id. In the container this is usually a local dir.",
+ )
+ parser.add_argument(
+ "--tp",
+ type=int,
+ default=int(os.environ.get("TP", "1")),
+ help="Tensor parallel size (world size).",
+ )
+ parser.add_argument(
+ "--nccl-port",
+ type=int,
+ default=int(os.environ.get("NCCL_PORT", "1218")),
+ help="TCP port for torch.distributed init (only used for NCCL init_method=tcp://localhost:).",
+ )
+ parser.add_argument("--max-model-len", type=int, default=2048)
+ parser.add_argument("--max-num-seqs", type=int, default=2)
+ parser.add_argument(
+ "--gpu-memory-utilization",
+ type=float,
+ default=float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9")),
+ help="Fraction of total GPU memory used for KV cache + activations.",
+ )
+ parser.add_argument(
+ "--attention-backend",
+ type=str,
+ default="compactor_triton",
+ choices=[b.name.lower() for b in AttentionBackend],
+ )
+ parser.add_argument(
+ "--compression-method",
+ type=str,
+ default="compactor",
+ choices=[m.name.lower() for m in CompressionMethod],
+ )
+ parser.add_argument(
+ "--compression-ratio",
+ type=float,
+ default=0.8,
+ help="Sequence-level compression ratio (e.g. 0.8 keeps 80%% of tokens).",
+ )
+ parser.add_argument("--chunk-size", type=int, default=512)
+ parser.add_argument(
+ "--no-chunked-compression",
+ dest="do_chunked_compression",
+ action="store_false",
+ )
+ parser.set_defaults(do_chunked_compression=True)
+
+ parser.add_argument("--prompt", type=str, default="用一句话介绍你自己,给我讲一个故事,200字左右。")
+ parser.add_argument("--max-new-tokens", type=int, default=64)
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=0.0,
+ help="0.0 = greedy decoding (recommended for smoke tests).",
+ )
+ parser.add_argument(
+ "--tokenizer-enable-thinking",
+ dest="tokenizer_enable_thinking",
+ action="store_true",
+ help="Pass enable_thinking=True to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.add_argument(
+ "--no-tokenizer-enable-thinking",
+ dest="tokenizer_enable_thinking",
+ action="store_false",
+ help="Pass enable_thinking=False to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.set_defaults(tokenizer_enable_thinking=False)
+ parser.add_argument(
+ "--tokenizer-add-generation-prompt",
+ dest="tokenizer_add_generation_prompt",
+ action="store_true",
+ help="Pass add_generation_prompt=True to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.add_argument(
+ "--no-tokenizer-add-generation-prompt",
+ dest="tokenizer_add_generation_prompt",
+ action="store_false",
+ help="Pass add_generation_prompt=False to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.set_defaults(tokenizer_add_generation_prompt=True)
+ parser.add_argument(
+ "--tokenizer-continue-final-message",
+ dest="tokenizer_continue_final_message",
+ action="store_true",
+ help="Pass continue_final_message=True to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.add_argument(
+ "--no-tokenizer-continue-final-message",
+ dest="tokenizer_continue_final_message",
+ action="store_false",
+ help="Pass continue_final_message=False to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.set_defaults(tokenizer_continue_final_message=False)
+ parser.add_argument(
+ "--skip-special-tokens",
+ dest="skip_special_tokens",
+ action="store_true",
+ help="Skip special tokens in output decoding (recommended).",
+ )
+ parser.add_argument(
+ "--no-skip-special-tokens",
+ dest="skip_special_tokens",
+ action="store_false",
+ help="Keep special tokens in output decoding (e.g. <|im_end|>).",
+ )
+ parser.set_defaults(skip_special_tokens=True)
+ parser.add_argument(
+ "--log-level",
+ type=str,
+ default="INFO",
+ choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = _parse_args()
+ logging.basicConfig(
+ level=getattr(logging, args.log_level.upper()),
+ format="%(asctime)s - %(levelname)s - %(message)s",
+ )
+
+ attention_backend = AttentionBackend[args.attention_backend.upper()]
+ compression_method = CompressionMethod[args.compression_method.upper()]
+
+ model = args.model
+ cfg = LLMConfig(
+ model=model,
+ path=model,
+ tensor_parallel_size=int(args.tp),
+ nccl_port=int(args.nccl_port),
+ max_model_len=int(args.max_model_len),
+ max_num_seqs=int(args.max_num_seqs),
+ gpu_memory_utilization=float(args.gpu_memory_utilization),
+ enforce_eager=True,
+ attention_backend=attention_backend,
+ show_progress_bar=False,
+ )
+ llm = LLM(cfg)
+
+ tokenizer_kwargs = {
+ "add_generation_prompt": bool(args.tokenizer_add_generation_prompt),
+ "enable_thinking": bool(args.tokenizer_enable_thinking),
+ "continue_final_message": bool(args.tokenizer_continue_final_message),
+ }
+ if tokenizer_kwargs.get("add_generation_prompt") and tokenizer_kwargs.get(
+ "continue_final_message"
+ ):
+ # HF tokenizer API rejects these being simultaneously True.
+ tokenizer_kwargs["continue_final_message"] = False
+ # Be defensive: only pass kwargs supported by this tokenizer build.
+ try:
+ supported = set(inspect.signature(llm.tokenizer.apply_chat_template).parameters)
+ tokenizer_kwargs = {k: v for k, v in tokenizer_kwargs.items() if k in supported}
+ except (TypeError, ValueError):
+ pass
+
+ outs = llm.generate_chat(
+ [[{"role": "user", "content": args.prompt}]],
+ sampling_params=SamplingParams(
+ temperature=float(args.temperature),
+ max_new_tokens=int(args.max_new_tokens),
+ ),
+ batch_compression_params=BatchCompressionParams(
+ compression_method=compression_method,
+ do_chunked_compression=bool(args.do_chunked_compression),
+ chunk_size=int(args.chunk_size),
+ ),
+ per_sequence_compression_params=SequenceCompressionParams(
+ compression_ratio=float(args.compression_ratio),
+ ),
+ tokenizer_kwargs=tokenizer_kwargs,
+ detokenizer_kwargs={"skip_special_tokens": bool(args.skip_special_tokens)},
+ )
+ print(outs[0])
+ llm.exit()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vllm/__init__.py b/vllm/__init__.py
index 968d1a143b16f144d6eda57357c1db31758ac4c8..63814ac2176e2270fd12e990396d8ded74dbf3cb 100644
--- a/vllm/__init__.py
+++ b/vllm/__init__.py
@@ -25,6 +25,7 @@ MODULE_ATTRS = {
"TokensPrompt": ".inputs:TokensPrompt",
"ModelRegistry": ".model_executor.models:ModelRegistry",
"SamplingParams": ".sampling_params:SamplingParams",
+ "CompressionParams": ".kvprune.integration.compression_params:CompressionParams",
"PoolingParams": ".pooling_params:PoolingParams",
"ClassificationOutput": ".outputs:ClassificationOutput",
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
@@ -58,6 +59,7 @@ if typing.TYPE_CHECKING:
ScoringRequestOutput,
)
from vllm.pooling_params import PoolingParams
+ from vllm.kvprune.integration.compression_params import CompressionParams
from vllm.sampling_params import SamplingParams
from vllm.v1.executor.ray_utils import initialize_ray_cluster
else:
@@ -77,6 +79,7 @@ __all__ = [
"__version__",
"__version_tuple__",
"LLM",
+ "CompressionParams",
"ModelRegistry",
"PromptType",
"TextPrompt",
diff --git a/vllm/compactor-vllm/.gitignore b/vllm/compactor-vllm/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7119bf941e33c3dd35a090ff44e8aca6f71e2b21
--- /dev/null
+++ b/vllm/compactor-vllm/.gitignore
@@ -0,0 +1,4 @@
+/.ruff_cache
+/.DS_Store
+/.idea
+*.pyc
\ No newline at end of file
diff --git a/vllm/compactor-vllm/AGENTS.md b/vllm/compactor-vllm/AGENTS.md
new file mode 100644
index 0000000000000000000000000000000000000000..6a615ce55229437b04cf8a288c483b02d63cedd2
--- /dev/null
+++ b/vllm/compactor-vllm/AGENTS.md
@@ -0,0 +1,44 @@
+# Repository Guidelines
+
+## Project Structure & Module Organization
+
+- `src/compactor_vllm/`: main Python package (minimal vLLM-style engine).
+ - `core/`: engine loop, scheduling, memory management.
+ - `attention/`: Triton attention backends and `compile_kernels.py` autotuning helper.
+ - `compression/`: compression methods (e.g., Compactor, SnapKV) and registries/config.
+ - `kv_cache/`: paged KV cache + store helpers.
+ - `models/`, `layers/`, `utils/`: model definitions and reusable building blocks.
+ - `triton_kernels/`: low-level kernels (treat as vendor-style code; avoid drive-by edits).
+- `tests/`: GPU correctness tests (`tests/test_*.py`).
+- `evaluate/`: evaluation scripts (RULER/LongBench) and configs (`evaluate/longbench_config/`).
+- Repo root: figures/plots used by `README.md`.
+
+## Build, Test, and Development Commands
+
+- `pip install -e .`: editable install for local development.
+- `pip install -e ".[evaluate]"`: install optional evaluation dependencies (you may also need `pip install datasets`).
+- `pytest tests/`: run unit/kernel tests (expects a CUDA-capable GPU and working `flash-attn`/Triton setup).
+- `python -m compactor_vllm.attention.compile_kernels --max-length 16384 --HKV 8 --HQ 32 --D 128 --page-size 128`: pre-autotune Triton kernels (results are cached on disk; avoids first-run autotuning latency).
+- `python evaluate/eval_ruler.py --help`: run RULER evaluation (downloads datasets).
+- `python evaluate/eval_longbench.py`: run LongBench evaluation (downloads datasets).
+
+## Coding Style & Naming Conventions
+
+- Use 4-space indentation and follow existing patterns in `src/compactor_vllm/` (type hints, `@dataclass` configs, `logging` over `print`).
+- Naming: `snake_case` (modules/functions), `PascalCase` (classes), `UPPER_SNAKE_CASE` (constants).
+- Lint/format: Ruff is configured in `pyproject.toml`. If installed, run `ruff check .` and `ruff format .` (cache is ignored via `.gitignore`).
+
+## Testing Guidelines
+
+- Framework: `pytest`. Prefer parameterized tests for kernels and keep GPU tests deterministic (seed RNGs; `torch.cuda.synchronize()` before assertions when needed).
+- When changing kernels or compression logic, add/extend a focused regression test and, when feasible, compare against a reference backend (e.g., FlashAttention).
+
+## Commit & Pull Request Guidelines
+
+- Commits in history are short and imperative (e.g., “fix plot”, “update package layout”); keep subjects concise and scoped.
+- PRs should include: a clear description, reproduction commands, expected correctness/perf impact, GPU/CUDA details for kernel changes, and new/updated tests. Add plots/screenshots when changing benchmarks or figures.
+
+## Environment & Configuration Tips
+
+- Requires an NVIDIA CUDA GPU; ensure compatible versions of PyTorch, Triton, and `flash-attn`.
+- Kernel constraint: `head_dim` (`D`) must be a power of two; new model configs may trigger autotuning on first use.
diff --git a/vllm/compactor-vllm/CLAUDE.md b/vllm/compactor-vllm/CLAUDE.md
new file mode 100644
index 0000000000000000000000000000000000000000..42d89ec6bdb56fe87ed2cacde81a24286b8677db
--- /dev/null
+++ b/vllm/compactor-vllm/CLAUDE.md
@@ -0,0 +1,158 @@
+# CLAUDE.md
+
+This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
+
+## 概述
+
+compactor-vllm 是一个用于长上下文 LLM 的极简推理引擎,支持无需训练的 KV 缓存压缩。它实现了分页 KV 缓存管理器、针对压缩缓存优化的自定义 Triton 注意力内核,以及多种压缩方法(Compactor、SnapKV)。
+
+## 安装
+
+```bash
+pip install -e .
+```
+
+依赖:Python 3.10+、带 CUDA 的 PyTorch、Triton、FlashAttention、Transformers。
+
+## 测试
+
+```bash
+pytest tests/
+```
+
+测试包括内核正确性测试(`test_triton_attention.py`)和 KV 缓存存储测试(`test_store_kv.py`)。
+
+## 内核自动调优
+
+内核在首次使用时会自动调优。建议在生产环境预先调优:
+
+```bash
+python compactor_vllm/attention/compile_kernels.py --max-length 16384 --HKV 8 --HQ 32 --D 128 --page-size 128
+```
+
+根据您的模型配置调整参数:
+- `--HKV`: KV 头数(模型的 `num_key_value_heads`)
+- `--HQ`: 查询头数(模型的 `num_attention_heads`)
+- `--D`: 头维度(必须是 2 的幂)
+- `--max-length`: 预期的最大序列长度
+
+## 核心架构
+
+### 执行流程
+
+1. **LLM (LLMEngine)**: 高层入口,为张量并行推理生成多个 ModelRunner 进程
+2. **ModelRunner**: 每个秩的执行循环,管理模型加载、预热和主推理循环
+3. **Scheduler**: 管理序列生命周期(待处理 → 运行中 → 已完成)和批处理
+4. **KVCacheManager**: 分配和跟踪分页 KV 缓存内存
+5. **Attention Layer**: 应用压缩并使用选定的后端计算注意力
+
+### 压缩流水线
+
+压缩在 **预填充(prefill)** 阶段分两步进行:
+
+1. **RoPE 前评分** (`apply_prerope_compression`):查询无关的重要性评分(例如 Compactor 的近似杠杆分数)
+2. **RoPE 后评分** (`apply_postrope_compression`):在旋转位置编码后可选的精化
+3. **KV 提取** (`extract_and_store_top_kv`):仅将评分最高的 KV 对存储到分页缓存中
+
+分数在 CUDA 流(`STORE_STREAM`)上异步计算,以与内存密集型操作重叠。
+
+### 注意力后端
+
+通过 `LLMConfig(attention_backend=...)` 选择:
+
+- **COMPACTOR_TRITON**(默认):为压缩缓存优化的自定义稀疏变长注意力内核
+- **FLASH_ATTENTION**:FlashAttention varlen 后端(备选方案,压缩时效率较低)
+
+内核位于 `attention/sparse_varlen_kernel.py`(预填充)和 `attention/sparse_decode_kernel.py`(解码)。
+
+### 分页 KV 缓存
+
+- **PagedKVCache** (`kv_cache/page_table.py`):由固定大小页面支持的全局 KV 缓存
+- 每层都有一个页表,将 `(batch, kv_head, logical_page)` 映射到物理页面 ID
+- 页面从每层的空闲列表(最小堆)中分配
+- 页面大小默认为 128 个 token(`kvcache_page_size` 配置)
+
+### 模型注册
+
+模型在 `models/__init__.py` 中通过 `MODEL_REGISTRY` 注册:
+
+```python
+MODEL_REGISTRY = {
+ "llama": LlamaForCausalLM,
+ "qwen3": Qwen3ForCausalLM,
+ "qwen3_moe": Qwen3MoeForCausalLM,
+}
+```
+
+添加新模型:
+1. 在 `models/` 中创建 `*ForCausalLM` 类,使用共享的 `layers/`(Attention、MoE 等)
+2. 使用 HuggingFace 配置中相应的 `model_type` 键进行注册
+
+### 添加压缩方法
+
+1. 在 `compression/` 中创建 `BaseCompressionMethod` 的子类:
+ - 实现 `pre_rope_scoring(q, k, v, context)` → 返回重要性分数
+ - 可选实现 `post_rope_scoring(q, k, v, prerope_scores, context)`
+2. 在 `compression/__init__.py` 中注册:
+ ```python
+ COMPRESSION_REGISTRY[CompressionMethod.MY_METHOD] = MyCompressionMethod
+ ```
+3. 在 `compression/compression_config.py` 的 `CompressionMethod` 中添加枚举值
+
+### 多 GPU 推理
+
+张量并行推理使用 `torch.distributed` (NCCL)。在 `LLMConfig` 中设置 `tensor_parallel_size`。world size 必须能整除 `num_key_value_heads`。
+
+## 目录结构
+
+```
+src/compactor_vllm/
+├── attention/ # Triton 注意力内核(预填充 + 解码)
+├── compression/ # 压缩方法实现
+├── config/ # LLMConfig、SamplingParams、枚举
+├── core/ # LLMEngine、ModelRunner、Scheduler、KVCacheManager
+├── kv_cache/ # PagedKVCache、页表、KV 存储工具
+├── layers/ # 可复用的模型层(Attention、MoE、Linear 等)
+├── models/ # 特定模型的实现(llama3、qwen3 等)
+├── utils/ # Sequence 数据类、上下文管理、辅助函数
+└── triton_kernels/ # 来自 Triton Lang 仓库的快速 MoE 内核
+```
+
+## 重要实现细节
+
+### 序列管理
+
+- **Sequence**:跟踪 prompt token、生成 token、采样参数和压缩参数的数据类
+- **SequenceStatus**:WAITING → RUNNING → FINISHED
+- 每个序列通过迭代器计数器获得唯一的 `seq_id`
+
+### 压缩参数
+
+- **BatchCompressionParams**:压缩方法和分块策略(应用于整个批次)
+- **SequenceCompressionParams**:每序列压缩比和受保护的 token 区域
+
+受保护的 token(首部/尾部)在压缩期间永远不会被丢弃。
+
+### 上下文管理
+
+`utils/context.py` 提供线程本地 `Context` 对象,存储:
+- 当前阶段(预填充 vs 解码)
+- 压缩上下文
+- 批次映射、序列长度、累积序列长度
+- 异步操作的 CUDA 流
+
+使用 `get_context()` 访问,`set_context()`/`reset_context()` 管理。
+
+### 内存分配
+
+KV 缓存内存在预热期间计算,并根据 `gpu_memory_utilization` 分配。如果可用内存不足,引擎将失败。
+
+### Triton 内核
+
+内核使用 Triton 的自动调优器并缓存到磁盘。由于内核要求,`head_dim` 必须是 2 的幂。
+
+
+
+#容器登录
+/public/home/lixh6/laibao/ssh/kvpress.sh
+使用这个登录容器
diff --git a/vllm/compactor-vllm/README.md b/vllm/compactor-vllm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c505ad2b0c3180d4727d00be952092efaee2cac7
--- /dev/null
+++ b/vllm/compactor-vllm/README.md
@@ -0,0 +1,346 @@
+# compactor-vllm
+[](https://arxiv.org/abs/2507.08143)
+[](https://opensource.org/licenses/MIT)
+
+**Nearly zero-overhead KV-cache compression in a minimal vLLM-style engine**
+
+Long-context LLMs quickly become bottlenecked by the key–value (KV) cache: memory usage and bandwidth both scale linearly with the number of tokens. **compactor-vllm** is a small, simple inference engine that makes long-context inference more practical by combining:
+
+- **Paged KV cache manager** – for efficient memory allocation and management
+- **Custom Triton kernels** – for sparse (and dense) variable-length attention and fast KV compression
+- **Training-free KV compression** – out-of-the-box, with the Compactor compression method.
+
+## Key Features
+
+### 🚀 Speed
+Custom Triton attention kernels for head-sparse that outperform FlashAttention2 by up to 45% on long-context tasks, for compressed and uncompressed KV caches (benchmarked and tuned on H100, L40, A100, H100 NVL, H200). Over 15x faster than NVIDIA's KVPress Library for KV Cache Compression
+
+### 💾 Memory Efficiency
+Achieve up to 50% memory savings while maintaining strong task performance.
+
+### ⚡ Zero-Overhead Compression
+Carefully overlapped KV compression operations with memory-bound portions of the prefill process.
+
+### ❗ Use Cases
+- **Long-document QA** - Reduce memory for 100K+ token contexts
+- **Multi-turn conversations** - Compress chat history while maintaining quality
+- **RAG systems** - Handle large retrieved contexts efficiently
+- **Batch processing** - Increase batch sizes with compressed KV cache
+### ⏱️ Coming Soon
+- **Prefix Caching**
+- **Calibrated Compression - automatically determine how much compression your context can tolerate**
+- **More Models**
+- **More Compression Methods**
+- **Fine-grained Compression Policies** - Specify specific regions of the context to compress (i.e don't compress system prompt, but compress few-shot exemplars).
+
+---
+
+## Performance
+
+### Throughput Comparison (50% KV Retention)
+
+At 50% KV retention, compactor-vllm achieves comparable throughput to **uncompressed vLLM** while using significantly less memory (see the first image).
+
+### Memory Usage (60% KV Retention)
+
+On the RULER 4K dataset with an H100 GPU, compactor-vllm reduces peak KV cache memory from 60GB to 36GB – a 40% reduction, as expected.
+
+
+
+### Task Performance (RULER Benchmark, Compactor KV Compression, Query Agnostic)
+
+| KV Discarded | 0% | 25% | 50% | 75% | 95% |
+|--------------|-------|-------|-------|-------|-------|
+| Llama 3.1-8B | 95.39 | 95.63 | 94.75 | 83.07 | 64.79 |
+| Qwen3-8B | 95.01 | 94.57 | 92.29 | 76.48 | 44.69 |
+
+At 50% compression, both models maintain over **97%** of their full-cache performance. Most tasks can tolerate
+at least 50% KV compression, and some can tolerate even more! An example of a RULER question:
+> A special magic uuid is hidden within the following text. Make sure to memorize it. I will quiz you about the uuid afterwards.
+One of the special magic uuids for 3ce915e7-c9d6-463b-8a3c-6f5f5bb5c40c is: 2c9b662e-040a-4aae-92e2-afd996bf10ab. ...
+One of the special magic uuids for bde13c1b-2073-4f6d-8d6a-05b343ef2016 is: bee3eb79-1d18-4ee9-86ad-8a8c6bc4123e.
+What is the special magic uuid for a93b12cd-1c24-420e-acab-d7e7cc6b66e5 mentioned in the provided text?
+
+### Attention Kernel Performance
+
+Our Triton kernels match outperform FlashAttention2 by upto 45% across different sequence lengths.
+and KV cache sizes, **even for uncompressed caches**:
+---
+
+## Installation
+
+### From Source
+
+```bash
+git clone https://github.com/vnchari/compactor_vllm.git
+cd compactor_vllm
+pip install -e .
+```
+
+### Requirements
+
+- Python 3.10+
+- NVIDIA GPU with CUDA support
+- PyTorch with CUDA
+- Transformers (for model downloading)
+- Triton
+- FlashAttention
+
+### Autotuning kernels
+You can autotune kernels ahead of time instead of occuring at first use. Autotuning results are
+automatically cached to the disk, so they only need to be done once per attention configuration
+```bash
+python3 compactor_vllm/attention/compile_kernels.py --max-length 16384 --HKV 8 --HQ 32 --D 128 --page-size 128
+```
+---
+
+## Quick Start
+
+### Basic Chat Generation with Compression
+
+```python
+from compactor_vllm import (
+ LLM,
+ LLMConfig,
+ SamplingParams,
+ CompressionMethod,
+)
+from compactor_vllm.compression import (
+ BatchCompressionParams,
+ SequenceCompressionParams
+)
+
+# Configure the model
+config = LLMConfig(
+ model="Qwen/Qwen3-8B",
+ max_model_len=40960,
+)
+
+llm = LLM(config)
+
+# Set up sampling parameters
+sampling = SamplingParams(temperature=0.7, max_new_tokens=256)
+
+# Configure compression
+compression = BatchCompressionParams(
+ compression_method=CompressionMethod.COMPACTOR, # or SNAPKV
+)
+
+# Create conversation
+messages_batch = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Summarize the main idea of KV cache compression."},
+ ],
+]
+
+# Generate with 50% KV retention
+sequence_compression = SequenceCompressionParams(compression_ratio=0.5)
+answers = llm.generate_chat(
+ messages_batch=messages_batch,
+ sampling_params=sampling,
+ batch_compression_params=compression,
+ per_sequence_compression_params=sequence_compression
+)
+print(answers[0])
+```
+
+---
+
+## Core Components
+
+### Compression Methods
+
+compactor-vllm supports multiple KV cache compression strategies:
+
+#### **COMPACTOR**
+- Query-agnostic compression based on approximate leverage scores
+- Training-free and parameter-free
+- Maintains strong performance with aggressive compression ratios
+
+#### **SnapKV**
+- Query-aware compression using recent-token attention statistics
+- Well-suited for scenarios where the question is known at inference-time
+
+#### **None**
+- Baseline with no compression
+- Standard paged KV cache behavior
+
+### Attention Backends
+
+Choose your attention implementation via `attention_backend`:
+
+```python
+from compactor_vllm import LLMConfig, AttentionBackend
+
+config = LLMConfig(
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
+ attention_backend=AttentionBackend.COMPACTOR_TRITON, # Recommended
+)
+```
+
+**COMPACTOR_TRITON**: Custom sparse variable-length attention kernel optimized for long contexts and compressed KV caches. Was developed in order to support prefix-caching (coming soon!)
+
+**FLASH_ATTENTION**: FlashAttention reference backend.
+
+### Supported Models
+
+Models are registered in `MODEL_REGISTRY` and include:
+
+- **Llama 3 family** – Full support for Meta's Llama 3 models
+- **Qwen3** – Dense Qwen3 models
+- **Qwen3 MoE** – Mixture-of-Experts Qwen3 variants
+
+Check supported architectures:
+```python
+from compactor_vllm.models import MODEL_REGISTRY
+print(list(MODEL_REGISTRY.keys()))
+# ['llama', 'qwen3', 'qwen3_moe']
+```
+
+---
+
+## Advanced Usage
+
+### Configuring Compression Ratios
+
+Control how aggressively to compress the KV cache:
+
+```python
+from compactor_vllm.compression import SequenceCompressionParams
+# Retain 50% of KV cache (discard 50%)
+sequence_compression = SequenceCompressionParams(compression_ratio=0.5)
+
+# More aggressive: retain only 25%
+sequence_compression = SequenceCompressionParams(compression_ratio=0.25)
+```
+
+### Multi-GPU Inference
+
+compactor-vllm supports tensor-parallel inference across multiple GPUs using `torch.distributed`. Specify `tensor_parallel_size` in ``LLMConfig``
+
+### Batch Processing
+
+Process multiple conversations efficiently:
+
+```python
+messages_batch = [
+ [{"role": "user", "content": "Question 1"}],
+ [{"role": "user", "content": "Question 2"}],
+ [{"role": "user", "content": "Question 3"}],
+]
+
+answers = llm.generate_chat(
+ messages_batch=messages_batch,
+ sampling_params=sampling,
+ batch_compression_params=compression,
+)
+```
+
+---
+
+## Extending compactor-vllm
+
+### Adding a New Compression Method
+
+1. Create a subclass of `BaseCompressionMethod`:
+
+```python
+# compression/my_method.py
+from compactor_vllm.compression import BaseCompressionMethod
+
+class MyCompressionMethod(BaseCompressionMethod):
+ def pre_rope_scoring(self, ...):
+ # Implement scoring logic
+ pass
+
+ def post_rope_scoring(self, ...):
+ # Optional refinement
+ pass
+```
+
+2. Register in `compression/__init__.py`:
+
+```python
+from compactor_vllm.compression import COMPRESSION_REGISTRY, CompressionMethod
+
+COMPRESSION_REGISTRY[CompressionMethod.MY_METHOD] = MyCompressionMethod
+```
+
+### Adding a New Model Architecture
+
+1. Implement `*ForCausalLM` under `models/` using shared `layers/`
+2. Register in `MODEL_REGISTRY` with the appropriate `model_type` key
+
+---
+
+## Testing
+
+Run kernel and component tests:
+
+```bash
+pytest tests/
+```
+---
+
+## Project Structure
+
+```
+compactor_vllm/
+├── core/ # Engine, scheduler, memory management
+│ ├── llm_engine.py
+│ ├── model_runner.py
+│ ├── scheduler.py
+│ └── memory_manager.py
+├── compression/ # Compression methods and configuration
+│ ├── compactor.py
+│ ├── snapkv.py
+│ └── compression_params.py
+├── attention/ # Attention kernels and backends
+│ ├── sparse_varlen_kernel.py
+│ └── sparse_decode_kernel.py
+├── kv_cache/ # Paged KV cache implementation
+│ ├── page_table.py
+│ └── store_kv_cache.py
+├── layers/ # Model layers
+│ ├── attention.py
+│ ├── moe.py
+│ └── ...
+├── models/ # Model implementations
+│ ├── llama.py
+│ ├── qwen3.py
+│ └── ...
+├── utils/ # Utilities and helpers
+└── triton_kernels/ # Fast MOE kernels from Triton Lang repo
+```
+
+---
+
+## Citation
+
+If you use compactor-vllm or the Compactor method in your research, please cite:
+
+```bibtex
+@article{chari2025compactor,
+ title = {Compactor: Calibrated Query-Agnostic KV Cache Compression with Approximate Leverage Scores},
+ author = {Vivek Chari and Benjamin Van Durme},
+ journal = {arXiv preprint arXiv:2507.08143},
+ year = {2025},
+ url = {https://arxiv.org/abs/2507.08143}
+}
+```
+
+---
+
+## Contributing
+
+Contributions are welcome! Please feel free to submit a Pull Request.
+
+## Acknowledgments
+
+* See https://github.com/NVIDIA/kvpress for additional compression methods in an easy-to-use format
+
+## MIT License
+
+THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
diff --git a/vllm/compactor-vllm/evaluate/eval_longbench.py b/vllm/compactor-vllm/evaluate/eval_longbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8d1979c102f11b7b640daf3d1b2f4c7af02ad90
--- /dev/null
+++ b/vllm/compactor-vllm/evaluate/eval_longbench.py
@@ -0,0 +1,124 @@
+import json
+import logging
+
+from datasets import concatenate_datasets, load_dataset
+
+from compactor_vllm import (
+ LLM,
+ LLMConfig,
+ SamplingParams,
+)
+from compactor_vllm.compression import (
+ BatchCompressionParams,
+ CompressionMethod,
+ SequenceCompressionParams,
+)
+from compactor_vllm.config.engine_config import AttentionBackend
+from longbench_metrics import dataset2metric
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s"
+ )
+ datasets = [
+ "narrativeqa",
+ "qasper",
+ "multifieldqa_en",
+ "hotpotqa",
+ "2wikimqa",
+ "musique",
+ "gov_report",
+ "qmsum",
+ "multi_news",
+ "trec",
+ "triviaqa",
+ "samsum",
+ "passage_retrieval_en",
+ "passage_count",
+ "lcc",
+ "repobench-p",
+ ]
+ dataset = concatenate_datasets(
+ [
+ load_dataset("THUDM/LongBench", n, split="test", trust_remote_code=True)
+ for n in datasets
+ ]
+ ).shuffle(seed=42)
+
+ # dataset = dataset.take(200)
+ prompts = json.load(open("longbench_config/dataset2prompt.json", "r"))
+ max_gen_lens = json.load(open("longbench_config/dataset2maxlen.json", "r"))
+
+ tokenizer_kwargs = {"add_generation_prompt": True, "enable_thinking": False}
+ dset_names = [
+ item["dataset"] if item["dataset"][-2:] != "_e" else item["dataset"][:-2]
+ for item in dataset
+ ]
+ gen_lengths = [max_gen_lens[dset_name] for dset_name in dset_names]
+
+ messages = [
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {"role": "user", "content": prompts[dset_name].format(**item)},
+ ]
+ for dset_name, item in zip(dset_names, dataset)
+ ]
+ # model = "Qwen/Qwen3-8B"
+ model = "meta-llama/Llama-3.1-8B-Instruct"
+ # model = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+ config = LLMConfig(
+ model,
+ max_num_seqs=64,
+ gpu_memory_utilization=0.95,
+ tensor_parallel_size=2,
+ max_model_len=128000,
+ attention_backend=AttentionBackend.COMPACTOR_TRITON,
+ leverage_sketch_size=32,
+ )
+ llm = LLM(config)
+ responses = llm.generate_chat(
+ messages,
+ [SamplingParams(max_new_tokens=g, temperature=0.00001) for g in gen_lengths],
+ BatchCompressionParams(
+ compression_method=CompressionMethod.COMPACTOR,
+ do_chunked_compression=False,
+ chunk_size=4096,
+ ),
+ per_sequence_compression_params=[
+ SequenceCompressionParams(
+ 0.25, protected_first_tokens=8, protected_last_tokens=64
+ )
+ ]
+ * len(messages),
+ tokenizer_kwargs=tokenizer_kwargs,
+ return_sequences=False,
+ )
+ results = {}
+ for dset_name, prediction, item in zip(dset_names, responses, dataset):
+ if dset_name not in results:
+ results[dset_name] = []
+
+ score = 0.0
+ if dset_name in ["trec", "triviaqa", "samsum", "lsht"]:
+ prediction = prediction.lstrip("\n").split("\n")[0]
+
+ for ground_truth in item["answers"]:
+ score = max(
+ score,
+ dataset2metric[dset_name](
+ prediction, ground_truth, all_classes=item["all_classes"]
+ ),
+ )
+ results[dset_name].append(score)
+
+ all_sum, all_count = 0, 0
+ for task, scores in results.items():
+ this_task_sum = sum(scores)
+ this_task_count = len(scores)
+ print(task, f"{this_task_sum / this_task_count:.2f}")
+ all_sum += sum(scores)
+ all_count += this_task_count
+ print(f"ALL: {all_sum / all_count:.2f}")
diff --git a/vllm/compactor-vllm/evaluate/eval_ruler.py b/vllm/compactor-vllm/evaluate/eval_ruler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3532c85ebf64bcda32ea9eedb1b11c5ed60dbc99
--- /dev/null
+++ b/vllm/compactor-vllm/evaluate/eval_ruler.py
@@ -0,0 +1,435 @@
+import argparse
+import logging
+import os
+import sys
+import json
+from datetime import datetime
+from pathlib import Path
+
+import torch
+from datasets import load_dataset
+
+from ruler_metrics import score_function
+
+# Allow running without `pip install -e .` by pointing to `compactor-vllm/src`.
+here = Path(__file__).resolve()
+repo_root = here.parents[1]
+src_dir = repo_root / "src"
+if src_dir.is_dir() and str(src_dir) not in sys.path:
+ sys.path.insert(0, str(src_dir))
+
+from compactor_vllm import (
+ LLM,
+ LLMConfig,
+ SamplingParams,
+) # noqa: E402
+from compactor_vllm.compression import (
+ BatchCompressionParams,
+ CompressionMethod,
+ SequenceCompressionParams,
+) # noqa: E402
+from compactor_vllm.config.engine_config import AttentionBackend # noqa: E402
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Run RULER evaluation with compactor_vllm."
+ )
+ parser.add_argument(
+ "--log-level",
+ type=str,
+ default="INFO",
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
+ help="Logging level.",
+ )
+ parser.add_argument(
+ "--dataset-length",
+ type=str,
+ default="4096",
+ help="Dataset configuration name.",
+ )
+
+ parser.add_argument(
+ "--dataset-parquet",
+ type=str,
+ default=None,
+ help=(
+ "Optional local Parquet dataset path (single .parquet file or a glob). "
+ "If provided, the script will load the dataset from local Parquet instead of "
+ "downloading 'simonjegou/ruler'."
+ ),
+ )
+ parser.add_argument(
+ "--dataset-split",
+ type=str,
+ default="test",
+ help=(
+ "Dataset split to load. For local parquet, this is typically 'train'. "
+ "For the online ruler dataset, default is 'test'."
+ ),
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="Shuffle seed for the dataset.",
+ )
+ parser.add_argument(
+ "--fraction",
+ type=float,
+ default=1.0,
+ help=(
+ "Fraction of the dataset to use in (0, 1]. "
+ "E.g., 0.1 uses 10%% of the shuffled dataset."
+ ),
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="meta-llama/Llama-3.1-8B-Instruct",
+ help="Model name or path.",
+ )
+ parser.add_argument(
+ "--max-num-seqs",
+ type=int,
+ default=32,
+ help="Maximum number of sequences to batch.",
+ )
+ parser.add_argument(
+ "--gpu-memory-utilization",
+ type=float,
+ default=0.95,
+ help="Fraction of GPU memory to use.",
+ )
+ parser.add_argument(
+ "--tensor-parallel-size",
+ type=int,
+ default=1,
+ help="Tensor parallelism degree.",
+ )
+ parser.add_argument(
+ "--max-model-len",
+ type=int,
+ default=40960,
+ help="Maximum model context length.",
+ )
+ parser.add_argument(
+ "--enforce-eager",
+ action="store_true",
+ help="Disable CUDA graph capture and always run in eager mode.",
+ )
+ backend_choices = [backend.name.lower() for backend in AttentionBackend]
+ parser.add_argument(
+ "--attention-backend",
+ type=str,
+ default="compactor_triton",
+ choices=backend_choices,
+ help=f"Attention backend to use. Choices: {backend_choices}",
+ )
+ parser.add_argument(
+ "--leverage-sketch-size",
+ type=int,
+ default=48,
+ help="Leverage sketch size for compactor attention.",
+ )
+ parser.add_argument(
+ "--max-new-tokens",
+ type=int,
+ default=256,
+ help="Maximum number of new tokens to generate.",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=0.0,
+ help="Sampling temperature (0 is greedy).",
+ )
+ method_choices = [m.name.lower() for m in CompressionMethod]
+ parser.add_argument(
+ "--compression-method",
+ type=str,
+ default="compactor",
+ choices=method_choices,
+ help=f"Compression method. Choices: {method_choices}",
+ )
+ parser.add_argument(
+ "--chunk-size",
+ type=int,
+ default=2048,
+ help="Chunk size for chunked compression.",
+ )
+ parser.add_argument(
+ "--no-chunked-compression",
+ dest="do_chunked_compression",
+ action="store_false",
+ help="Disable leverage chunked compression (enabled by default).",
+ )
+ parser.set_defaults(do_chunked_compression=True)
+ parser.add_argument(
+ "--seq-compression-ratio",
+ type=float,
+ default=0.5,
+ help="Compression ratio for SequenceCompressionParams.",
+ )
+ parser.add_argument(
+ "--protected-first-tokens",
+ type=int,
+ default=8,
+ help="Number of protected tokens at the beginning of each sequence.",
+ )
+ parser.add_argument(
+ "--extra-protected-last-tokens",
+ type=int,
+ default=16,
+ help=(
+ "Extra number of protected tokens at the end, in addition to the "
+ "tokenized length of answer_prefix+question."
+ ),
+ )
+ parser.add_argument(
+ "--tokenizer-add-generation-prompt",
+ action="store_true",
+ help="Set tokenizer_kwargs['add_generation_prompt']=True (default False).",
+ )
+ parser.add_argument(
+ "--tokenizer-enable-thinking",
+ action="store_true",
+ help="Set tokenizer_kwargs['enable_thinking']=True (default False).",
+ )
+ parser.add_argument(
+ "--no-tokenizer-continue-final-message",
+ dest="tokenizer_continue_final_message",
+ action="store_false",
+ help="Set tokenizer_kwargs['continue_final_message']=False (default True).",
+ )
+ parser.set_defaults(tokenizer_continue_final_message=True)
+
+ parser.add_argument(
+ "--results-dir",
+ type=str,
+ default="results",
+ help="Directory to save detailed evaluation results.",
+ )
+
+ return parser.parse_args()
+
+
+def main(args: argparse.Namespace) -> None:
+ torch.manual_seed(args.seed)
+ logging.basicConfig(
+ level=getattr(logging, args.log_level.upper(), logging.INFO),
+ format="%(asctime)s %(levelname)s: %(message)s",
+ )
+ logger = logging.getLogger(__name__)
+
+ if args.dataset_parquet:
+ logger.info(
+ "Loading local parquet dataset from %s (split=%s)",
+ args.dataset_parquet,
+ args.dataset_split,
+ )
+ # datasets supports a file path or glob pattern via data_files.
+ dataset = load_dataset(
+ "parquet",
+ data_files=args.dataset_parquet,
+ split=args.dataset_split,
+ )
+ else:
+ logger.info(
+ "Loading dataset %s (length=%s, split=%s)",
+ "simonjegou/ruler",
+ args.dataset_length,
+ args.dataset_split,
+ )
+ dataset = load_dataset(
+ "simonjegou/ruler",
+ args.dataset_length,
+ split=args.dataset_split,
+ )
+ if args.seed is not None and args.seed >= 0:
+ logger.info("Shuffling dataset with seed %d", args.seed)
+ dataset = dataset.shuffle(seed=args.seed)
+ if not (0 < args.fraction <= 1.0):
+ raise ValueError("--fraction must be in the interval (0, 1].")
+ if args.fraction < 1.0:
+ n_examples = max(1, int(len(dataset) * args.fraction))
+ logger.info(
+ "Using %.2f fraction of data: %d / %d examples",
+ args.fraction,
+ n_examples,
+ len(dataset),
+ )
+ dataset = dataset.select(range(n_examples))
+ else:
+ logger.info("Using full dataset: %d examples", len(dataset))
+ tokenizer_kwargs = {
+ "add_generation_prompt": args.tokenizer_add_generation_prompt,
+ "enable_thinking": args.tokenizer_enable_thinking,
+ "continue_final_message": args.tokenizer_continue_final_message,
+ }
+ messages = [
+ [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ },
+ {
+ "role": "user",
+ "content": example["context"] + " " + example["question"],
+ },
+ {
+ "role": "assistant",
+ "content": example["answer_prefix"],
+ },
+ ]
+ for example in dataset
+ ]
+ attention_backend = AttentionBackend[args.attention_backend.upper()]
+ compression_method = CompressionMethod[args.compression_method.upper()]
+ logger.info("Using model: %s", args.model)
+ model_path = args.model if os.path.isdir(args.model) else None
+ if model_path is not None:
+ logger.info("Detected local model path: %s", model_path)
+ config = LLMConfig(
+ args.model,
+ path=model_path,
+ max_num_seqs=args.max_num_seqs,
+ gpu_memory_utilization=args.gpu_memory_utilization,
+ tensor_parallel_size=args.tensor_parallel_size,
+ max_model_len=args.max_model_len,
+ enforce_eager=args.enforce_eager,
+ attention_backend=attention_backend,
+ leverage_sketch_size=args.leverage_sketch_size,
+ )
+ llm = LLM(config)
+
+ end_protected_lengths = [
+ args.extra_protected_last_tokens
+ + len(
+ llm.tokenizer(example["answer_prefix"] + example["question"])["input_ids"]
+ )
+ for example in dataset
+ ]
+
+ per_sequence_compression_params = [
+ SequenceCompressionParams(
+ args.seq_compression_ratio,
+ protected_first_tokens=args.protected_first_tokens,
+ protected_last_tokens=end_protected_length,
+ )
+ for end_protected_length in end_protected_lengths
+ ]
+
+ # Sampling params
+ sampling_params = SamplingParams(
+ max_new_tokens=args.max_new_tokens,
+ temperature=args.temperature,
+ )
+
+ # Batch compression params
+ batch_compression_params = BatchCompressionParams(
+ compression_method=compression_method,
+ do_chunked_compression=args.do_chunked_compression,
+ chunk_size=args.chunk_size,
+ )
+ logger.info("Running generate_chat on %d examples.", len(messages))
+ responses = llm.generate_chat(
+ messages,
+ sampling_params,
+ batch_compression_params,
+ per_sequence_compression_params=per_sequence_compression_params,
+ tokenizer_kwargs=tokenizer_kwargs,
+ return_sequences=False,
+ )
+ logger.info("Scoring responses.")
+ results = {}
+ per_example = []
+
+ all_sum, all_count = 0.0, 0
+
+ for idx, (example, response) in enumerate(zip(dataset, responses)):
+ task = example["task"]
+ answer = example["answer"]
+ score = score_function(
+ generated=response,
+ ground_truth=answer,
+ task_category=task,
+ )
+ if task not in results:
+ results[task] = []
+ results[task].append(score)
+
+ all_sum += score
+ all_count += 1
+
+ per_example.append(
+ {
+ "index": idx,
+ "task": task,
+ "context": example["context"],
+ "question": example["question"],
+ "answer_prefix": example["answer_prefix"],
+ "ground_truth": answer,
+ "generated": response,
+ "score": score,
+ "compression_params": {
+ "seq_compression_ratio": args.seq_compression_ratio,
+ "protected_first_tokens": args.protected_first_tokens,
+ "protected_last_tokens": end_protected_lengths[idx],
+ },
+ }
+ )
+
+ per_task_summary = {}
+ for task, scores in results.items():
+ this_task_sum = sum(scores)
+ this_task_count = len(scores)
+ avg = this_task_sum / this_task_count
+ print(task, f"{avg:.3f}")
+ per_task_summary[task] = {
+ "avg_score": avg,
+ "num_examples": this_task_count,
+ "sum_scores": this_task_sum,
+ }
+
+ overall_avg = all_sum / all_count if all_count > 0 else 0.0
+ print(f"ALL: {overall_avg:.3f}")
+
+ os.makedirs(args.results_dir, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ safe_model_name = args.model.replace("/", "_")
+ base_name = f"ruler_{args.dataset_length}_{safe_model_name}_{timestamp}"
+
+ summary_path = os.path.join(args.results_dir, base_name + "_summary.json")
+ details_path = os.path.join(args.results_dir, base_name + "_details.jsonl")
+
+ logger.info("Saving summary to %s", summary_path)
+ with open(summary_path, "w", encoding="utf-8") as f:
+ json.dump(
+ {
+ "timestamp": timestamp,
+ "model": args.model,
+ "dataset": "simonjegou/ruler",
+ "dataset_length": args.dataset_length,
+ "num_examples": len(dataset),
+ "overall_avg_score": overall_avg,
+ "per_task": per_task_summary,
+ "arguments": vars(args), # all CLI args
+ },
+ f,
+ ensure_ascii=False,
+ indent=2,
+ )
+
+ logger.info("Saving per-example details to %s", details_path)
+ with open(details_path, "w", encoding="utf-8") as f:
+ for row in per_example:
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
+
+
+if __name__ == "__main__":
+ main(parse_args())
+
+
+
+#HIP_LAUNCH_BLOCKING=1 TORCHDYNAMO_DISABLE=1 python eval_ruler.py --dataset-parquet /home/laibao/proj/kvpress/compactor-vllm/evaluate/test-00000-of-00001.parquet --dataset-split train --model /mnt/data/llm-models/Qwen3-8B/ --compression-method compactor --seq-compression-ratio 1 --enforce-eager
\ No newline at end of file
diff --git a/vllm/compactor-vllm/evaluate/longbench_config/dataset2maxlen.json b/vllm/compactor-vllm/evaluate/longbench_config/dataset2maxlen.json
new file mode 100644
index 0000000000000000000000000000000000000000..79d0d9990e5799c845ebcf839c1ee1a4ff14873e
--- /dev/null
+++ b/vllm/compactor-vllm/evaluate/longbench_config/dataset2maxlen.json
@@ -0,0 +1,23 @@
+{
+ "narrativeqa": 128,
+ "qasper": 128,
+ "multifieldqa_en": 64,
+ "multifieldqa_zh": 64,
+ "hotpotqa": 32,
+ "2wikimqa": 32,
+ "musique": 32,
+ "dureader": 128,
+ "gov_report": 512,
+ "qmsum": 512,
+ "multi_news": 512,
+ "vcsum": 512,
+ "trec": 64,
+ "triviaqa": 32,
+ "samsum": 128,
+ "lsht": 64,
+ "passage_count": 32,
+ "passage_retrieval_en": 32,
+ "passage_retrieval_zh": 32,
+ "lcc": 64,
+ "repobench-p": 64
+}
\ No newline at end of file
diff --git a/vllm/compactor-vllm/evaluate/longbench_config/dataset2prompt.json b/vllm/compactor-vllm/evaluate/longbench_config/dataset2prompt.json
new file mode 100644
index 0000000000000000000000000000000000000000..faf6cc0f847baadc42c5178c6f1f8e93bb0730b7
--- /dev/null
+++ b/vllm/compactor-vllm/evaluate/longbench_config/dataset2prompt.json
@@ -0,0 +1,23 @@
+{
+ "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: \n\n\n{context}\n \n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: \n\n\n{context}\n \n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "multifieldqa_en": "Read the following text and answer briefly.\n\n\n{context}\n \n\n Now, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n\n{context}\n \n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
+ "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "dureader": "请基于给定的文章回答下述问题。\n\n文章:\n\n\n{context}\n \n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
+ "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n\n\n{context}\n \n\nNow, write a one-page summary of the report.\n\nSummary:",
+ "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n\n\n{context}\n \n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
+ "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n\n\n{context}\n \n\nNow, write a one-page summary of all the news.\n\nSummary:",
+ "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n\n\n{context}\n \n\n会议总结:",
+ "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n\n{context}\n \n\n{input}",
+ "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n\n{context}\n \n\n{input}",
+ "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n\n{context}\n \n\n{input}",
+ "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n\n{context}\n \n\n{input}",
+ "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n\n{context}\n \n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
+ "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n\n{context}\n \n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
+ "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n\n{context}\n \n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
+ "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
+ "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
+}
\ No newline at end of file
diff --git a/vllm/compactor-vllm/evaluate/longbench_config/dataset2prompt_taskagnostic.json b/vllm/compactor-vllm/evaluate/longbench_config/dataset2prompt_taskagnostic.json
new file mode 100644
index 0000000000000000000000000000000000000000..faf6cc0f847baadc42c5178c6f1f8e93bb0730b7
--- /dev/null
+++ b/vllm/compactor-vllm/evaluate/longbench_config/dataset2prompt_taskagnostic.json
@@ -0,0 +1,23 @@
+{
+ "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: \n\n\n{context}\n \n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: \n\n\n{context}\n \n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "multifieldqa_en": "Read the following text and answer briefly.\n\n\n{context}\n \n\n Now, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n\n{context}\n \n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
+ "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n\n\n{context}\n \n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "dureader": "请基于给定的文章回答下述问题。\n\n文章:\n\n\n{context}\n \n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
+ "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n\n\n{context}\n \n\nNow, write a one-page summary of the report.\n\nSummary:",
+ "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n\n\n{context}\n \n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
+ "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n\n\n{context}\n \n\nNow, write a one-page summary of all the news.\n\nSummary:",
+ "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n\n\n{context}\n \n\n会议总结:",
+ "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n\n{context}\n \n\n{input}",
+ "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n\n{context}\n \n\n{input}",
+ "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n\n{context}\n \n\n{input}",
+ "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n\n{context}\n \n\n{input}",
+ "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n\n{context}\n \n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
+ "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n\n{context}\n \n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
+ "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n\n{context}\n \n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
+ "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
+ "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
+}
\ No newline at end of file
diff --git a/vllm/compactor-vllm/evaluate/longbench_metrics.py b/vllm/compactor-vllm/evaluate/longbench_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfdac7a67267b6390aa2be1abba1b9de6002fa0d
--- /dev/null
+++ b/vllm/compactor-vllm/evaluate/longbench_metrics.py
@@ -0,0 +1,176 @@
+import re
+import string
+from collections import Counter
+
+import jieba
+from fuzzywuzzy import fuzz
+from rouge import Rouge
+
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def normalize_zh_answer(s):
+ """Lower text and remove punctuation, extra whitespace."""
+
+ def white_space_fix(text):
+ return "".join(text.split())
+
+ def remove_punc(text):
+ cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
+ all_punctuation = set(string.punctuation + cn_punctuation)
+ return "".join(ch for ch in text if ch not in all_punctuation)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_punc(lower(s)))
+
+
+def count_score(prediction, ground_truth, **kwargs):
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def retrieval_score(prediction, ground_truth, **kwargs):
+ pattern = r"Paragraph (\d+)"
+ matches = re.findall(pattern, ground_truth)
+ ground_truth_id = matches[0]
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth_id):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def retrieval_zh_score(prediction, ground_truth, **kwargs):
+ pattern = r"段落(\d+)"
+ matches = re.findall(pattern, ground_truth)
+ ground_truth_id = matches[0]
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth_id):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def code_sim_score(prediction, ground_truth, **kwargs):
+ all_lines = prediction.lstrip("\n").split("\n")
+ prediction = ""
+ for line in all_lines:
+ if ("`" not in line) and ("#" not in line) and ("//" not in line):
+ prediction = line
+ break
+ return fuzz.ratio(prediction, ground_truth) / 100
+
+
+def classification_score(prediction, ground_truth, **kwargs):
+ em_match_list = []
+ all_classes = kwargs["all_classes"]
+ for class_name in all_classes:
+ if class_name in prediction:
+ em_match_list.append(class_name)
+ for match_term in em_match_list:
+ if match_term in ground_truth and match_term != ground_truth:
+ em_match_list.remove(match_term)
+ if ground_truth in em_match_list:
+ score = 1.0 / len(em_match_list)
+ else:
+ score = 0.0
+ return score
+
+
+def rouge_score(prediction, ground_truth, **kwargs):
+ rouge = Rouge()
+ try:
+ scores = rouge.get_scores([prediction], [ground_truth], avg=True)
+ except:
+ return 0.0
+ return scores["rouge-l"]["f"]
+
+
+def rouge_zh_score(prediction, ground_truth, **kwargs):
+ prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
+ ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
+ score = rouge_score(prediction, ground_truth)
+ return score
+
+
+def f1_score(prediction, ground_truth, **kwargs):
+ common = Counter(prediction) & Counter(ground_truth)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction)
+ recall = 1.0 * num_same / len(ground_truth)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def qa_f1_score(prediction, ground_truth, **kwargs):
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(ground_truth)
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ return f1_score(prediction_tokens, ground_truth_tokens)
+
+
+def qa_f1_zh_score(prediction, ground_truth, **kwargs):
+ prediction_tokens = list(jieba.cut(prediction, cut_all=False))
+ ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
+ prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
+ ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
+ prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
+ ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
+ return f1_score(prediction_tokens, ground_truth_tokens)
+
+
+dataset2metric = {
+ "narrativeqa": qa_f1_score,
+ "qasper": qa_f1_score,
+ "multifieldqa_en": qa_f1_score,
+ "multifieldqa_zh": qa_f1_zh_score,
+ "hotpotqa": qa_f1_score,
+ "2wikimqa": qa_f1_score,
+ "musique": qa_f1_score,
+ "dureader": rouge_zh_score,
+ "gov_report": rouge_score,
+ "qmsum": rouge_score,
+ "multi_news": rouge_score,
+ "vcsum": rouge_zh_score,
+ "trec": classification_score,
+ "triviaqa": qa_f1_score,
+ "samsum": rouge_score,
+ "lsht": classification_score,
+ "passage_retrieval_en": retrieval_score,
+ "passage_count": count_score,
+ "passage_retrieval_zh": retrieval_zh_score,
+ "lcc": code_sim_score,
+ "repobench-p": code_sim_score,
+}
diff --git a/vllm/compactor-vllm/evaluate/ruler_metrics.py b/vllm/compactor-vllm/evaluate/ruler_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dbda53e7d30214d2e1f42a5e1a51baa2d32251f
--- /dev/null
+++ b/vllm/compactor-vllm/evaluate/ruler_metrics.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import re
+from typing import List
+
+import pandas as pd
+
+
+def string_match_part(preds, refs):
+ score = (
+ sum(
+ [
+ max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
+ for pred, ref in zip(preds, refs)
+ ]
+ )
+ / len(preds)
+ * 100
+ )
+ return round(score, 2)
+
+
+def string_match_all(preds, refs):
+ score = (
+ sum(
+ [
+ sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
+ for pred, ref in zip(preds, refs)
+ ]
+ )
+ / len(preds)
+ * 100
+ )
+ return round(score, 2)
+
+
+def calculate_metrics(df: pd.DataFrame) -> dict:
+ scores = {}
+
+ np_pattern = re.compile(r"[\x00-\x1f]")
+ df["predicted_answer"] = df["predicted_answer"].apply(
+ lambda x: np_pattern.sub("", x.strip()).strip()
+ )
+
+ for task, df_task in df.groupby("task"):
+ task_category = task.split("_")[0]
+ metric_fn = string_match_part if task_category == "qa" else string_match_all
+ preds = df_task["predicted_answer"].tolist()
+ refs = df_task["answer"].tolist()
+ score = metric_fn(preds, refs)
+ scores[task] = {"string_match": score}
+ return scores
+
+
+def score_function(*, generated, ground_truth: List[str], task_category: str):
+ np_pattern = re.compile(r"[\x00-\x1f]")
+ generated = np_pattern.sub("", generated.strip()).strip()
+ task_category = task_category.split("_")[0]
+ metric_fn = string_match_part if task_category == "qa" else string_match_all
+ return metric_fn([generated], [ground_truth])
diff --git a/vllm/compactor-vllm/evaluate/test-00000-of-00001.parquet b/vllm/compactor-vllm/evaluate/test-00000-of-00001.parquet
new file mode 100644
index 0000000000000000000000000000000000000000..a4899bd5767998cd65777b07f2d5c4fd7a4973c9
Binary files /dev/null and b/vllm/compactor-vllm/evaluate/test-00000-of-00001.parquet differ
diff --git a/vllm/compactor-vllm/evaluate/test.py b/vllm/compactor-vllm/evaluate/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfd4e2275190ed3f12b1f7f6fae6da04ce7e2349
--- /dev/null
+++ b/vllm/compactor-vllm/evaluate/test.py
@@ -0,0 +1,218 @@
+import argparse
+import inspect
+import logging
+import os
+import sys
+from pathlib import Path
+
+
+def _maybe_add_src_to_path() -> None:
+ # Allow running without `pip install -e .` by pointing to `compactor-vllm/src`.
+ here = Path(__file__).resolve()
+ repo_root = here.parents[1]
+ src_dir = repo_root / "src"
+ if src_dir.is_dir() and str(src_dir) not in sys.path:
+ sys.path.insert(0, str(src_dir))
+
+
+_maybe_add_src_to_path()
+
+from compactor_vllm import LLM, LLMConfig, SamplingParams # noqa: E402
+from compactor_vllm.compression import ( # noqa: E402
+ BatchCompressionParams,
+ CompressionMethod,
+ SequenceCompressionParams,
+)
+from compactor_vllm.config.engine_config import AttentionBackend # noqa: E402
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Minimal smoke test for compactor-vllm (no speculative decoding)."
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default=os.environ.get("MODEL", "/mnt/data/llm-models/Qwen3-8B"),
+ help="Local model directory or HF id. In the container this is usually a local dir.",
+ )
+ parser.add_argument(
+ "--tp",
+ type=int,
+ default=int(os.environ.get("TP", "1")),
+ help="Tensor parallel size (world size).",
+ )
+ parser.add_argument(
+ "--nccl-port",
+ type=int,
+ default=int(os.environ.get("NCCL_PORT", "1218")),
+ help="TCP port for torch.distributed init (only used for NCCL init_method=tcp://localhost:).",
+ )
+ parser.add_argument("--max-model-len", type=int, default=2048)
+ parser.add_argument("--max-num-seqs", type=int, default=2)
+ parser.add_argument(
+ "--gpu-memory-utilization",
+ type=float,
+ default=float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9")),
+ help="Fraction of total GPU memory used for KV cache + activations.",
+ )
+ parser.add_argument(
+ "--attention-backend",
+ type=str,
+ default="compactor_triton",
+ choices=[b.name.lower() for b in AttentionBackend],
+ )
+ parser.add_argument(
+ "--compression-method",
+ type=str,
+ default="compactor",
+ choices=[m.name.lower() for m in CompressionMethod],
+ )
+ parser.add_argument(
+ "--compression-ratio",
+ type=float,
+ default=0.8,
+ help="Sequence-level compression ratio (e.g. 0.8 keeps 80%% of tokens).",
+ )
+ parser.add_argument("--chunk-size", type=int, default=512)
+ parser.add_argument(
+ "--no-chunked-compression",
+ dest="do_chunked_compression",
+ action="store_false",
+ )
+ parser.set_defaults(do_chunked_compression=True)
+
+ parser.add_argument("--prompt", type=str, default="用一句话介绍你自己,给我讲一个故事,200字左右。")
+ parser.add_argument("--max-new-tokens", type=int, default=64)
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=0.0,
+ help="0.0 = greedy decoding (recommended for smoke tests).",
+ )
+ parser.add_argument(
+ "--tokenizer-enable-thinking",
+ dest="tokenizer_enable_thinking",
+ action="store_true",
+ help="Pass enable_thinking=True to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.add_argument(
+ "--no-tokenizer-enable-thinking",
+ dest="tokenizer_enable_thinking",
+ action="store_false",
+ help="Pass enable_thinking=False to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.set_defaults(tokenizer_enable_thinking=False)
+ parser.add_argument(
+ "--tokenizer-add-generation-prompt",
+ dest="tokenizer_add_generation_prompt",
+ action="store_true",
+ help="Pass add_generation_prompt=True to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.add_argument(
+ "--no-tokenizer-add-generation-prompt",
+ dest="tokenizer_add_generation_prompt",
+ action="store_false",
+ help="Pass add_generation_prompt=False to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.set_defaults(tokenizer_add_generation_prompt=True)
+ parser.add_argument(
+ "--tokenizer-continue-final-message",
+ dest="tokenizer_continue_final_message",
+ action="store_true",
+ help="Pass continue_final_message=True to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.add_argument(
+ "--no-tokenizer-continue-final-message",
+ dest="tokenizer_continue_final_message",
+ action="store_false",
+ help="Pass continue_final_message=False to tokenizer.apply_chat_template (if supported).",
+ )
+ parser.set_defaults(tokenizer_continue_final_message=False)
+ parser.add_argument(
+ "--skip-special-tokens",
+ dest="skip_special_tokens",
+ action="store_true",
+ help="Skip special tokens in output decoding (recommended).",
+ )
+ parser.add_argument(
+ "--no-skip-special-tokens",
+ dest="skip_special_tokens",
+ action="store_false",
+ help="Keep special tokens in output decoding (e.g. <|im_end|>).",
+ )
+ parser.set_defaults(skip_special_tokens=True)
+ parser.add_argument(
+ "--log-level",
+ type=str,
+ default="INFO",
+ choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = _parse_args()
+ logging.basicConfig(
+ level=getattr(logging, args.log_level.upper()),
+ format="%(asctime)s - %(levelname)s - %(message)s",
+ )
+
+ attention_backend = AttentionBackend[args.attention_backend.upper()]
+ compression_method = CompressionMethod[args.compression_method.upper()]
+
+ model = args.model
+ cfg = LLMConfig(
+ model=model,
+ path=model,
+ tensor_parallel_size=int(args.tp),
+ nccl_port=int(args.nccl_port),
+ max_model_len=int(args.max_model_len),
+ max_num_seqs=int(args.max_num_seqs),
+ gpu_memory_utilization=float(args.gpu_memory_utilization),
+ enforce_eager=True,
+ attention_backend=attention_backend,
+ show_progress_bar=False,
+ )
+ llm = LLM(cfg)
+
+ tokenizer_kwargs = {
+ "add_generation_prompt": bool(args.tokenizer_add_generation_prompt),
+ "enable_thinking": bool(args.tokenizer_enable_thinking),
+ "continue_final_message": bool(args.tokenizer_continue_final_message),
+ }
+ if tokenizer_kwargs.get("add_generation_prompt") and tokenizer_kwargs.get(
+ "continue_final_message"
+ ):
+ # HF tokenizer API rejects these being simultaneously True.
+ tokenizer_kwargs["continue_final_message"] = False
+ # Be defensive: only pass kwargs supported by this tokenizer build.
+ try:
+ supported = set(inspect.signature(llm.tokenizer.apply_chat_template).parameters)
+ tokenizer_kwargs = {k: v for k, v in tokenizer_kwargs.items() if k in supported}
+ except (TypeError, ValueError):
+ pass
+
+ outs = llm.generate_chat(
+ [[{"role": "user", "content": args.prompt}]],
+ sampling_params=SamplingParams(
+ temperature=float(args.temperature),
+ max_new_tokens=int(args.max_new_tokens),
+ ),
+ batch_compression_params=BatchCompressionParams(
+ compression_method=compression_method,
+ do_chunked_compression=bool(args.do_chunked_compression),
+ chunk_size=int(args.chunk_size),
+ ),
+ per_sequence_compression_params=SequenceCompressionParams(
+ compression_ratio=float(args.compression_ratio),
+ ),
+ tokenizer_kwargs=tokenizer_kwargs,
+ detokenizer_kwargs={"skip_special_tokens": bool(args.skip_special_tokens)},
+ )
+ print(outs[0])
+ llm.exit()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vllm/compactor-vllm/flash_attn_vs_triton_h100.png b/vllm/compactor-vllm/flash_attn_vs_triton_h100.png
new file mode 100644
index 0000000000000000000000000000000000000000..e74103b1a0236116330917b10c01f4deb3181d50
Binary files /dev/null and b/vllm/compactor-vllm/flash_attn_vs_triton_h100.png differ
diff --git a/vllm/compactor-vllm/pip_list.txt b/vllm/compactor-vllm/pip_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4fb0b822c41f12e16a11766d8b88f863ea2a4790
--- /dev/null
+++ b/vllm/compactor-vllm/pip_list.txt
@@ -0,0 +1,274 @@
+Package Version
+---------------------------------- ------------------------------------------
+accelerate 1.12.0
+addict 2.4.0
+aiofiles 25.1.0
+aiohappyeyeballs 2.6.1
+aiohttp 3.13.2
+aiohttp-cors 0.8.1
+aiosignal 1.4.0
+airportsdata 20250909
+amdsmi 24.5.3+02cbffb.dirty
+annotated-doc 0.0.4
+annotated-types 0.7.0
+anyio 4.12.0
+apex 1.5.0+das.opt1.dtk25042
+astor 0.8.1
+async-timeout 5.0.1
+attrs 25.4.0
+backports.asyncio.runner 1.2.0
+blake3 1.0.8
+blinker 1.9.0
+boto3 1.42.10
+botocore 1.42.10
+cachetools 6.2.4
+certifi 2025.11.12
+charset-normalizer 3.4.4
+click 8.2.1
+cloudpickle 3.1.2
+cmake 3.29.0
+coloredlogs 15.0.1
+colorful 0.5.8
+compressed-tensors 0.10.2
+contourpy 1.3.2
+cryptography 3.4.8
+cupy 12.3.0
+cycler 0.12.1
+datasets 4.4.1
+dbus-python 1.2.18
+dcu-megatron 0.13.0+das.opt1.dtk25042
+deepspeed 0.15.4+das.opt1.dtk25042
+depyf 0.18.0
+dgl 2.2.1+das.opt1.dtk25042
+dill 0.4.0
+diskcache 5.6.3
+distlib 0.4.0
+distro 1.7.0
+dnspython 2.8.0
+dropout_layer_norm 2.6.1+das.opt1.dtk2504
+eft 0.0.7
+einops 0.8.1
+email-validator 2.3.0
+exceptiongroup 1.3.1
+fastapi 0.124.4
+fastapi-cli 0.0.16
+fastapi-cloud-cli 0.6.0
+fastar 0.8.0
+fastpt 2.1.1+das.dtk25042
+fastrlock 0.8.3
+filelock 3.20.1
+flash_attn 2.6.1+das.opt1.dtk2504.20251216.gbd5c0f0c
+flash_mla 1.0.0+das.opt1.dtk2504.20251210.g124c5ef1
+Flask 3.1.2
+flatbuffers 25.9.23
+fonttools 4.61.1
+frozenlist 1.8.0
+fsspec 2025.12.0
+fused_dense_lib 2.6.1+das.opt1.dtk2504
+future 1.0.0
+gguf 0.17.1
+google-api-core 2.28.1
+google-auth 2.45.0
+googleapis-common-protos 1.72.0
+greenlet 3.3.0
+grouped-gemm 0.5.0+das.dtk2504
+grouped-gemm-int4 0.5.0+das.dtk2504
+grpcio 1.76.0
+h11 0.16.0
+h2 4.3.0
+hf-xet 1.2.0
+hiredis 3.3.0
+hjson 3.1.0
+hpack 4.1.0
+httpcore 1.0.9
+httplib2 0.20.2
+httptools 0.7.1
+httpx 0.28.1
+huggingface-hub 0.36.0
+humanfriendly 10.0
+humanize 4.14.0
+Hypercorn 0.18.0
+hyperframe 6.1.0
+hypothesis 5.35.1
+idna 3.11
+importlib_metadata 8.7.0
+iniconfig 2.3.0
+interegular 0.3.3
+itsdangerous 2.2.0
+jeepney 0.7.1
+Jinja2 3.1.6
+jiter 0.12.0
+jmespath 1.0.1
+jsonschema 4.25.1
+jsonschema-specifications 2025.9.1
+keyring 23.5.0
+kiwisolver 1.4.9
+lark 1.2.2
+launchpadlib 1.10.16
+lazr.restfulclient 0.14.4
+lazr.uri 1.0.6
+libnacl 2.1.0
+lightop 0.6.0+das.dtk25042.20251216.g3830d4e2
+llguidance 0.7.30
+llvmlite 0.44.0
+lm-format-enforcer 0.10.12
+lmslim 0.3.1+das.opt1.dtk25042.20251202.g07a5af3e
+markdown-it-py 4.0.0
+MarkupSafe 3.0.3
+matplotlib 3.10.8
+mdurl 0.1.2
+megatron-core 0.13.2
+mistral_common 1.8.6
+mmcv 2.2.0+das.opt1.dtk25042
+mmengine 0.10.7
+moe-w8a8 0.0.1+das.dtk2504
+moe-w8a8-prefill-gemm 0.0.1+das.dtk2504
+more-itertools 8.10.0
+mpmath 1.3.0
+msgpack 1.1.2
+msgspec 0.20.0
+multidict 6.7.0
+multiprocess 0.70.18
+nest-asyncio 1.6.0
+networkx 3.4.2
+ninja 1.11.1
+numa 1.4.6
+numba 0.61.2
+numpy 1.25.0
+nvidia-cublas-cu12 12.4.5.8
+nvidia-cuda-cupti-cu12 12.4.127
+nvidia-cuda-nvrtc-cu12 12.4.127
+nvidia-cuda-runtime-cu12 12.4.127
+nvidia-cudnn-cu12 9.1.0.70
+nvidia-cufft-cu12 11.2.1.3
+nvidia-curand-cu12 10.3.5.147
+nvidia-cusolver-cu12 11.6.1.9
+nvidia-cusparse-cu12 12.3.1.170
+nvidia-nccl-cu12 2.21.5
+nvidia-nvjitlink-cu12 12.4.127
+nvidia-nvtx-cu12 12.4.127
+oauthlib 3.2.0
+onnxruntime 1.19.2+das.opt1.dtk25042
+openai 1.90.0
+opencensus 0.11.4
+opencensus-context 0.1.3
+opencv-python 4.12.0.88
+opencv-python-headless 4.12.0.88
+opentelemetry-api 1.39.1
+opentelemetry-exporter-prometheus 0.60b1
+opentelemetry-proto 1.39.1
+opentelemetry-sdk 1.39.1
+opentelemetry-semantic-conventions 0.60b1
+outlines 0.1.11
+outlines_core 0.1.26
+packaging 25.0
+pandas 2.3.3
+partial-json-parser 0.2.1.1.post7
+peft 0.18.0
+pillow 12.0.0
+pip 25.3
+platformdirs 4.5.1
+pluggy 1.6.0
+priority 2.0.0
+prometheus_client 0.23.1
+prometheus-fastapi-instrumentator 7.1.0
+propcache 0.4.1
+proto-plus 1.26.1
+protobuf 6.33.2
+psutil 7.1.3
+py-cpuinfo 9.0.0
+py-spy 0.4.1
+pyarrow 22.0.0
+pyasn1 0.6.1
+pyasn1_modules 0.4.2
+pybase64 1.4.3
+pycountry 24.6.1
+pydantic 2.12.5
+pydantic_core 2.41.5
+pydantic-extra-types 2.10.6
+Pygments 2.19.2
+PyGObject 3.42.1
+PyHive 0.7.0
+PyJWT 2.3.0
+PyMySQL 1.1.2
+pyparsing 3.2.5
+pytest 9.0.2
+pytest-asyncio 1.3.0
+python-apt 2.4.0+ubuntu4
+python-dateutil 2.9.0.post0
+python-dotenv 1.2.1
+python-json-logger 4.0.0
+python-multipart 0.0.20
+PyTrie 0.4.0
+pytz 2025.2
+PyYAML 6.0.3
+pyzmq 27.1.0
+Quart 0.20.0
+ray 2.48.0
+redis 7.1.0
+referencing 0.37.0
+regex 2025.11.3
+requests 2.32.5
+rich 14.2.0
+rich-toolkit 0.17.0
+rignore 0.7.6
+rotary_emb 2.6.1+das.opt1.dtk2504
+rpds-py 0.30.0
+rsa 4.9.1
+runai-model-streamer 0.11.0
+runai-model-streamer-s3 0.11.0
+s3transfer 0.16.0
+safetensors 0.7.0
+scipy 1.15.3
+SecretStorage 3.3.1
+sentencepiece 0.2.1
+sentry-sdk 2.47.0
+setuptools 80.8.0
+setuptools-scm 9.2.2
+shellingham 1.5.4
+six 1.16.0
+smart_open 7.5.0
+sniffio 1.3.1
+sortedcontainers 2.4.0
+SQLAlchemy 2.0.45
+starlette 0.50.0
+sympy 1.13.1
+taskgroup 0.2.2
+tensorboardX 2.6.4
+tensorizer 2.12.0
+termcolor 3.2.0
+threadpoolctl 3.6.0
+tiktoken 0.12.0
+tokenizers 0.22.1
+tomli 2.3.0
+torch 2.5.1+das.opt1.dtk25042
+torchaudio 2.5.1+das.opt1.dtk25042
+torchdata 0.8.0
+torchvision 0.20.1+das.opt1.dtk25042
+tqdm 4.67.1
+transformer_engine 2.5.0+das.opt1.dtk25042
+transformers 4.57.3
+triton 3.1+das.opt1.dtk25042
+typer 0.20.0
+typer-slim 0.20.0
+typing_extensions 4.15.0
+typing-inspection 0.4.2
+tzdata 2025.3
+urllib3 2.6.2
+uvicorn 0.38.0
+uvloop 0.22.1
+virtualenv 20.35.4
+vllm 0.9.2+das.opt2.ffcc47b.dtk25042
+wadllib 1.3.6
+watchfiles 1.1.1
+websockets 15.0.1
+Werkzeug 3.1.4
+wheel 0.37.1
+wrapt 2.0.1
+wsproto 1.3.2
+xentropy_cuda_lib 2.6.1+das.opt1.dtk2504
+xgrammar 0.1.19
+xxhash 3.6.0
+yapf 0.43.0
+yarl 1.22.0
+zipp 3.23.0
diff --git a/vllm/compactor-vllm/pyproject.toml b/vllm/compactor-vllm/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..2196dc3fc933d5d9d1083835efa7fbeb1d66c7ea
--- /dev/null
+++ b/vllm/compactor-vllm/pyproject.toml
@@ -0,0 +1,23 @@
+[project]
+name = "compactor-vllm"
+description = "Fast KV Cache Compression for LLMs"
+version = "0.0.1"
+dependencies = [
+ # "triton>=3.5.0",
+ "transformers",
+ # "torch>=2.9.0",
+ "safetensors",
+ "tqdm",
+ "flash-attn",
+ "pytest"
+]
+requires-python = ">= 3.8"
+authors = [
+ {name = "Vivek Chari", email = "viveknchari@gmail.com"},
+]
+[project.optional-dependencies]
+evaluate = ["rouge", "pandas", "fuzzywuzzy"]
+[tool.ruff]
+exclude = [
+ "triton_kernels"
+]
diff --git a/vllm/compactor-vllm/src/compactor_vllm/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bc62b89a2e56745a077fce61a70c15c4a391014
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/__init__.py
@@ -0,0 +1,17 @@
+from compactor_vllm.compression import CompressionMethod
+from compactor_vllm.config.engine_config import AttentionBackend, LLMConfig
+from compactor_vllm.config.sampling_params import SamplingParams
+from compactor_vllm.core.llm_engine import LLMEngine as _LLMEngine
+
+
+class LLM(_LLMEngine):
+ pass
+
+
+__all__ = [
+ "LLMConfig",
+ "LLM",
+ "SamplingParams",
+ "AttentionBackend",
+ "CompressionMethod",
+]
diff --git a/vllm/compactor-vllm/src/compactor_vllm/attention/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/attention/compile_kernels.py b/vllm/compactor-vllm/src/compactor_vllm/attention/compile_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..9341ca6dbfe8c92d0a9eb1422658d582ded874b9
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/attention/compile_kernels.py
@@ -0,0 +1,261 @@
+import argparse
+import logging
+import math
+
+import torch
+from compactor_vllm.attention.sparse_varlen_kernel import (
+ causal_sparse_varlen_with_cache,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def build_mock_paged_cache_from_lengths(
+ L_cache_per_b: torch.Tensor,
+ HKV: int,
+ D: int,
+ PAGE_SIZE: int,
+ N_LOGICAL_PAGES_MAX: int,
+ device,
+ dtype,
+):
+ B = len(L_cache_per_b)
+ max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
+ assert (L_cache_per_b <= max_len).all()
+
+ seq_lens_bh = torch.empty((B, HKV), dtype=torch.int32, device=device)
+ for b in range(B):
+ seq_lens_bh[b, :].fill_(L_cache_per_b[b])
+
+ num_phys_pages = B * HKV * N_LOGICAL_PAGES_MAX
+ CACHE_SIZE = num_phys_pages * PAGE_SIZE
+
+ K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
+ V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
+ page_table = torch.empty(
+ (B, HKV, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
+ )
+
+ # assign unique physical pages per (b, h, lp)
+ phys_page = 0
+ for b in range(B):
+ for h in range(HKV):
+ for lp in range(N_LOGICAL_PAGES_MAX):
+ page_table[b, h, lp] = phys_page
+ phys_page += 1
+
+ for b in range(B):
+ Lc = int(L_cache_per_b[b].item())
+ for h in range(HKV):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b, h, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_cache[idx] = torch.randn(D, device=device, dtype=dtype)
+ V_cache[idx] = torch.randn(D, device=device, dtype=dtype)
+
+ return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
+
+
+def autotune_causal_sparse_varlen_with_cache(
+ *,
+ max_length: int = 16384,
+ HKV: int = 8,
+ HQ: int = 32,
+ D: int = 128,
+ PAGE_SIZE: int = 128,
+ device: str = "cuda",
+ dtype=torch.float16,
+):
+ """
+ Autotune causal_sparse_varlen_with_cache over a sweep of cache/append lengths.
+ """
+ import itertools
+
+ import tqdm
+
+ N_LOGICAL_PAGES_MAX = ((max_length + PAGE_SIZE - 1) // PAGE_SIZE) * PAGE_SIZE
+ B = 4
+
+ # D must be a power of two (kernel requirement).
+ assert (D & (D - 1)) == 0
+
+ lengths_to_sweep = [0, 256]
+ i = 9
+ while (v := (1 << i)) < max_length:
+ lengths_to_sweep.append(v)
+ i += 1
+
+ combos = list(itertools.product(lengths_to_sweep, repeat=2))
+ logger.info(
+ "tuning kernels. this may take a few minutes, "
+ "but only needs to be run once per LLMConfig"
+ )
+
+ for cache_l, append_l in tqdm.tqdm(combos):
+ if cache_l + append_l == 0:
+ continue
+
+ L_cache_per_b = torch.tensor(
+ [cache_l] * B,
+ device=device,
+ dtype=torch.int32,
+ )
+ assert (L_cache_per_b <= PAGE_SIZE * N_LOGICAL_PAGES_MAX).all()
+ K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
+ build_mock_paged_cache_from_lengths(
+ L_cache_per_b=L_cache_per_b,
+ HKV=HKV,
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ L_app_list = [append_l] * B
+ cu = [0]
+ for L in L_app_list:
+ cu.append(cu[-1] + L)
+ cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
+ N = int(cu_seqlens_qk[-1].item())
+
+ max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
+ max_seqlen_k = seq_lens_bh.max().item()
+ q_raw = torch.randn(N, HQ, D, device=device, dtype=dtype)
+ k_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
+ v_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
+
+ # Identity batch mapping (local batch index == global)
+ batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
+
+ sm_scale = 1.0 / math.sqrt(D)
+
+ causal_sparse_varlen_with_cache(
+ q=q_raw,
+ k_cache=K_cache,
+ v_cache=V_cache,
+ k=k_append_raw,
+ v=v_append_raw,
+ seq_lens_bh=seq_lens_bh,
+ global_page_table=page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=cu_seqlens_qk,
+ HKV=HKV,
+ PAGE_SIZE=PAGE_SIZE,
+ sm_scale=sm_scale,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k_cache=max_seqlen_k,
+ )
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Autotune Triton kernels. "
+ "Results are cached, so this should only need to be run once per configuration."
+ "This script doesn't need to be run, as the kernels will be autotuned at runtime"
+ "if no cached autotuning data exists. Running this before hand will prevent run-time"
+ "autotuning, which will accelerate compactor-vllm at inference time."
+ )
+ parser.add_argument(
+ "--max-length",
+ type=int,
+ default=16384,
+ help="Maximum total sequence length to consider.",
+ )
+ parser.add_argument(
+ "--HKV",
+ type=int,
+ default=8,
+ help="Number of KV heads.",
+ )
+ parser.add_argument(
+ "--HQ",
+ type=int,
+ default=32,
+ help="Number of query heads.",
+ )
+ parser.add_argument(
+ "--D",
+ type=int,
+ default=128,
+ help="Per-head hidden dimension (must be power of 2).",
+ )
+ parser.add_argument(
+ "--page-size",
+ type=int,
+ default=128,
+ help="Page size (tokens per physical page).",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Torch device to run on (e.g. 'cuda', 'cuda:0', 'cpu').",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ default="float16",
+ help="Dtype for tensors: one of {float16, fp16, bfloat16, bf16, float32, fp32}.",
+ )
+ parser.add_argument(
+ "--log-level",
+ type=str,
+ default="INFO",
+ choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
+ help="Logging level.",
+ )
+ return parser.parse_args()
+
+
+def _resolve_dtype(dtype_str: str):
+ s = dtype_str.lower()
+ if s in ("float16", "fp16", "half"):
+ return torch.float16
+ if s in ("bfloat16", "bf16"):
+ return torch.bfloat16
+ if s in ("float32", "fp32"):
+ return torch.float32
+ raise ValueError(f"Unsupported dtype: {dtype_str}")
+
+
+def main():
+ args = _parse_args()
+ logging.basicConfig(
+ level=getattr(logging, args.log_level.upper()),
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ )
+
+ dtype = _resolve_dtype(args.dtype)
+ logger.info(
+ "Starting autotune with max_length=%d, HKV=%d, HQ=%d, D=%d, page_size=%d, "
+ "device=%s, dtype=%s",
+ args.max_length,
+ args.HKV,
+ args.HQ,
+ args.D,
+ args.page_size,
+ args.device,
+ dtype,
+ )
+
+ autotune_causal_sparse_varlen_with_cache(
+ max_length=args.max_length,
+ HKV=args.HKV,
+ HQ=args.HQ,
+ D=args.D,
+ PAGE_SIZE=args.page_size,
+ device=args.device,
+ dtype=dtype,
+ )
+
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s: %(message)s",
+ )
+ main()
diff --git a/vllm/compactor-vllm/src/compactor_vllm/attention/sparse_decode_kernel.py b/vllm/compactor-vllm/src/compactor_vllm/attention/sparse_decode_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5f1e4f0943da5fe6a7c1871cd243644d285bb66
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/attention/sparse_decode_kernel.py
@@ -0,0 +1,401 @@
+import functools
+import math
+
+import torch
+import triton
+import triton.language as tl
+
+from compactor_vllm.utils.triton_compat import (
+ autotune as triton_autotune,
+ maybe_set_allocator,
+)
+
+
+def head_sparse_decode_attention(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_lens_bh: torch.Tensor,
+ global_page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ HKV: int,
+ PAGE_SIZE: int,
+ sm_scale: float = None,
+ key_split: int = None,
+):
+ """
+ Decode-time head-sparse attention over a paged KV cache.
+
+ This is a wrapper around the Triton decode kernel used during incremental
+ generation. For each batch, we read the cached keys
+ and values from a global paged KV buffer, apply causal attention with one
+ new query token, and return the attention output.
+
+ The KV cache is stored in a single global K/V tensor of shape
+ ``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
+ (batch, kv_head, token_idx) is mapped to a physical row in the cache by:
+
+ 1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
+ 2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.
+
+ Grouped-query attention (GQA / MQA) is supported by passing more query
+ heads than KV heads (``HQ`` must be a multiple of ``HKV``).
+
+ Args:
+ :param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]``
+ containing the new decode tokens for each sequence in the launch batch.
+ :param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
+ backing buffer for all (batch, head) KV pages.
+ :param v: Global value cache of shape ``[CACHE_SIZE, D]``.
+ :param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
+ local batch index and KV head, the number of valid cached tokens
+ in the paged KV cache.
+ :param global_page_table: Tensor of shape
+ ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
+ ``(true_batch_idx, kv_head, logical_page)`` to a physical page id
+ in the global cache.
+ :param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
+ index used by this call to the true batch row used to index
+ ``global_page_table``.
+ :param HKV: Number of KV heads.
+ :param PAGE_SIZE: Number of tokens stored per physical KV page.
+ :param sm_scale: Optional scaling factor applied to the attention logits
+ before softmax. If ``None``, ``1 / sqrt(D)`` is used.
+ :param key_split: Optional number of splits along the key sequence length.
+ If > 1, the kernel will process the KV sequence in ``key_split``
+ chunks to reduce on-chip memory usage. If ``None`` or 0, a
+ heuristic is used.
+
+ Returns:
+ :return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
+ device and dtype as ``q``.
+ """
+
+ with torch.cuda.device(q.device):
+ if q.ndim != 3:
+ assert q.ndim == 4
+ B, HQ, S, D = q.shape
+ assert S == 1, "head_sparse_decode_attention only supports q_len=1"
+ q = q.squeeze(-2)
+ elif q.ndim == 3:
+ B, HQ, D = q.shape
+
+ CACHE_SIZE = k.shape[0]
+ assert PAGE_SIZE % 32 == 0, "PAGE_SIZE must be divisible by 128"
+ GROUP_M = HQ // HKV
+ assert GROUP_M * HKV == HQ, "HQ must be divisible by H_kv"
+
+ FP8 = hasattr(torch, "float8_e5m2") and q.dtype == torch.float8_e5m2
+
+ seq_lens_bh = seq_lens_bh.to(torch.int32)
+ assert B <= 32767, "too many batches"
+ assert global_page_table.shape[1] == HKV
+ assert q.is_contiguous()
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
+
+ sm_scale = 1 / math.sqrt(D) if sm_scale is None else sm_scale
+ if key_split is None:
+ # round max_seq_len to the next power of two to maximize cache hits
+ key_split = num_splits_heuristic(
+ B * HKV,
+ max_seq_len=1 << int(seq_lens_bh.max()).bit_length(),
+ num_sms=torch.cuda.get_device_properties(
+ q.device
+ ).multi_processor_count,
+ max_splits=12,
+ )
+
+ maybe_set_allocator(
+ lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
+ )
+
+ # stage 1 scratch
+ mid_o = torch.empty((B, key_split, HQ, D), device=q.device, dtype=q.dtype)
+ mid_lse = torch.empty((B, key_split, HQ), device=q.device, dtype=torch.float32)
+ # processes all queries for a KV head together
+ # pointers are lowercase, CONSTANTS are upper
+ grid1 = (B, HKV, key_split)
+ _varkv_stage1_groupM[grid1](
+ q=q,
+ k=k,
+ v=v,
+ mid_o=mid_o,
+ mid_lse=mid_lse,
+ page_table_bhl=global_page_table,
+ batch_mapping=batch_mapping,
+ seq_lens_bh=seq_lens_bh.contiguous(),
+ SM_SCALE=sm_scale,
+ B=B,
+ HKV=HKV,
+ HQ=HQ,
+ CACHE_SIZE=CACHE_SIZE,
+ STRIDE_LBS=mid_lse.stride(0),
+ STRIDE_LS=mid_lse.stride(1),
+ STRIDE_LH=mid_lse.stride(2),
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ D=D,
+ KEY_SPLIT=key_split,
+ GROUP_M=GROUP_M,
+ DTYPE=tl.float8e5
+ if FP8
+ else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
+ PAGE_SIZE=PAGE_SIZE,
+ )
+
+ if key_split == 1:
+ return mid_o.squeeze(1).contiguous()
+
+ # reduce partial results across splits
+ output = torch.empty_like(q)
+ grid2 = (B, HQ)
+ _varkv_stage2_reduce[grid2](
+ mid_o=mid_o,
+ mid_lse=mid_lse,
+ output=output,
+ STRIDE_LBS=mid_lse.stride(0),
+ STRIDE_LS=mid_lse.stride(1),
+ STRIDE_LH=mid_lse.stride(2),
+ STRIDE_OBS=output.stride(0),
+ STRIDE_OH=output.stride(1),
+ B=B,
+ HQ=HQ,
+ D=D, # type: ignore
+ KEY_SPLIT=key_split, # type: ignore
+ DTYPE=tl.float8e5
+ if FP8
+ else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
+ )
+ return output
+
+
+# similar to flash attention split heuristic
+@functools.lru_cache(maxsize=128)
+def num_splits_heuristic(
+ total_mblocks: int,
+ max_seq_len: int,
+ num_sms: int,
+ max_splits: int,
+) -> int:
+ # If we nearly fill SMs already, prefer 1 split
+ if total_mblocks >= 0.8 * num_sms or max_seq_len <= 1024:
+ return 1
+ eff = []
+ max_eff = 0.0
+ for s in range(1, min(max_splits, num_sms) + 1):
+ if (max_seq_len / s) <= 512:
+ break
+ n_waves = float(total_mblocks * s) / float(num_sms)
+ e = n_waves / math.ceil(n_waves) if n_waves > 0 else 0.0
+ eff.append(e)
+ max_eff = max(max_eff, e)
+ threshold = 0.75 * max_eff # if not split_min_hit else 0.9 * max_eff
+ for i, e in enumerate(eff, start=1):
+ if e >= threshold:
+ return i
+ return 1
+
+
+def prune_invalid_configs(configs, _, **kwargs):
+ PAGE_SIZE = kwargs["PAGE_SIZE"]
+ return [conf for conf in configs if conf.kwargs.get("BLOCK_N", 0) <= PAGE_SIZE]
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_N": BLOCK_N, "MIN_BLOCK_KV": MIN_BLOCK_KV, "WARPSPEC": ws},
+ num_warps=w,
+ num_stages=s,
+ )
+ for BLOCK_N in [32, 64, 128]
+ for MIN_BLOCK_KV in [8]
+ for s in [2, 3, 4]
+ for w in [4, 8]
+ for ws in [True, False]
+ ],
+ key=[
+ "HKV",
+ "GROUP_M",
+ "D",
+ "PAGE_SIZE", # "B"
+ ],
+ cache_results=True,
+ prune_configs_by={"early_config_prune": prune_invalid_configs},
+)
+@triton.jit
+def _varkv_stage1_groupM(
+ q, # [B, HQ, D] contiguous
+ k, # GLOBAL cache: [CACHE_SIZE, D], contiguous
+ v, # GLOBAL cache: [CACHE_SIZE, D], contiguous
+ mid_o,
+ mid_lse,
+ page_table_bhl, # int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
+ batch_mapping, # int32 [B] maps local pid_b -> true batch index
+ seq_lens_bh, # int32 [B*H_kv] valid tokens per (b,h)
+ SM_SCALE,
+ B,
+ HKV,
+ HQ,
+ CACHE_SIZE, # CACHE_SIZE = N_PAGES * PAGE_SIZE
+ STRIDE_LBS,
+ STRIDE_LS,
+ STRIDE_LH,
+ # constexprs
+ N_LOGICAL_PAGES_MAX: tl.constexpr, # page table width per (b,h)
+ D: tl.constexpr,
+ KEY_SPLIT: tl.constexpr,
+ GROUP_M: tl.constexpr,
+ DTYPE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ MIN_BLOCK_KV: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+):
+ pid_b = tl.program_id(0) # batch
+ pid_kvh = tl.program_id(1) # kv head
+ pid_s = tl.program_id(2) # split
+
+ # valid length L for this (b,h)
+ bh_stride = HKV
+ L = tl.load(seq_lens_bh + pid_b * bh_stride + pid_kvh)
+ if L == 0:
+ return
+
+ tl.assume(L > 0)
+
+ # split sizing on logical token axis [0..L)
+ base = tl.cdiv(L, KEY_SPLIT)
+ per_split_len = tl.cdiv(base, MIN_BLOCK_KV) * MIN_BLOCK_KV
+ split_start = pid_s * per_split_len
+ split_end = tl.minimum(split_start + per_split_len, L)
+
+ # query heads mapped to this kv head
+ base_qh = pid_kvh * GROUP_M
+ GROUP_M_PAD: tl.constexpr = 16 if GROUP_M < 16 else GROUP_M
+ offs_m = tl.arange(0, GROUP_M_PAD)
+ mask_m = offs_m < GROUP_M
+ offs_d = tl.arange(0, D)
+
+ # load Q tile [M, D]
+ q_ptrs = q + (pid_b * HQ + base_qh + offs_m)[:, None] * D + offs_d[None, :]
+ q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0).to(DTYPE) # [M, D]
+
+ # streaming softmax state per query
+ e_max = tl.zeros([GROUP_M_PAD], dtype=tl.float32) - float("inf")
+ e_sum = tl.zeros([GROUP_M_PAD], dtype=tl.float32)
+ acc = tl.zeros([GROUP_M_PAD, D], dtype=tl.float32)
+
+ if split_end > split_start:
+ # logical pages covering [split_start, split_end)
+ lp0 = split_start // PAGE_SIZE
+ lp1 = tl.cdiv(split_end, PAGE_SIZE) # exclusive
+
+ mapped_b = tl.load(batch_mapping + pid_b)
+ tl.assume(mapped_b >= 0)
+ # page table base for this (b,h)
+ pt_stride = N_LOGICAL_PAGES_MAX
+ pt_base = (mapped_b * HKV + pid_kvh) * pt_stride
+
+ for lp in tl.range(lp0, lp1):
+ phys = tl.load(
+ page_table_bhl + pt_base + lp, cache_modifier=".cg"
+ ) # physical page id
+ # bounds within the logical page
+ local_start = tl.where(lp == lp0, split_start - lp * PAGE_SIZE, 0)
+ local_end = tl.where(lp == (lp1 - 1), split_end - lp * PAGE_SIZE, PAGE_SIZE)
+
+ page_base = phys * PAGE_SIZE
+ page_base = tl.multiple_of(page_base, BLOCK_N)
+ for s in tl.range(local_start, local_end, BLOCK_N):
+ s = tl.multiple_of(s, MIN_BLOCK_KV)
+ offs_bn = tl.arange(0, BLOCK_N)
+ key_idx = page_base + s + offs_bn
+ k_ptrs = k + key_idx[:, None] * D + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
+ qk = tl.dot(q, k_blk.T) * SM_SCALE # [M, BN]
+
+ offs_n = s + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < local_end
+ qk = tl.where(mask_n[None, :], qk, -float("inf"))
+
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max) # [M]
+ re_scale = tl.exp(e_max - n_e_max) # [M]
+ acc = acc * re_scale[:, None] # [M, D]
+ v_ptrs = v + key_idx[:, None] * D + offs_d[None, :]
+ v_blk = tl.load(v_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
+ p = tl.exp(qk - n_e_max[:, None]) # [M, BN]
+ acc = tl.dot(p.to(DTYPE), v_blk, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # write mid outputs [M, D] for this split
+ tmp = (acc / e_sum[:, None]).to(DTYPE)
+ row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
+ mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
+ tl.store(mid_ptrs, tmp, mask=mask_m[:, None])
+
+ ml_ptrs = (
+ mid_lse
+ + pid_b * STRIDE_LBS
+ + pid_s * STRIDE_LS
+ + (base_qh + offs_m) * STRIDE_LH
+ )
+ safe_sum = tl.where(mask_m, e_sum, 1.0)
+ tl.store(ml_ptrs, e_max + tl.log(safe_sum), mask=mask_m)
+ else:
+ # empty split
+ zero_md = tl.zeros([GROUP_M_PAD, D], dtype=DTYPE)
+ row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
+ mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
+ tl.store(mid_ptrs, zero_md, mask=mask_m[:, None])
+ ml_ptrs = (
+ mid_lse
+ + pid_b * STRIDE_LBS
+ + pid_s * STRIDE_LS
+ + (base_qh + offs_m) * STRIDE_LH
+ )
+ tl.store(ml_ptrs, -float("inf"), mask=mask_m)
+
+
+@triton.jit
+def _varkv_stage2_reduce(
+ mid_o,
+ mid_lse,
+ output,
+ STRIDE_LBS,
+ STRIDE_LS,
+ STRIDE_LH,
+ STRIDE_OBS,
+ STRIDE_OH,
+ B,
+ HQ,
+ D: tl.constexpr,
+ KEY_SPLIT: tl.constexpr,
+ DTYPE: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
+ offs_d = tl.arange(0, D)
+
+ # across split LSE combine
+ e_sum = 0.0
+ e_max = -float("inf")
+ acc = tl.zeros([D], dtype=tl.float32)
+
+ for s in tl.range(KEY_SPLIT):
+ row_mid = pid_b * (KEY_SPLIT * HQ) + s * HQ + pid_h
+ tv = tl.load(mid_o + row_mid * D + offs_d).to(DTYPE)
+ tl_ptr = mid_lse + pid_b * STRIDE_LBS + s * STRIDE_LS + pid_h * STRIDE_LH
+ tlogic = tl.load(tl_ptr)
+
+ n_e_max = tl.maximum(e_max, tlogic)
+ old_scale = tl.exp(e_max - n_e_max)
+ acc = acc * old_scale + tl.exp(tlogic - n_e_max) * tv.to(tl.float32)
+ e_sum = e_sum * old_scale + tl.exp(tlogic - n_e_max)
+ e_max = n_e_max
+
+ o = (acc / e_sum).to(DTYPE)
+ o_ptr = output + pid_b * STRIDE_OBS + pid_h * STRIDE_OH + offs_d
+ tl.store(o_ptr, o)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/attention/sparse_varlen_kernel.py b/vllm/compactor-vllm/src/compactor_vllm/attention/sparse_varlen_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..62f32200479d6ed3860056152f352433f563ccf4
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/attention/sparse_varlen_kernel.py
@@ -0,0 +1,526 @@
+import logging
+import math
+
+import torch
+import triton
+import triton.language as tl
+
+from compactor_vllm.utils.triton_compat import (
+ autotune as triton_autotune,
+ cuda_capability_geq,
+ maybe_set_allocator,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def causal_sparse_varlen_with_cache(
+ q,
+ k,
+ v,
+ k_cache,
+ v_cache,
+ seq_lens_bh,
+ global_page_table,
+ batch_mapping,
+ cu_seqlens_q,
+ max_seqlen_q: int,
+ max_seqlen_k_cache: int,
+ HKV: int,
+ PAGE_SIZE: int,
+ sm_scale=None,
+):
+ """
+ Causal prefill attention over a paged KV cache plus a block of newly
+ appended tokens in a packed batch format.
+
+ This function wraps the Triton kernel
+ ``_causal_head_sparse_varlen_with_cache`` to compute prefill attention for
+ a batch of variable-length sequences, where:
+ • Past keys/values are stored in a paged global KV cache
+ (``k_cache``, ``v_cache``) with a (per-layer) page table.
+
+ • New tokens for this step are given as K/V blocks
+ (``k``, ``v``), together with a packed query block ``q``.
+
+ • The result is equivalent to applying causal attention over the
+ concatenation of:
+ [ cached KV prefix || (K_app, V_app) for this step ]
+ for each sequence in the batch.
+
+ Grouped-query attention (GQA / MQA) is supported by allowing more query
+ heads than KV heads: ``HQ`` must be divisible by ``HKV``.
+
+ Args:
+ :param q:
+ Query tensor of shape ``[N, HQ, D]`` (float16 / bfloat16/float32).
+ ``N`` is the total number of new tokens across the batch
+ (i.e. ``N = sum_b seqlen_q[b]``), packed according to
+ ``cu_seqlens_q``. ``HQ`` is the number of query heads, ``D`` the
+ head dimension (must be a power of two).
+ :param k:
+ New key tensor of shape ``[N, HKV, D]`` for the same tokens as
+ ``q``. These are the K values appended to the cache for this
+ prefill step.
+ :param v:
+ New value tensor of shape ``[N, HKV, D]`` for the same tokens as
+ ``q``.
+ :param k_cache:
+ Global key cache backing buffer of shape ``[CACHE_SIZE, D]``.
+ Keys for all cached tokens and heads are stored here; the mapping
+ from (batch, head, token index) to a row in this buffer is
+ given by ``global_page_table``.
+ :param v_cache:
+ Global value cache of shape ``[CACHE_SIZE, D]``. Must have the
+ same layout as ``k_cache`` (same ``CACHE_SIZE`` and ``D``).
+ :param seq_lens_bh:
+ Tensor of shape ``[B, HKV]`` (int32) giving, for each local batch
+ index and KV head, the number of cached tokens already present
+ in the paged KV cache before this prefill step.
+ :param global_page_table:
+ Tensor of shape ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32)
+ mapping ``(true_batch_idx, kv_head, logical_page)`` to a physical
+ page id in the global KV cache. A physical page id `p` refers to
+ the slice:
+ ``k_cache[p * PAGE_SIZE : (p + 1) * PAGE_SIZE]``.
+ :param batch_mapping:
+ Tensor of shape ``[B]`` (int16 / int32) mapping the local batch
+ index used in this kernel launch to the global batch index used
+ to index ``global_page_table``. This allows the same global cache
+ to be shared across multiple microbatches.
+ :param cu_seqlens_q:
+ Tensor of shape ``[B + 1]`` (int32) with cumulative sequence
+ lengths for the *new* tokens (q/k/v) in packed form. For batch
+ element ``b``:
+ ``seqlen_q[b] = cu_seqlens_q[b + 1] - cu_seqlens_q[b]``.
+ The total number of tokens satisfies
+ ``N = cu_seqlens_q[-1]``.
+ :param max_seqlen_q:
+ Maximum new query sequence length across the batch, i.e.
+ ``max_b seqlen_q[b]``.
+ :param max_seqlen_k_cache:
+ Maximum cached sequence length across (batch, KV head), i.e.
+ ``max_{b,h} seq_lens_bh[b, h]``.
+ :param HKV:
+ Number of KV heads. Must divide ``HQ``.
+ :param PAGE_SIZE:
+ Number of tokens stored per physical page in the paged KV cache.
+ ``CACHE_SIZE`` must be divisible by ``PAGE_SIZE``.
+ :param sm_scale:
+ Optional scaling factor applied to the attention logits before
+ softmax. If ``None``, defaults to ``1.0 / sqrt(D)``.
+ :returns torch.Tensor:
+ Attention output of shape ``[N, HQ, D]``, with the same dtype and
+ device as ``q``. The output is laid out in the same packed
+ varlen format as the input queries, i.e. the first
+ ``seqlen_q[0]`` rows correspond to batch 0, the next
+ ``seqlen_q[1]`` rows to batch 1, etc.
+ """
+ assert q.ndim == 3, "q should be [N, HQ, D]"
+ N, HQ, D = q.shape
+ assert (D & (D - 1)) == 0, "D must be power of two"
+
+ B = cu_seqlens_q.numel() - 1
+ assert B > 0
+ assert HQ % HKV == 0, "Number of query heads must divide number of keys heads"
+ H_g = HQ // HKV
+ # view Q as [HKV, N, QUERY_GROUP_SIZE, D]
+ out = torch.empty_like(q)
+ q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+ out = out.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+
+ # K_app/V_app: [N, HKV, D] -> [HKV, N, D]
+ k_app = k.view(N, HKV, D).permute(1, 0, 2)
+ v_app = v.view(N, HKV, D).permute(1, 0, 2)
+
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=q.device)
+ seq_lens_bh = seq_lens_bh.to(dtype=torch.int32, device=q.device)
+ batch_mapping = batch_mapping.to(dtype=torch.int16, device=q.device)
+
+ N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
+ CACHE_SIZE = k_cache.shape[0]
+ assert v_cache.shape[0] == CACHE_SIZE
+ assert k_cache.shape[1] == D and v_cache.shape[1] == D
+ assert PAGE_SIZE > 0 and CACHE_SIZE % PAGE_SIZE == 0
+
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ # strides for Q [G, N, QUERY_GROUP_SIZE, D]
+ STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
+ STRIDE_KC, STRIDE_VC = k_cache.stride(0), v_cache.stride(0)
+ # [G, N, D]
+ STRIDE_KA_G, STRIDE_KA_N, STRIDE_KA_D = k_app.stride()
+ STRIDE_VA_G, STRIDE_VA_N, STRIDE_VA_D = v_app.stride()
+
+ # OUT [G, N, QUERY_GROUP_SIZE, D]
+ STRIDE_OUT_G, STRIDE_OUT_N, STRIDE_OUT_H, STRIDE_OUT_D = out.stride()
+ # launch grid
+ maybe_set_allocator(
+ lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
+ )
+ assert STRIDE_KA_D == STRIDE_VA_D == STRIDE_Q_D == STRIDE_OUT_D == 1, (
+ "final dimension must be contiguous"
+ )
+
+ def grid(META):
+ return HKV, B, triton.cdiv(max_seqlen_q, META["BLOCK_M"])
+
+ # On a fresh batch, max_seqlen_k_cache==0 (no KV prefix yet). Passing
+ # `triton.next_power_of_2(0)` into autotune constexpr keys breaks
+ # kernel selection / tuning and can yield garbage outputs.
+ _k_max_autotune = max(int(max_seqlen_k_cache), 1)
+ AUTOTUNE_MAX_Q_LEN = triton.next_power_of_2(max_seqlen_q)
+ AUTOTUNE_MAX_K_LEN = triton.next_power_of_2(_k_max_autotune)
+ _causal_head_sparse_varlen_with_cache[grid](
+ Q=q,
+ K_cache=k_cache,
+ V_cache=v_cache,
+ K_app=k_app,
+ V_app=v_app,
+ cu_seqlens_qk=cu_seqlens_q,
+ seq_lens_bh=seq_lens_bh,
+ page_table=global_page_table,
+ batch_mapping=batch_mapping,
+ OUT=out,
+ HKV=HKV,
+ QUERY_GROUP_SIZE=H_g,
+ PAGE_SIZE=PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ STRIDE_Q_G=STRIDE_Q_G,
+ STRIDE_Q_N=STRIDE_Q_N,
+ STRIDE_Q_H=STRIDE_Q_H,
+ STRIDE_KC=STRIDE_KC,
+ STRIDE_VC=STRIDE_VC,
+ STRIDE_KA_G=STRIDE_KA_G,
+ STRIDE_KA_N=STRIDE_KA_N,
+ STRIDE_VA_G=STRIDE_VA_G,
+ STRIDE_VA_N=STRIDE_VA_N,
+ STRIDE_OUT_G=STRIDE_OUT_G,
+ STRIDE_OUT_N=STRIDE_OUT_N,
+ STRIDE_OUT_H=STRIDE_OUT_H,
+ sm_scale=sm_scale,
+ D=D,
+ AUTOTUNE_MAX_Q_LEN=AUTOTUNE_MAX_Q_LEN,
+ AUTOTUNE_MAX_K_LEN=AUTOTUNE_MAX_K_LEN,
+ )
+ return out.permute(1, 0, 2, 3).view(N, HQ, D) # already contiguous
+
+
+autotune_configs_cc9 = [
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=16, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=4, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
+ ),
+]
+
+autotune_configs_cc8 = [
+ triton.Config(
+ {"BLOCK_N": BN, "BLOCK_M": BM, "WARPSPEC": True}, num_warps=w, num_stages=s
+ )
+ for BN in [16, 32]
+ for BM in [64]
+ for w in [4, 8]
+ for s in [2, 3]
+]
+
+
+def prune_invalid_configs(configs, _, **kwargs):
+ return [
+ conf
+ for conf in configs
+ if not (conf.kwargs.get("BLOCK_N") == 32 and conf.kwargs.get("num_stages") == 4)
+ ]
+
+
+def get_autotune_configs():
+ if cuda_capability_geq(9, 0):
+ return autotune_configs_cc9
+ else:
+ return autotune_configs_cc8
+
+
+@triton_autotune(
+ configs=get_autotune_configs(),
+ key=[
+ "HKV",
+ "QUERY_GROUP_SIZE",
+ "D",
+ "PAGE_SIZE",
+ "AUTOTUNE_MAX_K_LEN",
+ "AUTOTUNE_MAX_Q_LEN",
+ ],
+ cache_results=True,
+)
+@triton.jit
+def _causal_head_sparse_varlen_with_cache(
+ Q, # [HKV, N, QUERY_GROUP_SIZE, D] (non-contiguous)
+ K_cache,
+ V_cache, # [CACHE_SIZE, D]
+ K_app,
+ V_app, # [HKV, N, D]
+ cu_seqlens_qk, # [B+1]
+ seq_lens_bh, # [B, HKV]
+ page_table, # [B_total, HKV, N_LOGICAL_PAGES_MAX]
+ batch_mapping, # [B], maps local b -> global batch index
+ OUT, # [HKV, N, QUERY_GROUP_SIZE, D]
+ #
+ HKV: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ N_LOGICAL_PAGES_MAX,
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_KC,
+ STRIDE_VC,
+ STRIDE_KA_G,
+ STRIDE_KA_N,
+ STRIDE_VA_G,
+ STRIDE_VA_N,
+ STRIDE_OUT_G,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ #
+ D: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+ AUTOTUNE_MAX_Q_LEN: tl.constexpr, # used for autotune key
+ AUTOTUNE_MAX_K_LEN: tl.constexpr, # used for autotune key
+):
+ TOTAL_N_QUERIES: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
+ pid_g = tl.program_id(0) # kv_head id in [0, HKV)
+ pid_b = tl.program_id(1) # batch id
+ pid_m = tl.program_id(2) # query-tile id within batch
+
+ # batch segment [qb, qe) in N
+ off_b = tl.load(cu_seqlens_qk + pid_b)
+ off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
+ seq_len_append = off_b1 - off_b
+
+ q_start = off_b + pid_m * BLOCK_M
+ q_end = tl.minimum(q_start + BLOCK_M, off_b1)
+ # number of queries in this tile for this batch
+ M = q_end - q_start
+ if M <= 0:
+ return
+
+ # cached length for (b, kv_head=pid_g)
+ L_cache = tl.load(seq_lens_bh + pid_b * HKV + pid_g)
+ # row indices flattened over [QUERY_GROUP_SIZE, M]
+ offs_row = tl.arange(0, TOTAL_N_QUERIES)
+ row_m = offs_row % BLOCK_M
+ row_h = offs_row // BLOCK_M
+ # valid rows: only those with row_m < M
+ row_mask = row_m < M
+
+ # global query index per row
+ q_idx = q_start + row_m
+ offs_d = tl.arange(0, D)
+ # Q tile: [TOTAL_N_QUERIES, D]
+ # Q layout: [HKV, N, QUERY_GROUP_SIZE, D]
+ q_ptrs = (
+ Q
+ + pid_g * STRIDE_Q_G
+ + q_idx[:, None] * STRIDE_Q_N
+ + row_h[:, None] * STRIDE_Q_H
+ + offs_d[None, :]
+ )
+ q = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+
+ e_max = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32) - float("inf")
+ e_sum = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32)
+ acc = tl.zeros([TOTAL_N_QUERIES, D], dtype=tl.float32)
+
+ offs_block_n = tl.arange(0, BLOCK_N)
+ qk_scale = sm_scale * 1.44269504
+
+ # 1) attend over cachee K/V
+ if L_cache > 0:
+ # map local (b) to global batch index
+ mapped_b = tl.load(batch_mapping + pid_b)
+ pt_base = (mapped_b * HKV + pid_g) * N_LOGICAL_PAGES_MAX
+ # iterate logical pages
+ num_lp = tl.cdiv(L_cache, PAGE_SIZE)
+ for lp in tl.range(0, num_lp):
+ # can overflow in 32 bits so upcast
+ phys = tl.load(page_table + pt_base + lp).to(tl.int64)
+ page_start = phys * PAGE_SIZE
+ # how many valid tokens in this page for this (b,g)
+ remain = L_cache - lp * PAGE_SIZE
+ page_len = tl.minimum(PAGE_SIZE, remain)
+ # iterate over this page in BLOCK_N chunks
+ for ks in tl.range(0, page_len, BLOCK_N):
+ offs_n = ks + offs_block_n
+ mask_n = offs_n < page_len
+
+ key_idx = page_start + offs_n
+ k_ptrs = K_cache + key_idx[:, None] * STRIDE_KC + offs_d[None, :]
+
+ k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
+ qk = tl.dot(q, k.T) * qk_scale # [TOTAL_N_QUERIES, BN]
+ qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
+
+ # softmax update
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ v_ptrs = V_cache + key_idx[:, None] * STRIDE_VC + offs_d[None, :]
+ v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # 2) attend over appended K_app/V_app (causal)
+ # appended tokens for batch b are in [off_b, off_b1)
+ # query tile is [q_start, q_end)
+ # for each query at index q_idx, valid appended keys k satisfy off_b <= k <= q_idx
+ if q_end > off_b:
+ # exactly one appended token
+ if seq_len_append == 1:
+ ka_ptrs = K_app + pid_g * STRIDE_KA_G + off_b * STRIDE_KA_N + offs_d
+ k = tl.load(ka_ptrs) # [D]
+ qk = tl.sum(q * k[None, :], 1) * qk_scale
+ qk = tl.where(row_mask, qk, -1.0e6)
+ n_e_max = tl.maximum(e_max, qk)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max)
+ va_ptrs = V_app + pid_g * STRIDE_VA_G + off_b * STRIDE_VA_N + offs_d
+ v = tl.load(va_ptrs) # [D]
+ acc = acc * re_scale[:, None] + p[:, None] * v[None, :]
+ e_sum = e_sum * re_scale + p
+ else:
+ # off-band: k in [off_b, q_start)
+ # for all queries t in [q_start, q_end), any k < q_start satisfies k <= t.
+ # so no causal mask needed.
+ off_band_start = off_b
+ off_band_end = q_start
+
+ if off_band_end > off_band_start:
+ for ks in tl.range(off_band_start, off_band_end, BLOCK_N):
+ offs_n = ks + offs_block_n
+ mask_n = offs_n < off_band_end
+
+ ka_ptrs = (
+ K_app
+ + pid_g * STRIDE_KA_G
+ + offs_n[:, None] * STRIDE_KA_N
+ + offs_d[None, :]
+ )
+ k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
+
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ va_ptrs = (
+ V_app
+ + pid_g * STRIDE_VA_G
+ + offs_n[:, None] * STRIDE_VA_N
+ + offs_d[None, :]
+ )
+ v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # on-band remaining k
+ on_band_start = tl.maximum(q_start, off_b)
+ if on_band_start < q_end:
+ for ks in tl.range(on_band_start, q_end, BLOCK_N):
+ offs_n = ks + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < q_end
+
+ ka_ptrs = (
+ K_app
+ + pid_g * STRIDE_KA_G
+ + offs_n[:, None] * STRIDE_KA_N
+ + offs_d[None, :]
+ )
+
+ k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+
+ caus_mask = offs_n[None, :] <= q_idx[:, None]
+ full_mask = row_mask[:, None] & mask_n[None, :] & caus_mask
+
+ qk = tl.where(full_mask, qk, -1.0e6)
+
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ va_ptrs = (
+ V_app
+ + pid_g * STRIDE_VA_G
+ + offs_n[:, None] * STRIDE_VA_N
+ + offs_d[None, :]
+ )
+ v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # 3) write outputs
+ o = (acc / e_sum[:, None]).to(q.dtype)
+ out_ptrs = (
+ OUT
+ + pid_g * STRIDE_OUT_G
+ + q_idx[:, None] * STRIDE_OUT_N
+ + row_h[:, None] * STRIDE_OUT_H
+ + offs_d[None, :]
+ )
+ tl.store(out_ptrs, o, mask=row_mask[:, None])
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/benchmark/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/benchmark/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/compression/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..568ad7890f9e1598264e8f741676a1d64f682b7b
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/__init__.py
@@ -0,0 +1,41 @@
+from compactor_vllm.compression.common import (
+ BaseCompressionMethod,
+ NoCompression,
+)
+from compactor_vllm.compression.criticalkv import CriticalAdaKVCompression
+from compactor_vllm.compression.compactor import CompactorCompression
+from compactor_vllm.compression.compression_config import (
+ BatchCompressionParams,
+ CompressionMethod,
+ SequenceCompressionParams,
+)
+from compactor_vllm.compression.snapkv import SnapKVCompression
+
+COMPRESSION_REGISTRY: dict[CompressionMethod, type[BaseCompressionMethod]] = {
+ CompressionMethod.CRITICALADAKV: CriticalAdaKVCompression,
+ CompressionMethod.COMPACTOR: CompactorCompression,
+ CompressionMethod.SNAPKV: SnapKVCompression,
+ CompressionMethod.NONE: NoCompression,
+}
+
+
+def apply_prerope_compression(q, k, v, context):
+ method = context.compression_context.compression_method
+ return COMPRESSION_REGISTRY[method].pre_rope_scoring(q, k, v, context=context)
+
+
+def apply_postrope_compression(q, k, v, prerope_scores, context):
+ method = context.compression_context.compression_method
+ return COMPRESSION_REGISTRY[method].post_rope_scoring(
+ q, k, v, prerope_scores, context=context
+ )
+
+
+__all__ = [
+ "apply_prerope_compression",
+ "apply_postrope_compression",
+ "CompressionMethod",
+ "BatchCompressionParams",
+ "SequenceCompressionParams",
+ "COMPRESSION_REGISTRY"
+]
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/common.py b/vllm/compactor-vllm/src/compactor_vllm/compression/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1661912e4b23ceb637666f4cbe97e62cbf50c37
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/common.py
@@ -0,0 +1,243 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+
+from compactor_vllm.kv_cache.store_kv_cache import prefill_store_topk_kv
+
+
+class BaseCompressionMethod(ABC):
+ """
+ Abstract interface for KV cache compression methods.
+
+ A compression method is implemented as a pair of optional scoring phases
+ that run before and after rotary position embedding (RoPE) is applied:
+
+ 1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.
+
+ 2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
+ - refine / reweight the pre-RoPE scores, or
+ - compute potentially position-aware.
+
+ Concrete subclasses are expected to implement both
+ static methods and return a single tensor of scores (or ``None`` if the
+ phase is a no-op), which the caller can then feed into the shared
+ “scores → top-k indices → KV extraction” pipeline.
+ """
+
+ @staticmethod
+ @abstractmethod
+ def pre_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ """
+ Compute per-token importance scores from pre-RoPE queries/keys.
+
+ Args:
+ :param q:
+ Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
+ :param k:
+ Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
+ :param v:
+ Value tensor. Shape ``[total_tokens, HKV, D]```
+ :param context:
+ compactor_vllm.utils.context.Context object carrying additional metadata,
+ such as batch mappings or temporary buffers
+
+ Returns:
+ :return Optional[torch.Tensor]:
+ A tensor of scores (e.g. per-token, per-head importance values)
+ to be passed to ``post_rope_scoring`` or directly into the
+ top-k selection step. If this phase is a no-op, implementations
+ should return ``None``. Shape ``[total_tokens, HKV]```.
+ """
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ """
+ Compute or refine importance scores from post-RoPE queries/keys.
+
+ This method is called after rotary embeddings have been applied. It can
+ optionally use both the post-RoPE Q/K and any scores produced by
+ ``pre_rope_scoring`` to produce final scores used for token selection.
+
+ Common patterns include:
+ * Using ``pre_rope_scores`` as a base signal and applying a
+ position-aware correction.
+ * Only computing scores that depend on absolute or relative positions.
+ * Simply passing through ``pre_rope_scores`` unchanged.
+
+ Args:
+ :param q:
+ Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
+ :param k:
+ Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
+ :param pre_rope_scores:
+ Optional scores returned by ``pre_rope_scoring``. May be
+ ``None`` if the pre-RoPE phase returned None.
+ :param v:
+ Value tensor. Shape ``[total_tokens, HKV, D]```
+ :param context:
+ compactor_vllm.utils.context.Context object carrying additional metadata,
+ such as batch mappings or temporary buffers
+ Returns:
+ :return Optional[torch.Tensor]:
+ Final importance scores to be consumed by the compression
+ pipeline (for top-k token selection). If this phase is a
+ no-op, implementations may return ``pre_rope_scores``. If
+ None is returned, no compression will be applied.
+ """
+ pass
+
+
+class NoCompression(BaseCompressionMethod):
+ """
+ Trivial compression method that disables KV cache compression.
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ return pre_rope_scores
+
+
+def extract_and_store_top_kv(
+ scores: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_k_len: int,
+ top_k: int,
+ H: int,
+ new_keys: torch.Tensor, # [N_total, H, D]
+ new_vals: torch.Tensor, # [N_total, H, D]
+ num_tokens_to_retain: torch.Tensor, # [B] int32
+ page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
+ batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
+ bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
+ k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ PAD_TO_PAGE_SIZE: bool = True,
+ K_TILE: int = 16,
+ padding: float = -float("inf"),
+):
+ """helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
+ indices_topk = scores_to_retain_indices(
+ scores,
+ cu_seqlens_k=cu_seqlens_k,
+ max_k_len=max_k_len,
+ top_k=top_k,
+ H=H,
+ padding=padding,
+ )
+ prefill_store_topk_kv(
+ new_keys=new_keys,
+ new_vals=new_vals,
+ indices_topk=indices_topk,
+ num_tokens_to_retain=num_tokens_to_retain,
+ page_table=page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=bh_lens,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ cu_seqlens_k=cu_seqlens_k,
+ PAGE_SIZE=PAGE_SIZE,
+ PAD_TO_PAGE_SIZE=PAD_TO_PAGE_SIZE,
+ K_TILE=K_TILE,
+ )
+
+
+def scores_to_retain_indices(
+ scores: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_k_len: int,
+ top_k: int,
+ H: int,
+ padding: float = -float("inf"),
+) -> torch.Tensor:
+ """
+ Select global top-k token–head indices per sequence from packed scores.
+
+ This helper takes per-token, per-head scores in packed varlen form and
+ returns, for each batch element, the indices of the top-k (token, head)
+ pairs in the flattened global layout.
+ Inputs are assumed to follow the usual packed varlen convention:
+ • ``scores`` is laid out as ``[N_total, H]``, where:
+ ``N_total = sum_b seqlen_k[b]``
+ and ``HKV`` is the number of KV heads.
+
+ • ``cu_seqlens_k`` is ``[B + 1]`` (int32), giving cumulative lengths
+ for the keys per batch:
+ ``seqlen_k[b] = cu_seqlens_k[b + 1] - cu_seqlens_k[b]``.
+
+ • ``max_k_len`` is an upper bound on ``seqlen_k[b]`` across the batch.
+
+ The function pads each sequence to length ``max_k_len`` with ``padding``
+ (default: ``-inf``), flattens the per-sequence scores into shape
+ ``[B, max_k_len * H]``, and runs a per-batch top-k. The returned indices
+ are shifted so that they directly index into the flattened global
+ score layout of shape ``[N_total * H]``:
+ global_index = (token_global_offset * H) + head_index
+
+ Args:
+ :param scores:
+ Tensor of shape ``[N_total, HKV]`` containing scores for each
+ (token, head) pair in packed varlen format.
+ :param cu_seqlens_k:
+ Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
+ lengths for each batch element. The total number of tokens
+ satisfies ``N_total = cu_seqlens_k[-1]``.
+ :param max_k_len:
+ Maximum key sequence length across the batch (i.e.
+ ``max_b seqlen_k[b]``). Used to allocate the padded buffer.
+ :param top_k:
+ Number of (token, head) entries to retain **per batch element**.
+ If ``top_k > max_k_len * HKV``, it is clamped to ``max_k_len * HKV``.
+ :param H:
+ Number of key heads; must match ``scores.shape[1]``.
+ :param padding:
+ Padding value used when extending sequences shorter than
+ ``max_k_len``. Defaults to ``-inf``, so that padded positions are
+ never selected in the top-k.
+
+ Returns:
+ :return torch.Tensor:
+ Tensor of shape ``[B, k_eff]`` (int64) where
+ ``k_eff = min(top_k, max_k_len * H)``. Each entry is a global
+ index into the flattened score array of shape ``[N_total * H]``
+ (i.e. scores viewed as ``scores.view(-1)``),
+ """
+ # idea: pad and then select top-k.
+ B, device = cu_seqlens_k.numel() - 1, scores.device
+ padded = torch.full(
+ (B, max_k_len, H), fill_value=padding, dtype=scores.dtype, device=device
+ )
+ for b in range(B):
+ s, e = int(cu_seqlens_k[b]), int(cu_seqlens_k[b + 1])
+ padded[b, : e - s, :].copy_(scores[s:e, :])
+ flat = padded.view(B, max_k_len * H)
+ idx = torch.topk(
+ flat, k=min(top_k, max_k_len * H), dim=1, largest=True, sorted=True
+ ).indices
+ return idx + (cu_seqlens_k[:-1] * H).unsqueeze(-1)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/compactor.py b/vllm/compactor-vllm/src/compactor_vllm/compression/compactor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6a26c6d1f8ac3a5d5d0158db1306e86fd385956
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/compactor.py
@@ -0,0 +1,704 @@
+"""
+Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress``
+算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、
+全局 z-score、blending 与首尾 sink pad)。
+
+非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。
+"""
+
+from __future__ import annotations
+
+import math
+from typing import List, Optional
+
+import torch
+import triton
+import triton.language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+
+
+def resolve_kvpress_compactor_blending(compression_context) -> float:
+ """与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。"""
+ if compression_context is None:
+ return 0.35
+ b = getattr(compression_context, "compactor_blending", None)
+ if b is not None:
+ return float(b)
+ cr = getattr(compression_context, "compression_ratio", None)
+ if cr is not None:
+ return float(cr)
+ return 0.35
+
+
+class CompactorCompression(BaseCompressionMethod):
+ """与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。"""
+
+ chunk_size: int = 256
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ return maybe_execute_in_stream(
+ kvpress_leverage_scores_packed,
+ k,
+ context.cu_seqlens_q,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ blending = resolve_kvpress_compactor_blending(compression_context)
+ return maybe_execute_in_stream(
+ kvpress_compactor_post_rope,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ pre_rope_scores,
+ compression_context,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ blending=float(blending),
+ STORE_STREAM=context.STORE_STREAM,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Cholesky 杠杆分(kvpress ``LeverageScorePress``)
+# ---------------------------------------------------------------------------
+
+
+def chol_with_jitter(
+ G: torch.Tensor, jitter: float = 0.0, max_tries: int = 5
+) -> torch.Tensor:
+ identity = torch.eye(G.shape[-1], device=G.device, dtype=G.dtype)
+ cur = float(jitter)
+ for _ in range(max_tries):
+ L, info = torch.linalg.cholesky_ex(G + cur * identity, upper=False)
+ if bool((info == 0).all()):
+ return L
+ cur = max(1e-8, (1e-2 if cur == 0.0 else 10.0 * cur))
+ raise RuntimeError(f"Cholesky failed after {max_tries} tries.")
+
+
+def compute_leverage_scores_mid(
+ key_states: torch.Tensor, sketch_dimension: int
+) -> torch.Tensor:
+ """
+ 与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``,
+ 返回 ``[L, H]``(未 z-score)。
+
+ 维序与 kvpress 的 ``(B, H, S, D)`` 对齐:先变为 ``[1, H, L, D]``,在序列维(``dim=-2``)
+ 上中心化,再与 ``Phi`` 为 ``(1, H, D, K)`` 的 batch 矩阵乘得到 ``[1, H, L, K]``。
+ """
+ d, k = key_states.shape[-1], sketch_dimension
+ device, dtype = key_states.device, key_states.dtype
+ H = key_states.shape[1]
+ Phi = torch.randn(1, H, d, k, device=device, dtype=dtype) * (1.0 / math.sqrt(k))
+ # [L, H, d] -> [1, H, L, d],与 kvpress (B,H,S,d) 一致
+ X0 = key_states.transpose(0, 1).unsqueeze(0).contiguous()
+ X = X0 - X0.mean(dim=-2, keepdim=True)
+ X = torch.matmul(X, Phi).to(torch.float32)
+ XT = X.transpose(-2, -1)
+ G = XT @ X
+ L = chol_with_jitter(
+ 0.5 * (G + G.transpose(-2, -1)), jitter=1e-2, max_tries=5
+ )
+ inv_Xt = torch.cholesky_solve(XT, L, upper=False)
+ scores = (X * inv_Xt.transpose(-2, -1)).sum(dim=-1).clamp_min(0)
+ # [1, H, L] -> [L, H]
+ return scores.squeeze(0).transpose(0, 1).contiguous()
+
+
+def kvpress_leverage_scores_packed(
+ key_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ compression_ctx,
+) -> torch.Tensor:
+ device = key_states.device
+ N, Hkv, _D = key_states.shape
+ sketch_dim = int(getattr(compression_ctx, "sketch_dimension", 48))
+ sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
+ sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
+
+ out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
+ mids_flat: list[torch.Tensor] = []
+ mid_ranges: list[tuple[int, int, int]] = []
+
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ k_mid = key_states[mid_start:mid_end, :, :].contiguous()
+ raw = compute_leverage_scores_mid(k_mid, sketch_dim)
+ mids_flat.append(raw.reshape(-1))
+ mid_ranges.append((mid_start, mid_end, Hkv))
+
+ if not mids_flat:
+ return out
+
+ flat = torch.cat(mids_flat, dim=0)
+ z = _zscore_flat_f32_global(flat)
+ offset = 0
+ for (mid_start, mid_end, _Hkv), r in zip(mid_ranges, mids_flat):
+ n = r.numel()
+ seg = z[offset : offset + n].view(mid_end - mid_start, Hkv)
+ out[mid_start:mid_end, :] = seg
+ offset += n
+ return out
+
+
+# ---------------------------------------------------------------------------
+# 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton
+# ---------------------------------------------------------------------------
+
+
+def _non_causal_chunked_attn_pytorch(
+ q: torch.Tensor, k: torch.Tensor, chunk_size: int
+) -> torch.Tensor:
+ """参考实现:与 kvpress 逐算子一致。"""
+ assert chunk_size > 0 and q.shape == k.shape
+ L, H, d = q.shape
+ B = 1
+ q = q.permute(1, 0, 2).unsqueeze(0).contiguous()
+ k = k.permute(1, 0, 2).unsqueeze(0).contiguous()
+ _B, H, S, _d = k.shape
+ S_pad = math.ceil(S / chunk_size) * chunk_size
+ pad_len = S_pad - S
+
+ if pad_len > 0:
+ q_padded = torch.cat(
+ [q, torch.zeros(B, H, pad_len, d, device=q.device, dtype=q.dtype)], dim=2
+ )
+ k_padded = torch.cat(
+ [k, torch.zeros(B, H, pad_len, d, device=k.device, dtype=k.dtype)], dim=2
+ )
+ last_chunk_start = (S // chunk_size) * chunk_size
+ in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
+ query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
+ else:
+ q_padded, k_padded = q, k
+ last_chunk_start = ((S - 1) // chunk_size) * chunk_size
+ in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
+ query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
+
+ num_chunks = S_pad // chunk_size
+ q_chunks = q_padded.view(B, H, num_chunks, chunk_size, d)
+ k_chunks = k_padded.view(B, H, num_chunks, chunk_size, d)
+ dots = torch.matmul(q_chunks, k_chunks.transpose(-2, -1))
+ dots[:, :, -1].masked_fill_(query_mask.unsqueeze(-1), 0)
+ dots[:, :, -1].masked_fill_(key_mask.unsqueeze(-2), -1e-9)
+ attn = torch.softmax(dots.to(torch.float32), dim=-1)
+ out = attn.sum(dim=-2).view(B, H, S_pad)[..., :S]
+ return out.squeeze(0).transpose(0, 1).contiguous()
+
+
+@triton.jit
+def _non_causal_chunk_row_kernel(
+ Q_ptr,
+ K_ptr,
+ Out_ptr,
+ stride_qh,
+ stride_qs,
+ stride_qd,
+ stride_kh,
+ stride_ks,
+ stride_kd,
+ stride_oh,
+ stride_os,
+ S,
+ S_pad,
+ num_chunks,
+ CHUNK_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ ND: tl.constexpr,
+):
+ """
+ 每个 program:一个 head、一个 chunk、一条 query 行。
+ 对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。
+ """
+ h = tl.program_id(0)
+ c = tl.program_id(1)
+ iq = tl.program_id(2)
+
+ g_i = c * CHUNK_SIZE + iq
+
+ offs_j = tl.arange(0, CHUNK_SIZE)
+ logits = tl.zeros([CHUNK_SIZE], dtype=tl.float32)
+
+ for db in range(ND):
+ offs_d = tl.arange(0, BLOCK_D) + db * BLOCK_D
+ mask_d = offs_d < D
+ q_off = (
+ h * stride_qh + g_i * stride_qs + offs_d * stride_qd
+ )
+ qd = tl.load(Q_ptr + q_off, mask=mask_d, other=0.0).to(tl.float32)
+
+ g_j = c * CHUNK_SIZE + offs_j
+ k_row_off = h * stride_kh + g_j[:, None] * stride_ks + offs_d[None, :] * stride_kd
+ kj = tl.load(K_ptr + k_row_off, mask=mask_d[None, :], other=0.0).to(tl.float32)
+ logits += tl.sum(qd[None, :] * kj, axis=1)
+
+ row_invalid = g_i >= S
+ g_j_all = c * CHUNK_SIZE + offs_j
+ col_invalid = g_j_all >= S
+
+ logits = tl.where(row_invalid, tl.zeros([CHUNK_SIZE], dtype=tl.float32), logits)
+ logits = tl.where(
+ row_invalid,
+ logits,
+ tl.where(col_invalid, tl.full([CHUNK_SIZE], -1e-9, dtype=tl.float32), logits),
+ )
+
+ m = tl.max(logits)
+ logits = logits - m
+ exp_v = tl.exp(logits)
+ denom = tl.sum(exp_v)
+ p = exp_v / denom
+
+ out_base = h * stride_oh + g_j_all * stride_os
+ tl.atomic_add(Out_ptr + out_base, p, mask=g_j_all < S)
+
+
+def _non_causal_chunked_attn_triton(
+ q: torch.Tensor, k: torch.Tensor, chunk_size: int
+) -> torch.Tensor:
+ """CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。"""
+ assert q.is_cuda and k.is_cuda and q.shape == k.shape
+ L, H, d = q.shape
+ assert chunk_size > 0
+ S_pad = math.ceil(L / chunk_size) * chunk_size
+ pad_len = S_pad - L
+ if pad_len > 0:
+ zq = torch.zeros(
+ pad_len, H, d, device=q.device, dtype=q.dtype, requires_grad=False
+ )
+ zk = torch.zeros(
+ pad_len, H, d, device=k.device, dtype=k.dtype, requires_grad=False
+ )
+ q = torch.cat([q, zq], dim=0)
+ k = torch.cat([k, zk], dim=0)
+
+ Q = q.transpose(0, 1).contiguous().to(dtype=torch.float32)
+ K = k.transpose(0, 1).contiguous().to(dtype=torch.float32)
+
+ num_chunks = S_pad // chunk_size
+ out_acc = torch.zeros(H, S_pad, device=q.device, dtype=torch.float32)
+
+ S = int(L)
+ grid = (H, num_chunks, chunk_size)
+ BLOCK_D = 32 if d <= 128 else 64
+ ND = (d + BLOCK_D - 1) // BLOCK_D
+ _non_causal_chunk_row_kernel[grid](
+ Q,
+ K,
+ out_acc,
+ Q.stride(0),
+ Q.stride(1),
+ Q.stride(2),
+ K.stride(0),
+ K.stride(1),
+ K.stride(2),
+ out_acc.stride(0),
+ out_acc.stride(1),
+ S,
+ S_pad,
+ int(num_chunks),
+ CHUNK_SIZE=chunk_size,
+ D=d,
+ BLOCK_D=BLOCK_D,
+ ND=ND,
+ num_warps=4,
+ )
+ return out_acc[:, :S].transpose(0, 1).contiguous()
+
+
+def non_causal_chunked_attn(q: torch.Tensor, k: torch.Tensor, chunk_size: int) -> torch.Tensor:
+ """q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。"""
+ if q.is_cuda and k.is_cuda:
+ return _non_causal_chunked_attn_triton(q, k, chunk_size)
+ return _non_causal_chunked_attn_pytorch(q, k, chunk_size)
+
+
+# ---------------------------------------------------------------------------
+# ×||V|| + avg_pool1d(k=3) — Triton(CUDA)
+# ---------------------------------------------------------------------------
+
+
+@triton.jit
+def _mul_vnorm_avgpool3_kernel(
+ A_ptr,
+ V_ptr,
+ OUT_ptr,
+ stride_al,
+ stride_ah,
+ stride_vl,
+ stride_vh,
+ stride_vd,
+ stride_ol,
+ stride_oh,
+ L,
+ D: tl.constexpr,
+):
+ """Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。"""
+ l = tl.program_id(0)
+ h = tl.program_id(1)
+ offs = tl.arange(0, D)
+
+ pos_m1 = l - 1
+ inb_m1 = (pos_m1 >= 0) & (pos_m1 < L)
+ ps_m1 = tl.where(inb_m1, pos_m1, 0)
+ a_m1 = tl.load(
+ A_ptr + ps_m1 * stride_al + h * stride_ah,
+ mask=inb_m1,
+ other=0.0,
+ ).to(tl.float32)
+ v_m1 = tl.load(
+ V_ptr + ps_m1 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_m1,
+ other=0.0,
+ ).to(tl.float32)
+ s_m1 = tl.where(inb_m1, a_m1 * tl.sqrt(tl.sum(v_m1 * v_m1)), 0.0)
+
+ inb_0 = (l >= 0) & (l < L)
+ ps0 = tl.where(inb_0, l, 0)
+ a0 = tl.load(
+ A_ptr + ps0 * stride_al + h * stride_ah,
+ mask=inb_0,
+ other=0.0,
+ ).to(tl.float32)
+ v0 = tl.load(
+ V_ptr + ps0 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_0,
+ other=0.0,
+ ).to(tl.float32)
+ s_0 = tl.where(inb_0, a0 * tl.sqrt(tl.sum(v0 * v0)), 0.0)
+
+ pos_p1 = l + 1
+ inb_p1 = (pos_p1 >= 0) & (pos_p1 < L)
+ ps_p1 = tl.where(inb_p1, pos_p1, 0)
+ a_p1 = tl.load(
+ A_ptr + ps_p1 * stride_al + h * stride_ah,
+ mask=inb_p1,
+ other=0.0,
+ ).to(tl.float32)
+ v_p1 = tl.load(
+ V_ptr + ps_p1 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_p1,
+ other=0.0,
+ ).to(tl.float32)
+ s_p1 = tl.where(inb_p1, a_p1 * tl.sqrt(tl.sum(v_p1 * v_p1)), 0.0)
+
+ out = (s_m1 + s_0 + s_p1) * (1.0 / 3.0)
+ tl.store(OUT_ptr + l * stride_ol + h * stride_oh, out)
+
+
+def _mul_vnorm_avgpool3_fused(
+ a: torch.Tensor, v: torch.Tensor, out: torch.Tensor | None = None
+) -> torch.Tensor:
+ assert a.dim() == 2 and v.dim() == 3 and a.shape[0] == v.shape[0] and a.shape[1] == v.shape[1]
+ L, H, D = v.shape
+ a = a.contiguous()
+ v = v.contiguous()
+ if a.dtype != torch.float32:
+ a = a.float()
+ if out is None:
+ out = torch.empty((L, H), device=v.device, dtype=torch.float32)
+ if L == 0 or H == 0:
+ return out
+ grid = (L, H)
+ _mul_vnorm_avgpool3_kernel[grid](
+ a,
+ v,
+ out,
+ a.stride(0),
+ a.stride(1),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ out.stride(0),
+ out.stride(1),
+ L,
+ D=D,
+ num_warps=4,
+ )
+ return out
+
+
+def _maybe_mul_vnorm_avgpool3_fused(a: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ if not a.is_cuda or not v.is_cuda:
+ import torch.nn.functional as F
+
+ s = a * v.norm(dim=-1)
+ return (
+ F.avg_pool1d(s.transpose(0, 1).unsqueeze(0), kernel_size=3, padding=1, stride=1)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
+ return _mul_vnorm_avgpool3_fused(a, v)
+
+
+@triton.jit
+def _zscore_elem_1d_kernel(
+ X_ptr,
+ OUT_ptr,
+ n,
+ mean,
+ inv_std,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+ mask = offs < n
+ x = tl.load(X_ptr + offs, mask=mask, other=0.0)
+ tl.store(OUT_ptr + offs, (x - mean) * inv_std, mask=mask)
+
+
+def _zscore_flat_f32_global(x: torch.Tensor) -> torch.Tensor:
+ """
+ 与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。
+ ``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。
+ """
+ if x.numel() == 0:
+ return x
+ mu = x.mean()
+ sig = x.std().clamp_min(1e-6)
+ inv = 1.0 / sig
+ if not x.is_cuda:
+ return (x - mu) * inv
+ x = x.contiguous()
+ out = torch.empty_like(x)
+ n = x.numel()
+ BLOCK = 1024
+ grid = (triton.cdiv(n, BLOCK),)
+ _zscore_elem_1d_kernel[grid](
+ x,
+ out,
+ n,
+ float(mu.item()),
+ float(inv.item()),
+ BLOCK=BLOCK,
+ num_warps=4,
+ )
+ return out
+
+
+def _attn_scores_kvpress_middle(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ sink_start: int,
+ sink_end: int,
+ chunk_size: int,
+ do_zscore: bool = True,
+) -> torch.Tensor:
+ """仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。"""
+ N, HQ, D = q.shape
+ Hkv = k.shape[1]
+ G = HQ // Hkv
+ device = q.device
+ attn_out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
+ parts: list[torch.Tensor] = []
+
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ q_m = q[mid_start:mid_end, :, :].contiguous()
+ k_m = k[mid_start:mid_end, :, :].contiguous()
+ v_m = v[mid_start:mid_end, :, :].contiguous()
+ # HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]``
+ k_4d = k_m.unsqueeze(0).transpose(1, 2).contiguous() # [1, Hkv, Lm, D]
+ k_rep = repeat_kv(k_4d, G)[0].transpose(0, 1).contiguous() # [Lm, HQ, D]
+ A = non_causal_chunked_attn(q_m, k_rep, chunk_size)
+ Lm, HQa = A.shape
+ assert HQa == HQ
+ A = A.view(Lm, Hkv, G).mean(dim=-1)
+ scores = _maybe_mul_vnorm_avgpool3_fused(A, v_m)
+ parts.append(scores.reshape(-1))
+
+ if not parts:
+ return attn_out
+
+ flat_a = torch.cat(parts, dim=0)
+ if do_zscore:
+ z_a = _zscore_flat_f32_global(flat_a)
+ else:
+ z_a = flat_a
+ offset = 0
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ n = (mid_end - mid_start) * Hkv
+ attn_out[mid_start:mid_end, :] = z_a[offset : offset + n].view(
+ mid_end - mid_start, Hkv
+ )
+ offset += n
+ return attn_out
+
+
+def non_causal_attn_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens_qk: torch.Tensor,
+ max_seqlen_qk: int,
+ chunk_size: int,
+ sm_scale: float = None,
+ normalize: bool = True,
+ context_lens: Optional[List[int]] = None,
+ protected_first_tokens: Optional[List[int]] = None,
+ protected_last_tokens: Optional[List[int]] = None,
+ *,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+) -> torch.Tensor:
+ """
+ 与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。
+ ``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。
+ 然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。
+ """
+ del sm_scale, max_seqlen_qk
+ sink_start, sink_end = 8, 4
+ out = _attn_scores_kvpress_middle(
+ q,
+ k,
+ v,
+ cu_seqlens_qk,
+ sink_start,
+ sink_end,
+ chunk_size,
+ do_zscore=normalize,
+ )
+
+ if accum_scores is not None:
+ w = 0.5 if accum_blending is None else float(accum_blending)
+ out = out + w * accum_scores.to(device=out.device, dtype=out.dtype)
+
+ if protected_first_tokens is not None and protected_last_tokens is not None and context_lens:
+ start = 0
+ for first, last, Lc in zip(
+ protected_first_tokens, protected_last_tokens, context_lens
+ ):
+ out[start : start + int(first)].fill_(torch.inf)
+ out[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
+ start += int(Lc)
+ return out
+
+
+def kvpress_compactor_post_rope(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ compression_ctx,
+ max_seqlen_q: int,
+ chunk_size: int,
+ blending: float,
+) -> torch.Tensor:
+ del max_seqlen_q
+ Hkv = k.shape[1]
+ device = q.device
+
+ sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
+ sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
+ context_lens: Optional[List[int]] = getattr(
+ compression_ctx, "context_lens", None
+ )
+ protected_first: Optional[List[int]] = getattr(
+ compression_ctx, "protected_first_tokens", None
+ )
+ protected_last: Optional[List[int]] = getattr(
+ compression_ctx, "protected_last_tokens", None
+ )
+
+ attn_out = _attn_scores_kvpress_middle(
+ q, k, v, cu_seqlens, sink_start, sink_end, chunk_size
+ )
+ lev = pre_rope_scores.to(device=device, dtype=torch.float32)
+ blended = torch.zeros_like(lev)
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ blended[mid_start:mid_end, :] = (
+ blending * lev[mid_start:mid_end, :] + attn_out[mid_start:mid_end, :]
+ )
+
+ pad_val = blended.max()
+ if not torch.isfinite(pad_val) or pad_val == 0:
+ pad_val = torch.tensor(1.0, device=device, dtype=torch.float32)
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if left_keep > 0:
+ blended[k_beg:mid_start, :] = pad_val
+ if right_keep > 0:
+ blended[mid_end:k_end, :] = pad_val
+
+ if protected_first is not None and protected_last is not None and context_lens:
+ start = 0
+ for first, last, Lc in zip(
+ protected_first, protected_last, context_lens
+ ):
+ blended[start : start + int(first)].fill_(torch.inf)
+ blended[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
+ start += int(Lc)
+
+ return blended
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/compactor_origin.py b/vllm/compactor-vllm/src/compactor_vllm/compression/compactor_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..36e2e45c758af75206a28f2143fd4e43ddf9607b
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/compactor_origin.py
@@ -0,0 +1,600 @@
+import logging
+import math
+from typing import List, Optional
+
+import torch
+import triton
+from tqdm.contrib.logging import logging_redirect_tqdm
+from triton import language as tl
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+logger = logging.getLogger(__name__)
+
+
+class CompactorCompression(BaseCompressionMethod):
+ chunk_size: int = 128
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ scores = maybe_execute_in_stream(
+ approximate_leverage_scores,
+ k,
+ compression_context.context_lens,
+ compression_context.PHI,
+ normalize=True,
+ chunk_size=compression_context.compression_chunk_size,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ return maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+
+def split_into_chunks(xs, chunk_size):
+ """
+ Convert a list of sequence lengths into a sequence of coalesced chunk lengths.
+
+ Given an iterable of per-sequence context lengths ``xs`` and a target ``chunk_size``,
+ this helper produces two parallel lists:
+
+ * ``coalesced_chunks`` – lengths of contiguous segments in the
+ **concatenated** sequence space, where each segment corresponds either
+ to a full chunk of size ``chunk_size`` or to a residual "epilogue"
+ tail shorter than ``chunk_size``.
+
+ * ``chunks`` – the actual chunk sizes used within each original sequence.
+ For a length ``n``, we produce ``n // chunk_size`` entries of
+ ``chunk_size`` (the "prologue") and at most one final entry equal to
+ ``n % chunk_size`` (the "epilogue").
+
+ ``chunks`` reflects how each input length is decomposed into
+ fixed-size (plus optional tail) processing blocks, while
+ ``coalesced_chunks`` describes those same blocks after concatenating consecutive
+ chunks of size ``chunk_size``. together
+
+ Example:
+ xs = [257, 127], chunk_size = 128
+ coalesced_chunks = [256, 1, 127]
+ chunks = [128, 128, 1, 127]
+
+ Args:
+ :param xs:
+ Iterable of non-negative integers
+ :param chunk_size:
+ Target chunk size
+
+ Returns:
+ :return Tuple[List[int], List[int]]:
+ ``(coalesced_chunks, chunks)`` as described above.
+ """
+ coalesced_chunks, chunks = [], []
+ for n in xs:
+ nchunks = n // chunk_size
+ prologue = nchunks * chunk_size
+ epilogue = n - prologue
+ if prologue > 0:
+ coalesced_chunks.append(prologue)
+ chunks.extend([chunk_size] * nchunks)
+ if epilogue > 0:
+ coalesced_chunks.append(epilogue)
+ chunks.append(epilogue)
+ return coalesced_chunks, chunks
+
+
+def approximate_leverage_scores(
+ key_states: torch.Tensor, # [N, H, D]
+ context_lens: List[int], # [B]
+ PHI: torch.Tensor, # [D, k]
+ regularizer: float = 5e-3,
+ normalize: bool = False,
+ chunk_size: int = 512,
+) -> torch.Tensor: # returns [N, H]
+ """
+ Approximate leverage scores for keys via randomized sketching.
+
+ This implements a randomized approximation to per-token leverage scores for
+ the key matrix, as described in Compactor: Calibrated Query-Agnostic KV Cache
+ Compression with Approximate Leverage Scores (https://arxiv.org/abs/2507.08143).
+ Args:
+ :param key_states:
+ Tensor of shape ``[N, H, D]`` containing pre-RoPE key states for
+ all tokens across the batch, packed along the sequence dimension.
+ ``N = sum(context_lens)``.
+ :param context_lens:
+ List of per-sequence context lengths, length ``B``.
+ :param PHI:
+ Random projection matrix of shape ``[D, k]`` used to sketch the
+ keys into a lower-dimensional subspace (k < D).
+ :param regularizer:
+ Small positive scalar added to the diagonal of each Gram matrix
+ before SVD to improve numerical stability. Defaults to ``1e-2``.
+ :param normalize:
+ If True, apply per-sequence z-score normalization to the scores
+ across all heads and tokens in a batch.
+ :param chunk_size:
+ Target chunk size along the sequence dimension. If > 0, the
+ concatenated sequence is split into chunks of at most this size
+ before forming Gram matrices and SVD. If ≤ 0, the entire sequence
+ for each context is treated as a single chunk.
+ Returns:
+ :return torch.Tensor:
+ Approximate leverage scores of shape ``[N, H]``, where each row
+ corresponds to a token and each column to a head.
+ """
+ if chunk_size > 0:
+ coalesced_chunk_lens, chunks_lens = split_into_chunks(context_lens, chunk_size)
+ else:
+ coalesced_chunk_lens, chunks_lens = context_lens, context_lens
+ chunk_lens_cuda = torch.tensor([0] + chunks_lens).cuda(non_blocking=True)
+ X = torch.matmul(key_states.transpose(0, 1), PHI)
+ H, N, k = X.shape
+ chunks = torch.split(X, coalesced_chunk_lens, dim=-2)
+ gram_matrices = []
+ for i, L in enumerate(coalesced_chunk_lens):
+ chunk = chunks[i]
+ if chunk_size <= 0 or L % chunk_size != 0:
+ chunk.sub_(chunk.mean(dim=-2, keepdim=True))
+ g = torch.matmul(chunk.transpose(-1, -2), chunk) # [H, k, k]
+ g = g.unsqueeze(1)
+ else:
+ chunk = chunk.view(H, -1, chunk_size, k) # [H, num_chunks, chunk_size, k]
+ chunk.sub_(chunk.mean(dim=-2, keepdim=True))
+ g = torch.matmul(chunk.transpose(-1, -2), chunk) # [H, num_chunks, k, k]
+ gram_matrices.append(g)
+ G = torch.cat(gram_matrices, dim=1).to(torch.float32)
+ diag = G.diagonal(dim1=-2, dim2=-1)
+ diag.add_(regularizer)
+ try:
+ V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
+ except RuntimeError:
+ try:
+ diag = G.diagonal(dim1=-2, dim2=-1)
+ diag.add_(regularizer * 10)
+ V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
+ except RuntimeError:
+ with logging_redirect_tqdm():
+ logger.warning(
+ "GESVDA failed, falling back to QR decomposition, which will be MUCH slower. "
+ "Try increasing chunk_size if this issue persists."
+ )
+ # this is over 50 times slower than using GESVDA
+ return _approximate_leverage_scores_qr_fallback(
+ X=X,
+ chunks_lens=chunks_lens,
+ chunk_lens_cuda=chunk_lens_cuda,
+ normalize=normalize,
+ chunk_size=chunk_size,
+ )
+ SV = (V * S.rsqrt().unsqueeze(-2)).to(X.dtype)
+ start = 0
+ all_scores = []
+ for i, L in enumerate(coalesced_chunk_lens):
+ chunk = chunks[i]
+ if chunk_size <= 0 or L % chunk_size != 0:
+ num_chunks = 1
+ sv = SV[:, start]
+ else:
+ num_chunks = L // chunk_size
+ chunk = chunk.view(H, -1, chunk_size, k) # [H, NC, CS]
+ sv = SV[:, start : start + num_chunks]
+ U = torch.matmul(chunk, sv)
+ scores = (U * U).sum(dim=-1).clamp_min_(0.0).view(H, -1)
+ all_scores.append(scores.transpose(-1, -2))
+ start += num_chunks
+
+ scores = torch.cat(all_scores, dim=0)
+ if normalize:
+ grid = (len(chunks_lens),)
+ cu_k = chunk_lens_cuda.cumsum(dim=0)
+ _zscore_per_batch_epilogue_no_window[grid](
+ scores, cu_k, scores.stride(0), scores.stride(1), H
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue_no_window(
+ OUT, # [Nk, Hk], float32
+ cu_k, # [B+1] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ if k_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+def _approximate_leverage_scores_qr_fallback(
+ X: torch.Tensor, # [H, N, k], already sketched (KΦ) and centered in-place
+ chunks_lens: List[int], # [num_chunks]
+ chunk_lens_cuda: torch.Tensor, # [num_chunks + 1] (prefix base)
+ normalize: bool,
+ chunk_size: int,
+) -> torch.Tensor:
+ H, N, k = X.shape
+ device, dtype = X.device, X.dtype
+ offsets: List[int] = []
+ offset = 0
+ for L in chunks_lens:
+ offsets.append(offset)
+ offset += L
+ if offset != N:
+ raise RuntimeError(
+ f"QR fallback: sum(chunks_lens)={offset} does not match N={N}"
+ )
+
+ blocks = torch.split(X, chunks_lens, dim=-2)
+ scores = torch.empty(N, H, device=device, dtype=dtype)
+ if chunk_size > 0:
+ full_indices = [i for i, L in enumerate(chunks_lens) if L == chunk_size]
+ epi_indices = [i for i, L in enumerate(chunks_lens) if L != chunk_size]
+
+ if full_indices:
+ # stack full chunks
+ full_blocks = torch.stack(
+ [blocks[i] for i in full_indices], dim=0
+ ) # [M, H, CS, k]
+ M, Hf, Lf, kf = full_blocks.shape
+ assert Lf == chunk_size
+
+ # merge (M, H) into a single batch dim for torch.linalg.q
+ full_blocks_2d = full_blocks.view(M * Hf, Lf, kf).to(torch.float32)
+
+ U_full, _ = torch.linalg.qr(full_blocks_2d, mode="reduced")
+ U_full = U_full.to(dtype)
+ scores_full = (U_full * U_full).sum(dim=-1).clamp_min(0.0) # [M * Hf, Lf]
+ scores_full = scores_full.view(M, Hf, Lf).transpose(-1, -2) # [M, H, CS]
+ for m, chunk_idx in enumerate(full_indices):
+ start = offsets[chunk_idx]
+ Lc = chunks_lens[chunk_idx]
+ scores[start : start + Lc].copy_(scores_full[m])
+ else:
+ epi_indices = list(range(len(chunks_lens)))
+
+ for chunk_idx in epi_indices:
+ block = blocks[chunk_idx]
+ _, Lc, _ = block.shape
+ if Lc == 0:
+ continue
+ U_epi, _ = torch.linalg.qr(block.to(torch.float32), mode="reduced")
+ scores_epi = (U_epi * U_epi).sum(dim=-1).to(dtype) # [H, Lc]
+ start = offsets[chunk_idx]
+ scores[start : start + Lc] = scores_epi.transpose(0, 1) # [Lc, H]
+
+ if normalize:
+ grid = (len(chunks_lens),)
+ cu_k = chunk_lens_cuda.cumsum(dim=0)
+ _zscore_per_batch_epilogue_no_window[grid](
+ scores, cu_k, scores.stride(0), scores.stride(1), H
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_M": BM, "BLOCK_K": BK, "WARPSPEC": False}, num_warps=w, num_stages=s
+ )
+ for BM in [64]
+ for BK in [64]
+ for w in [4]
+ for s in [2]
+ ],
+ key=[
+ "QUERY_GROUP_SIZE",
+ "D",
+ "CHUNK_SIZE",
+ ],
+ cache_results=True,
+)
+@triton.jit
+def _non_causal_attn_kernel(
+ Q,
+ K,
+ V,
+ accum_scores,
+ cu_seqlens_qk,
+ #
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_Q_D,
+ STRIDE_K_G,
+ STRIDE_K_N,
+ STRIDE_K_D,
+ STRIDE_V_G,
+ STRIDE_V_N,
+ STRIDE_V_D,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ #
+ CHUNK_SIZE: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ D: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+):
+ TOTAL_QUERIES_PER_BLOCK: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
+ INVERSE_CHUNK: tl.constexpr = 1.0 / CHUNK_SIZE
+ pid_g = tl.program_id(0) # KV head in [0, HKV)
+ pid_b = tl.program_id(1) # batch id
+ pid_m = tl.program_id(2) # chunk id within batch
+
+ off_b = tl.load(cu_seqlens_qk + pid_b)
+ off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
+
+ chunk_start = off_b + pid_m * CHUNK_SIZE
+ chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, off_b1)
+ M = chunk_end - chunk_start
+ if M <= 0:
+ return
+
+ offs_d = tl.arange(0, D)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ # Flattened query rows inside a [BLOCK_M, QUERY_GROUP_SIZE] tile
+ offs_q = tl.arange(0, TOTAL_QUERIES_PER_BLOCK)
+ row_m = offs_q % BLOCK_M # token offset in this tile
+ row_h = offs_q // BLOCK_M # query-group index
+
+ qk_scale = sm_scale * 1.44269504 # convert to log2-domain
+ NEG_INF = -1.0e9
+
+ # Iterate over query tiles within this chunk
+ for qs in tl.range(chunk_start, chunk_end, BLOCK_M):
+ # Global query indices for rows in this tile
+ q_idx = qs + row_m # [TOTAL_QUERIES_PER_BLOCK]
+ q_mask = q_idx < chunk_end # mask for valid rows in this tile
+
+ # Load Q tile: [TOTAL_QUERIES_PER_BLOCK, D]
+ q_ptrs = (
+ Q
+ + pid_g * STRIDE_Q_G
+ + q_idx[:, None] * STRIDE_Q_N
+ + row_h[:, None] * STRIDE_Q_H
+ + offs_d[None, :] * STRIDE_Q_D
+ )
+ q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0)
+
+ # ---- Pass 1: per-row max and denominator over all keys in this chunk ----
+ row_max = tl.full([TOTAL_QUERIES_PER_BLOCK], NEG_INF, tl.float32)
+ row_sum = tl.zeros([TOTAL_QUERIES_PER_BLOCK], dtype=tl.float32)
+
+ for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
+ k_idx = ks + offs_k # [BLOCK_K]
+ k_mask = k_idx < chunk_end # which keys are valid in this tile
+
+ k_ptrs = (
+ K
+ + pid_g * STRIDE_K_G
+ + k_idx[:, None] * STRIDE_K_N
+ + offs_d[None, :] * STRIDE_K_D
+ )
+ k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # [BLOCK_K, D]
+
+ # logits: [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)
+
+ cur_max = tl.max(qk, 1)
+ new_max = tl.maximum(row_max, cur_max)
+
+ # rescale previous sum to new_max (base 2)
+ rescale = tl.math.exp2(row_max - new_max)
+ p = tl.math.exp2(qk - new_max[:, None])
+
+ row_sum = row_sum * rescale + tl.sum(p, 1)
+ row_max = new_max
+
+ # Avoid division by zero for inactive rows
+ denom = tl.where(q_mask, row_sum, 1.0)
+
+ for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
+ k_idx = ks + offs_k
+ k_mask = k_idx < chunk_end
+
+ k_ptrs = (
+ K
+ + pid_g * STRIDE_K_G
+ + k_idx[:, None] * STRIDE_K_N
+ + offs_d[None, :] * STRIDE_K_D
+ )
+ k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)
+
+ # p has shape [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
+ p = tl.math.exp2(qk - row_max[:, None]) / denom[:, None]
+ # zero-out invalid rows / columns
+ p = tl.where(
+ q_mask[:, None], p, INVERSE_CHUNK
+ ) # preserve attention mass in shorter chunks
+
+ contrib = tl.sum(p, 0) # [BLOCK_K], sum over queries & query-groups
+
+ out_ptrs = accum_scores + k_idx * STRIDE_OUT_N + pid_g * STRIDE_OUT_H
+ old = tl.load(out_ptrs, mask=k_mask, other=0.0)
+ new = old + contrib.to(old.dtype)
+ tl.store(out_ptrs, new, mask=k_mask)
+
+
+def non_causal_attn_scores(
+ q: torch.Tensor, # [N, HQ, D]
+ k: torch.Tensor, # [N, HKV, D]
+ v: torch.Tensor, # [N, HKV, D]
+ cu_seqlens_qk: torch.Tensor, # [B + 1]
+ max_seqlen_qk: int,
+ chunk_size: int,
+ sm_scale: float = None,
+ normalize: bool = True,
+ context_lens: Optional[List[int]] = None,
+ protected_first_tokens: Optional[List[int]] = None,
+ protected_last_tokens: Optional[List[int]] = None,
+ *,
+ accum_scores: torch.Tensor = None, # [N, HKV] (float32)
+ accum_blending: float = None,
+) -> torch.Tensor:
+ """
+ :param q: Tensor of shape ``[N, H, D]`` containing post-rope queries
+ :param k: Tensor of shape ``[N, H, D]`` containing post-rope keys
+ :param v: Tensor of shape ``[N, H, D]`` containing values
+ :param cu_seqlens_qk Tensor of shape ``[B + 1]`` demarcating batch boundaries
+ :param max_seqlen_qk int containing the maximum sequence length
+ :param chunk_size: int specifying the size of the chunk to perform non-causal attention over
+ :param sm_scale: float specifying the scaling factor applied to attention scores (1/sqrt(D) if None)
+ :param normalize: bool specifying whether to z-score normalize final attention scores
+ :param context_lens: List[int] specifying the context lengths. CPU version of cu_seqlens_qk.diff(0)
+ :param protected_first_tokens: List[int] specifying how many tokens should be protected at the
+ start of each sequence
+ :param protected_last_tokens: List[int] specifying how many tokens should be protected at the
+ end of each sequence
+ :param accum_scores: Tensor of shape ``[N, H]`` containing key scores that should be accumulated into
+ :param accum_blending float specifying the scaling of ``accum_scores`` prior to adding the new
+ non-causal attention scores. Final output is equivalent to return out + accum_blending * accum_scores
+ """
+ assert q.ndim == 3 and k.ndim == 3
+ assert q.shape[0] == k.shape[0] and q.shape[-1] == k.shape[-1]
+ N, HQ, D = q.shape
+ HKV = k.shape[1]
+ assert HQ % HKV == 0, "Number of query heads must divide number of KV heads"
+ assert (D & (D - 1)) == 0, "D must be a power of two"
+
+ B = cu_seqlens_qk.numel() - 1
+ H_g = HQ // HKV # query-group size per KV head
+
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+ out = torch.zeros(N, HKV, device=q.device, dtype=torch.float32)
+ q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+ k = k.view(N, HKV, D).permute(1, 0, 2)
+ # v = v.view(N, HKV, D).permute(1, 0, 2)
+
+ if cu_seqlens_qk.device != q.device:
+ cu_seqlens_qk = cu_seqlens_qk.to(device=q.device)
+ cu_seqlens_qk = cu_seqlens_qk.to(torch.int32)
+
+ STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
+ STRIDE_K_G, STRIDE_K_N, STRIDE_K_D = k.stride()
+ STRIDE_V_G, STRIDE_V_N, STRIDE_V_D = v.stride()
+ STRIDE_OUT_N, STRIDE_OUT_H = out.stride()
+
+ assert STRIDE_Q_D == 1 and STRIDE_K_D == 1, "last dim must be contiguous"
+
+ def grid(_):
+ return (
+ HKV,
+ B,
+ triton.cdiv(max_seqlen_qk, chunk_size),
+ )
+
+ _non_causal_attn_kernel[grid](
+ q,
+ k,
+ v,
+ out,
+ cu_seqlens_qk,
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_Q_D,
+ STRIDE_K_G,
+ STRIDE_K_N,
+ STRIDE_K_D,
+ STRIDE_V_G,
+ STRIDE_V_N,
+ STRIDE_V_D,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ CHUNK_SIZE=chunk_size,
+ QUERY_GROUP_SIZE=H_g,
+ D=D,
+ )
+ if normalize:
+ grid = (B,)
+ _zscore_per_batch_epilogue_no_window[grid](
+ out, cu_seqlens_qk, out.stride(0), out.stride(1), HKV
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ out += accum_scores * accum_blending
+ else:
+ out += accum_scores
+ if protected_first_tokens is not None or protected_last_tokens is not None:
+ start = 0
+ for first, last, L in zip(
+ protected_first_tokens, protected_last_tokens, context_lens
+ ):
+ out[start : start + first].fill_(torch.inf)
+ out[start + L - last : start + L].fill_(torch.inf)
+ start += L
+ return out
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/compression_config.py b/vllm/compactor-vllm/src/compactor_vllm/compression/compression_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e861e663644b0ff6e9d0d2641e6940e6514bf6f3
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/compression_config.py
@@ -0,0 +1,45 @@
+import logging
+from dataclasses import dataclass
+from enum import Enum, auto
+
+logger = logging.getLogger(__name__)
+
+
+class CompressionMethod(Enum):
+ CRITICALADAKV = auto()
+ COMPACTOR = auto()
+ SNAPKV = auto()
+ NONE = auto()
+
+
+# class CachingPolicy(Enum):
+# CACHE_PROMPT = auto()
+# DONT_CACHE = auto()
+
+
+# class CompressionType(Enum):
+# QUERY_AWARE = auto()
+# QUERY_AGNOSTIC = auto()
+
+
+@dataclass
+class SequenceCompressionParams:
+ compression_ratio: float = 1.0
+ protected_first_tokens: int = 16
+ protected_last_tokens: int = 64
+
+
+@dataclass
+class BatchCompressionParams:
+ # compression_type: CompressionType = CompressionType.QUERY_AGNOSTIC
+ compression_method: CompressionMethod = CompressionMethod.COMPACTOR
+
+ do_chunked_compression: bool = True
+ chunk_size: int = 512
+
+ def __post_init__(self):
+ if self.compression_method == CompressionMethod.SNAPKV:
+ self.do_chunked_compression = False
+ logger.warning(
+ "CompressionMethod.SNAPKV is not compatible with chunked compression. Disabling it."
+ )
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv-cursor.py b/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv-cursor.py
new file mode 100644
index 0000000000000000000000000000000000000000..20aaec214a77030ced076bb4bb40c7c2c03c1210
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv-cursor.py
@@ -0,0 +1,459 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)及首/尾保护段长度。
+
+注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.compression.compactor import (
+ CompactorCompression,
+ non_causal_attn_scores,
+)
+from compactor_vllm.compression.snapkv import SnapKVCompression
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+
+
+# ============================================================================
+# Triton Kernel 1: 计算 ||Wo @ V||₁ (L1 范数)
+# ============================================================================
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ],
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+# ============================================================================
+# Triton Kernel 2: Stage 1 保护 + Stage 2 加权融合
+# ============================================================================
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
+ key=["Hk"],
+ cache_results=True,
+)
+@triton.jit
+def _critical_ada_fuse_kernel(
+ BASE_SCORES,
+ WO_V_NORM,
+ STAGE1_MASK,
+ cu_k,
+ OUT,
+ EPSILON: tl.constexpr,
+ STRIDE_BS_NK,
+ STRIDE_BS_HK,
+ STRIDE_WN_NK,
+ STRIDE_WN_HK,
+ STRIDE_S1_NK,
+ STRIDE_S1_HK,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
+ wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
+ s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
+
+ base = tl.load(bs_ptrs, mask=kmask, other=0.0)
+ wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
+ stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
+
+ fused = (base + EPSILON) * wnorm
+ fused = tl.where(stage1_protect == 1, float("inf"), fused)
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, fused, mask=kmask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 在每条序列上尽量贴近 kvpress 的 CriticalAdaKV 语义:
+ 1) alpha_safeguard 安全预算(每头至少保留一部分);
+ 2) 基于 base_scores 的 head-wise 自适应预算分配(head_budgets);
+ 3) Stage-1 按 head_budgets * first_stage_ratio 保护;
+ 4) Stage-2 计算 ``(base + eps) * ||Wo@V||_1``,再按 head_budgets 做每头 top-k 保护。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``、``protected_first_tokens``、``protected_last_tokens``;
+ 可选 ``critical_ada_epsilon``、``critical_ada_first_stage_ratio``、
+ ``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ prot_first = compression_ctx.protected_first_tokens or [0] * B
+ prot_last = compression_ctx.protected_last_tokens or [0] * B
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(getattr(compression_ctx, "critical_ada_alpha_safeguard", 0.2))
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
+ # kvpress 风格的每头预算(按序列自适应),用于 Stage-1/Stage-2。
+ head_budgets_by_batch = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ s = int(prot_first[b]) if b < len(prot_first) else 0
+ e = int(prot_last[b]) if b < len(prot_last) else 0
+ lo, hi = k_beg + s, k_end - e
+ compressible = max(0, hi - lo)
+ keep_pairs = int(btr[b].item())
+ if compressible <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ # 每头 token 预算(kvpress 的 n_kept)
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, compressible)
+ # 安全预算(每头至少保留 n_safe)
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ if n_safe > 0:
+ tk_safe = min(n_safe, compressible)
+ for hk in range(Hk):
+ safe_idx = torch.topk(base_scores[lo:hi, hk], tk_safe, sorted=False).indices
+ stage1_mask[lo + safe_idx, hk] = 1
+
+ # 自适应预算分配:在扁平 (token, head) 空间取 top n_kept_tokens*Hk,统计每个 head 的预算
+ budget_scores = base_scores[lo:hi, :].clone()
+ if n_safe > 0:
+ budget_scores[stage1_mask[lo:hi, :] == 1] = float("inf")
+ top_pairs = min(n_kept_tokens * Hk, budget_scores.numel())
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ top_idx_flat = torch.topk(
+ budget_scores.reshape(-1), top_pairs, sorted=False
+ ).indices
+ top_head_idx = top_idx_flat % Hk
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
+ head_budgets_by_batch.append(head_budgets)
+
+ # Stage-1:按 head_budgets 的 first_stage_ratio 分头保护(kvpress 语义)
+ for hk in range(Hk):
+ phase1_budget = int(head_budgets[hk].item() * first_stage_ratio)
+ if phase1_budget <= 0:
+ continue
+ tk = min(phase1_budget, compressible)
+ top_idx = torch.topk(base_scores[lo:hi, hk], tk, sorted=False).indices
+ stage1_mask[lo + top_idx, hk] = 1
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_fuse(_META):
+ return (B, Hk)
+
+ _critical_ada_fuse_kernel[grid_fuse](
+ base_scores,
+ wo_v_norm,
+ stage1_mask,
+ cu_seqlens,
+ final_scores,
+ EPSILON=epsilon,
+ *base_scores.stride(),
+ *wo_v_norm.stride(),
+ *stage1_mask.stride(),
+ *final_scores.stride(),
+ Hk=Hk,
+ )
+
+ # Stage-2(kvpress 语义):在融合后按每头预算再做一次 top-k 保护。
+ for b in range(B):
+ hb = head_budgets_by_batch[b]
+ if hb is None:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ s = int(prot_first[b]) if b < len(prot_first) else 0
+ e = int(prot_last[b]) if b < len(prot_last) else 0
+ lo, hi = k_beg + s, k_end - e
+ if hi <= lo:
+ continue
+ region_len = hi - lo
+ for hk in range(Hk):
+ budget = int(hb[hk].item())
+ if budget <= 0:
+ continue
+ tk = min(budget, region_len)
+ idx = torch.topk(final_scores[lo:hi, hk], tk, sorted=False).indices
+ final_scores[lo + idx, hk] = float("inf")
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx % Hk
+ seq_idx = prune_idx // Hk + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
+ 再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = getattr(cc, "critical_ada_base_scorer", "compactor") if cc is not None else "compactor"
+ if str(base).lower() == "snapkv":
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "snapkv":
+ base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
+ else:
+ # 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
+ # maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
+ # 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ base_scores = maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv.py b/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv.py
new file mode 100644
index 0000000000000000000000000000000000000000..94ba718111580a9b04dd9521121eaf024356815c
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv.py
@@ -0,0 +1,451 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)。CriticalAda 主链在 **PyTorch** 中与 kvpress ``CriticalAdaKVPress.compress``
+对齐;``||Wo@V||_1`` 仍默认用 Triton ``_compute_wo_v_l1_kernel``(与 ``CriticalKVPress.vwl1norm`` 同式)。
+将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 可改走 ``_vwl1_norm_kvpress_reference``。
+
+注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.compression.compactor import (
+ CompactorCompression,
+ kvpress_compactor_post_rope,
+ resolve_kvpress_compactor_blending,
+)
+from compactor_vllm.compression.snapkv import SnapKVCompression
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
+_USE_WO_L1_REFERENCE_BACKEND = False
+
+
+def _vwl1_norm_kvpress_reference(
+ values_seg: torch.Tensor,
+ wo: torch.Tensor,
+ num_kv_heads: int,
+ num_query_groups: int,
+) -> torch.Tensor:
+ """
+ 与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
+ 将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
+
+ 算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
+ """
+ k_len, Hk, D = values_seg.shape
+ Hq, D_wo, hidden = wo.shape
+ assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
+ # [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
+ v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
+ v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
+ # Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
+ wo_f = wo
+ head_list = []
+ for head in range(Hq):
+ v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
+ head_wov = v_h.matmul(wo_f[head, :, :])
+ head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
+ head_list.append(head_wov_norm)
+ stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
+ stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
+ return stacked.transpose(0, 1).contiguous()
+
+
+# ============================================================================
+# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
+# ============================================================================
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ],
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 按 kvpress ``CriticalAdaKVPress.compress`` 的顺序实现:safeguard scatter →
+ head-major 展平做 head_budgets → Stage1 在 **已抬高** 的分数上 top-k →
+ ``(scores + ε) * ||WoV||₁`` → Stage2 scatter → 最终按 head-major 展平做 bottom-k。
+
+ ``||Wo@V||₁`` 仍用 Triton(``_compute_wo_v_l1_kernel``);中间 CriticalAda 步骤用 PyTorch
+ 与 kvpress 逐句对齐。仅 base 分数来自 Compactor/SnapKV。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
+ ``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ if B > 0 and int(k_lengths.max().item()) > 0:
+ if _USE_WO_L1_REFERENCE_BACKEND:
+ for b in range(B):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_end <= k_beg:
+ continue
+ v_seg = v[k_beg:k_end, :, :].contiguous()
+ wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
+ v_seg, wo, Hk, G
+ )
+ else:
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ # kvpress 用 finfo.max 抬高分数;与 inf 混用时 topk 行为一致
+ _score_max = float(torch.finfo(torch.float32).max)
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ head_budgets_by_batch: list[Optional[torch.Tensor]] = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+
+ scores_seg = base_scores[k_beg:k_end, :].float()
+ keep_pairs = int(btr[b].item())
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, k_len)
+
+ # scores_work: 布局 [k_len, Hk],对应 kvpress [bsz=1, H, k_len] 的 transpose(0,2) 视角下沿 token 维的 topk
+ scores_work = scores_seg.clone()
+
+ # --- Alpha safeguard(kvpress L148–152)---
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ nk = min(n_safe, k_len) if n_safe > 0 else 0
+ if nk > 0:
+ for hk in range(Hk):
+ top_idx = torch.topk(scores_work[:, hk], nk, dim=0, largest=True).indices
+ scores_work[top_idx, hk] = _score_max
+
+ # --- Head budgets:kvpress L158–164,展平顺序与 [bsz, H, k_len] 一致(head-major:h*K + t)---
+ top_pairs = min(n_kept_tokens * Hk, k_len * Hk)
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ wn = wo_v_norm[k_beg:k_end, :]
+ final_scores[k_beg:k_end, :] = (scores_seg + epsilon) * wn
+ continue
+
+ budget_flat = scores_work.permute(1, 0).contiguous().reshape(-1)
+ top_idx_flat = torch.topk(
+ budget_flat, top_pairs, largest=True, sorted=False
+ ).indices
+ top_head_idx = top_idx_flat // k_len
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int64)
+ head_budgets_by_batch.append(head_budgets)
+
+ # --- Stage 1(kvpress L166–171):在已 safeguard 的 scores_work 上沿 token 维 top-k ---
+ head_selection_budget_1st = (
+ (head_budgets.to(torch.float32) * float(first_stage_ratio))
+ .to(torch.int64)
+ .tolist()
+ )
+ M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
+ mk = min(M1, k_len) if M1 > 0 else 0
+ if mk > 0:
+ top_k_index = torch.topk(scores_work, mk, dim=0, largest=True, sorted=True).indices
+ for hk in range(Hk):
+ phase1_budget = int(head_selection_budget_1st[hk])
+ if phase1_budget <= 0:
+ continue
+ take = min(phase1_budget, mk)
+ scores_work[top_k_index[:take, hk], hk] = _score_max
+
+ # --- Stage 2 重加权(kvpress L173–175)---
+ wn = wo_v_norm[k_beg:k_end, :]
+ scores_fused = (scores_work + epsilon) * wn
+
+ # --- Stage 2 scatter(kvpress L176–179)---
+ M2 = int(head_budgets.max().item())
+ mk2 = min(M2, k_len) if M2 > 0 else 0
+ if mk2 > 0:
+ top_k_index2 = torch.topk(
+ scores_fused, mk2, dim=0, largest=True, sorted=True
+ ).indices
+ for hk in range(Hk):
+ budget = int(head_budgets[hk].item())
+ if budget <= 0:
+ continue
+ take = min(budget, mk2)
+ scores_fused[top_k_index2[:take, hk], hk] = _score_max
+
+ final_scores[k_beg:k_end, :] = scores_fused
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ # kvpress L187:``scores.reshape(bsz, -1)`` 即 [H, K] 按 head-major 展平(flat = h*K + t)
+ flat_scores = (
+ final_scores[k_beg:k_end, :].permute(1, 0).contiguous().reshape(-1)
+ )
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx // k_len
+ seq_idx = prune_idx % k_len + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 仅 ``critical_ada_base_scorer == "compactor"`` 时与 kvpress ``CompactorPress.score`` 一致
+ (``kvpress_compactor_post_rope``:``blending * l_scores + attn_scores``);其它 base(如 SnapKV)
+ 走对应单一 ScorerPress,再叠 CriticalAda。须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = (
+ getattr(cc, "critical_ada_base_scorer", "compactor")
+ if cc is not None
+ else "compactor"
+ )
+ if str(base).lower() == "compactor":
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "compactor":
+ # 特例:与 ``CompactorPress.score`` / ``CompactorCompression.post_rope_scoring`` 一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ blending = resolve_kvpress_compactor_blending(compression_context)
+ base_scores = maybe_execute_in_stream(
+ kvpress_compactor_post_rope,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ pre_rope_scores,
+ compression_context,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ blending=float(blending),
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ else:
+ base_scores = SnapKVCompression.post_rope_scoring(
+ q, k, v, pre_rope_scores, context
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv_origin.py b/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5964c95908ddd2529c97e5cb617ec11ebffa878
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv_origin.py
@@ -0,0 +1,502 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)。Stage1/2 与 kvpress 论文/实现一致;``||Wo@V||_1`` 在 **算法上** 与
+``CriticalKVPress.vwl1norm`` 相同(GQA 上逐 query 头 L1 再对组取均值)。**默认用 Triton**
+(``_compute_wo_v_l1_kernel``);若需与 PyTorch 逐行对齐,将模块内 ``_USE_WO_L1_REFERENCE_BACKEND`` 改为 ``True`` 即走 ``_vwl1_norm_kvpress_reference``。
+
+注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.compression.compactor import (
+ CompactorCompression,
+ non_causal_attn_scores,
+)
+from compactor_vllm.compression.snapkv import SnapKVCompression
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
+_USE_WO_L1_REFERENCE_BACKEND = False
+
+
+def _vwl1_norm_kvpress_reference(
+ values_seg: torch.Tensor,
+ wo: torch.Tensor,
+ num_kv_heads: int,
+ num_query_groups: int,
+) -> torch.Tensor:
+ """
+ 与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
+ 将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
+
+ 算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
+ """
+ k_len, Hk, D = values_seg.shape
+ Hq, D_wo, hidden = wo.shape
+ assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
+ # [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
+ v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
+ v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
+ # Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
+ wo_f = wo
+ head_list = []
+ for head in range(Hq):
+ v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
+ head_wov = v_h.matmul(wo_f[head, :, :])
+ head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
+ head_list.append(head_wov_norm)
+ stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
+ stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
+ return stacked.transpose(0, 1).contiguous()
+
+
+# ============================================================================
+# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
+# ============================================================================
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ],
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+# ============================================================================
+# Triton:Stage 1 保护 + Stage 2 加权融合(逐元素)
+# ============================================================================
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
+ key=["Hk"],
+ cache_results=True,
+)
+@triton.jit
+def _critical_ada_fuse_kernel(
+ BASE_SCORES,
+ WO_V_NORM,
+ STAGE1_MASK,
+ cu_k,
+ OUT,
+ STRIDE_BS_NK,
+ STRIDE_BS_HK,
+ STRIDE_WN_NK,
+ STRIDE_WN_HK,
+ STRIDE_S1_NK,
+ STRIDE_S1_HK,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ EPSILON: tl.constexpr,
+ Hk: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
+ wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
+ s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
+
+ base = tl.load(bs_ptrs, mask=kmask, other=0.0)
+ wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
+ stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
+
+ fused = (base + EPSILON) * wnorm
+ fused = tl.where(stage1_protect == 1, float("inf"), fused)
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, fused, mask=kmask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 在每条序列上对齐 kvpress ``CriticalAdaKVPress.compress``(整段 ``k_len``、与源实现相同的
+ top-k / scatter 顺序);仅 base 分数来自 compactor_vllm 的 Compactor/SnapKV。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
+ ``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ if B > 0 and int(k_lengths.max().item()) > 0:
+ if _USE_WO_L1_REFERENCE_BACKEND:
+ for b in range(B):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_end <= k_beg:
+ continue
+ v_seg = v[k_beg:k_end, :, :].contiguous()
+ wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
+ v_seg, wo, Hk, G
+ )
+ else:
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
+ head_budgets_by_batch: list[Optional[torch.Tensor]] = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ keep_pairs = int(btr[b].item())
+ scores_seg = base_scores[k_beg:k_end, :]
+ # 与 kvpress 的 n_kept 一致:每头保留 n_kept 个 token
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, k_len)
+
+ # kvpress:topk 在「未改动的」scores 上取索引,scatter 只写在副本上,供 head_budgets 用;
+ # Stage1 仍用原始 scores_seg(见下)。
+ working = scores_seg.clone()
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ if n_safe > 0:
+ nk = min(n_safe, k_len)
+ for hk in range(Hk):
+ top_idx = torch.topk(scores_seg[:, hk], nk, sorted=True).indices
+ working[:, hk].scatter_(0, top_idx, float("inf"))
+
+ top_pairs = min(n_kept_tokens * Hk, working.numel())
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ top_idx_flat = torch.topk(working.reshape(-1), top_pairs, sorted=False).indices
+ top_head_idx = top_idx_flat % Hk
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
+ head_budgets_by_batch.append(head_budgets)
+
+ # Stage 1:与 kvpress 相同 — 先 topk(..., M1, sorted=True),再每头取前 phase1 个下标
+ head_selection_budget_1st = (
+ (head_budgets.to(torch.float32) * float(first_stage_ratio))
+ .to(torch.int64)
+ .tolist()
+ )
+ M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
+ if M1 > 0:
+ mk = min(M1, k_len)
+ for hk in range(Hk):
+ phase1_budget = int(head_selection_budget_1st[hk])
+ if phase1_budget <= 0:
+ continue
+ full_idx = torch.topk(scores_seg[:, hk], mk, sorted=True).indices
+ take = min(phase1_budget, mk)
+ stage1_mask[k_beg + full_idx[:take], hk] = 1
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_fuse(_META):
+ return (B, Hk)
+
+ _critical_ada_fuse_kernel[grid_fuse](
+ base_scores,
+ wo_v_norm,
+ stage1_mask,
+ cu_seqlens,
+ final_scores,
+ *base_scores.stride(),
+ *wo_v_norm.stride(),
+ *stage1_mask.stride(),
+ *final_scores.stride(),
+ Hk=Hk,
+ EPSILON=float(epsilon),
+ )
+
+ # Stage 2(kvpress):对融合后分数先 topk(..., M2, sorted=True),再每头取前 budget 个下标置 inf
+ for b in range(B):
+ hb = head_budgets_by_batch[b]
+ if hb is None:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ k_len = k_end - k_beg
+ if k_len <= 0:
+ continue
+ fused_seg = final_scores[k_beg:k_end, :]
+ M2 = int(hb.max().item())
+ if M2 <= 0:
+ continue
+ mk = min(M2, k_len)
+ for hk in range(Hk):
+ budget = int(hb[hk].item())
+ if budget <= 0:
+ continue
+ full_idx = torch.topk(fused_seg[:, hk], mk, sorted=True).indices
+ take = min(budget, mk)
+ final_scores[k_beg + full_idx[:take], hk] = float("inf")
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx % Hk
+ seq_idx = prune_idx // Hk + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
+ 再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = getattr(cc, "critical_ada_base_scorer", "snapkv") if cc is not None else "compactor"
+ if str(base).lower() == "snapkv":
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "snapkv":
+ base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
+ else:
+ # 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
+ # maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
+ # 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ base_scores = maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/snapkv.py b/vllm/compactor-vllm/src/compactor_vllm/compression/snapkv.py
new file mode 100644
index 0000000000000000000000000000000000000000..e79a62660cbf0fb41d89e0e8ada171f0e417a6b6
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/snapkv.py
@@ -0,0 +1,546 @@
+import math
+from typing import Optional
+
+import torch
+import triton
+from triton import language as tl
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
+DEFAULT_SNAPKV_WINDOW_SIZE = 64
+DEFAULT_SNAPKV_KERNEL_SIZE = 5
+
+
+class SnapKVCompression(BaseCompressionMethod):
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ scores = maybe_execute_in_stream(
+ query_aware_key_scores,
+ q,
+ k,
+ context.cu_seqlens_q,
+ context.cu_seqlens_k,
+ w=DEFAULT_SNAPKV_WINDOW_SIZE,
+ kernel_size=DEFAULT_SNAPKV_KERNEL_SIZE,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
+ )
+ for bq in [32, 64]
+ for bk in [32, 64]
+ for num_warps in [4, 8]
+ for num_stages in [3, 4]
+ ],
+ key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
+ cache_results=True,
+)
+@triton.jit
+def _lse_and_store_logits_kernel(
+ Q,
+ K,
+ cu_q,
+ cu_k,
+ w_b, # int32 pointers
+ out_m,
+ out_S, # [B, Hk, ROWS_MAX] float32
+ LOGITS, # [Nk, Hk, ROWS_MAX] float32
+ sm_scale, # float
+ QUERY_GROUP_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ STRIDE_Q_NQ,
+ STRIDE_Q_HQ,
+ STRIDE_K_NK,
+ STRIDE_K_HK,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ ROWS_MAX,
+):
+ # program ids
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ rid = tl.program_id(2) # row-tile id
+ # batch segment bounds
+ q_end = tl.load(cu_q + b + 1)
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ q_win_beg = q_end - win
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ # rows for this (b,hk)
+ rows_b = win * QUERY_GROUP_SIZE
+ row0 = rid * BLOCK_Q
+ if row0 >= rows_b:
+ return
+
+ # exp(x) = exp2(x * 1/ln2)
+ qk_scale = sm_scale * 1.4426950408889634
+
+ offs_qrow = row0 + tl.arange(0, BLOCK_Q)
+ row_mask = offs_qrow < rows_b
+
+ # map row -> (q_idx, hq_local)
+ hq_local = offs_qrow % QUERY_GROUP_SIZE
+ q_off = offs_qrow // QUERY_GROUP_SIZE
+ q_idx = q_win_beg + q_off
+ hq_glob = hk * QUERY_GROUP_SIZE + hq_local
+
+ offs_d = tl.arange(0, D)
+
+ q_ptrs = (
+ Q
+ + q_idx[:, None] * STRIDE_Q_NQ
+ + hq_glob[:, None] * STRIDE_Q_HQ
+ + offs_d[None, :]
+ )
+ q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+ m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
+ S = tl.zeros([BLOCK_Q], dtype=tl.float32)
+
+ # Full-sequence causal attention (matches kvpress softmax), then use prefix columns only.
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
+
+ s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
+ s = tl.where(kmask[None, :], s, -float("inf"))
+ # Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
+ causal_ok = nk[None, :] <= q_idx[:, None]
+ s = tl.where(causal_ok, s, -float("inf"))
+
+ # store prefix logits only (for marginal probs on prefix keys)
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ store_mask = kmask & (nk < k_eff_end)
+ tl.store(log_ptrs, s.T, mask=store_mask[:, None] & row_mask[None, :])
+
+ # log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask)
+ cur_max = tl.max(s, 1) # [BQ]
+ n_m = tl.maximum(m, cur_max)
+ rescale = tl.math.exp2(m - n_m)
+ S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
+ m = n_m
+
+ # store m,S for these rows
+ m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
+ tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
+
+
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
+ for bq in [16, 32, 64]
+ for bk in [32, 64, 128]
+ ],
+ key=["HK", "HQ"],
+ cache_results=True,
+)
+@triton.jit
+def _prefix_probs_kernel(
+ cu_k,
+ w_b,
+ in_m,
+ in_S, # [B, Hk, ROWS_MAX] f32
+ LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only)
+ PROBS, # [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs
+ #
+ QUERY_GROUP_SIZE: tl.constexpr,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ STRIDE_PB_NK,
+ STRIDE_PB_HK,
+ STRIDE_PB_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ rows_b = win * QUERY_GROUP_SIZE
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ for row0 in tl.range(0, rows_b, BLOCK_Q):
+ r_idx = row0 + tl.arange(0, BLOCK_Q)
+ rmask = r_idx < rows_b
+
+ m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ m = tl.load(
+ m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
+ mask=rmask,
+ other=-float("inf"),
+ )
+ S = tl.load(
+ S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
+ )
+
+ valid_row = S > 0
+ m = tl.where(valid_row, m, 0.0)
+ S = tl.where(valid_row, S, 1.0)
+
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ s_T = tl.load(
+ log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
+ ) # [BK, BQ]
+
+ probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
+ probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
+
+ prob_ptrs = (
+ PROBS
+ + nk[:, None] * STRIDE_PB_NK
+ + hk * STRIDE_PB_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_PB_R
+ )
+ tl.store(prob_ptrs, probs_T, mask=kmask[:, None] & rmask[None, :])
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue(
+ OUT, # [Nk, Hk], float32
+ cu_k,
+ w_b, # [B+1], [B] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ EPS: tl.constexpr, # e.g., 1e-12
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if k_eff_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_eff_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var + EPS)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_T": bt}) for bt in [32, 64, 128, 256]],
+ key=["KERNEL_SIZE"],
+ cache_results=True,
+)
+@triton.jit
+def _snapkv_avg_pool1d_kernel(
+ IN,
+ OUT,
+ Lp,
+ STRIDE_IN_C,
+ STRIDE_IN_L,
+ STRIDE_OUT_C,
+ STRIDE_OUT_L,
+ KERNEL_SIZE: tl.constexpr,
+ PAD: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+):
+ """
+ Symmetric 1D average pool on the last dimension, matching
+ `F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp]
+ (equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size).
+ """
+ c = tl.program_id(0)
+ t0 = tl.program_id(1) * BLOCK_T + tl.arange(0, BLOCK_T)
+ mask = t0 < Lp
+
+ acc = tl.zeros([BLOCK_T], dtype=tl.float32)
+ for j in tl.static_range(KERNEL_SIZE):
+ idx = t0 - PAD + j
+ valid = (idx >= 0) & (idx < Lp)
+ ptrs = IN + c * STRIDE_IN_C + idx * STRIDE_IN_L
+ v = tl.load(ptrs, mask=valid & mask, other=0.0).to(tl.float32)
+ acc += v
+ acc = acc / tl.cast(KERNEL_SIZE, tl.float32)
+
+ out_ptrs = OUT + c * STRIDE_OUT_C + t0 * STRIDE_OUT_L
+ tl.store(out_ptrs, acc, mask=mask)
+
+
+def _snapkv_avg_pool1d_triton(x: torch.Tensor, kernel_size: int) -> torch.Tensor:
+ """
+ kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp].
+ `x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]).
+ """
+ assert x.dtype == torch.float32
+ Hk, G, Lp = x.shape
+ if Lp == 0:
+ return x
+ pad = kernel_size // 2
+ x2 = x.reshape(Hk * G, Lp).contiguous()
+ out = torch.empty_like(x2)
+ C = Hk * G
+ si_c, si_l = x2.stride()
+ so_c, so_l = out.stride()
+
+ def grid(meta):
+ return (C, triton.cdiv(Lp, meta["BLOCK_T"]))
+
+ _snapkv_avg_pool1d_kernel[grid](
+ x2,
+ out,
+ Lp,
+ si_c,
+ si_l,
+ so_c,
+ so_l,
+ KERNEL_SIZE=kernel_size,
+ PAD=pad,
+ )
+ return out.view(Hk, G, Lp)
+
+
+def _snapkv_kvpress_epilogue(
+ probs_buf: torch.Tensor,
+ out: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ w: torch.Tensor,
+ G: int,
+ Hk: int,
+ kernel_size: int,
+) -> None:
+ """
+ Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d
+ → mean over GQA groups → pad tail with global max of prefix scores.
+ """
+ B = cu_seqlens_k.numel() - 1
+ for b in range(B):
+ k_beg = int(cu_seqlens_k[b].item())
+ k_end = int(cu_seqlens_k[b + 1].item())
+ win = int(w[b].item())
+ k_eff_end = k_end - win
+ if win <= 0 or k_eff_end <= k_beg:
+ continue
+ Lp = k_eff_end - k_beg
+ rows_b = win * G
+ p = probs_buf[k_beg:k_eff_end, :, :rows_b]
+ # [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout
+ x = p.view(Lp, Hk, win, G).mean(dim=2)
+ x = x.permute(1, 2, 0).contiguous() # [Hk, G, Lp]
+ x = _snapkv_avg_pool1d_triton(x, kernel_size)
+ x = x.mean(dim=1)
+ seg = x.permute(1, 0).contiguous()
+ out[k_beg:k_eff_end, :] = seg
+ pad_val = seg.max()
+ out[k_eff_end:k_end, :] = pad_val
+
+
+def query_aware_key_scores(
+ q: torch.Tensor, # [N_q, Hq, D]
+ k: torch.Tensor, # [N_k, Hk, D]
+ cu_seqlens_q: torch.Tensor, # [B+1], int32
+ cu_seqlens_k: torch.Tensor, # [B+1], int32
+ w: torch.Tensor | int, # [B], int32
+ sm_scale: float = None, # defaults to 1/sqrt(D)
+ *,
+ kernel_size: int = DEFAULT_SNAPKV_KERNEL_SIZE,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+ normalize: bool = False,
+) -> Optional[torch.Tensor]:
+ assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
+ device = q.device
+ N_q, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ B = cu_seqlens_q.numel() - 1
+ assert B == cu_seqlens_k.numel() - 1
+
+ G = Hq // Hk
+ if type(w) is int:
+ max_w = w
+ w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
+ else:
+ max_w = int(w.max().item())
+ assert w.numel() == B
+ ROWS_MAX = max_w * G
+ if ROWS_MAX == 0:
+ return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+
+ out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+ m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ probs_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+
+ # strides
+ STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
+ STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
+ STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
+ STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
+ STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
+ STRIDE_PB_NK, STRIDE_PB_HK, STRIDE_PB_R = probs_buf.stride()
+ STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
+
+ def grid(META):
+ return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
+
+ _lse_and_store_logits_kernel[grid](
+ q,
+ k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ sm_scale,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ D=D,
+ STRIDE_Q_NQ=STRIDE_Q_NQ,
+ STRIDE_Q_HQ=STRIDE_Q_HQ,
+ STRIDE_K_NK=STRIDE_K_NK,
+ STRIDE_K_HK=STRIDE_K_HK,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ ROWS_MAX=ROWS_MAX,
+ )
+
+ _prefix_probs_kernel[(B, Hk)](
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ probs_buf,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ STRIDE_PB_NK=STRIDE_PB_NK,
+ STRIDE_PB_HK=STRIDE_PB_HK,
+ STRIDE_PB_R=STRIDE_PB_R,
+ )
+ _snapkv_kvpress_epilogue(
+ probs_buf, out, cu_seqlens_k, w, G, Hk, kernel_size
+ )
+ if normalize:
+ _zscore_per_batch_epilogue[(B,)](
+ out,
+ cu_seqlens_k,
+ w,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK=Hk,
+ EPS=1e-12,
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ accum_scores.mul_(accum_blending)
+ accum_scores.add_(out)
+ return accum_scores
+ else:
+ return out
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/compression/snapkv_origin.py b/vllm/compactor-vllm/src/compactor_vllm/compression/snapkv_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eaaba64a384f4bc74840724168710439eec9a16
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/compression/snapkv_origin.py
@@ -0,0 +1,449 @@
+import math
+from typing import Optional
+
+import torch
+import triton
+from triton import language as tl
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+
+class SnapKVCompression(BaseCompressionMethod):
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ scores = maybe_execute_in_stream(
+ query_aware_key_scores,
+ q,
+ k,
+ context.cu_seqlens_q,
+ context.cu_seqlens_k,
+ w=32,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
+ )
+ for bq in [32, 64]
+ for bk in [32, 64]
+ for num_warps in [4, 8]
+ for num_stages in [3, 4]
+ ],
+ key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
+ cache_results=True,
+)
+@triton.jit
+def _lse_and_store_logits_kernel(
+ Q,
+ K,
+ cu_q,
+ cu_k,
+ w_b, # int32 pointers
+ out_m,
+ out_S, # [B, Hk, ROWS_MAX] float32
+ LOGITS, # [Nk, Hk, ROWS_MAX] float32
+ sm_scale, # float
+ QUERY_GROUP_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ STRIDE_Q_NQ,
+ STRIDE_Q_HQ,
+ STRIDE_K_NK,
+ STRIDE_K_HK,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ ROWS_MAX,
+):
+ # program ids
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ rid = tl.program_id(2) # row-tile id
+ # batch segment bounds
+ q_end = tl.load(cu_q + b + 1)
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ q_win_beg = q_end - win
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ # rows for this (b,hk)
+ rows_b = win * QUERY_GROUP_SIZE
+ row0 = rid * BLOCK_Q
+ if row0 >= rows_b:
+ return
+
+ # exp(x) = exp2(x * 1/ln2)
+ qk_scale = sm_scale * 1.4426950408889634
+
+ offs_qrow = row0 + tl.arange(0, BLOCK_Q)
+ row_mask = offs_qrow < rows_b
+
+ # map row -> (q_idx, hq_local)
+ hq_local = offs_qrow % QUERY_GROUP_SIZE
+ q_off = offs_qrow // QUERY_GROUP_SIZE
+ q_idx = q_win_beg + q_off
+ hq_glob = hk * QUERY_GROUP_SIZE + hq_local
+
+ offs_d = tl.arange(0, D)
+
+ q_ptrs = (
+ Q
+ + q_idx[:, None] * STRIDE_Q_NQ
+ + hq_glob[:, None] * STRIDE_Q_HQ
+ + offs_d[None, :]
+ )
+ q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+ m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
+ S = tl.zeros([BLOCK_Q], dtype=tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
+
+ s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
+ s = tl.where(kmask[None, :], s, -float("inf"))
+
+ # store into LOGITS[nk, hk, row] -> [BK, BQ]
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
+
+ # log2 streaming LSE update
+ cur_max = tl.max(s, 1) # [BQ]
+ n_m = tl.maximum(m, cur_max)
+ rescale = tl.math.exp2(m - n_m)
+ S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
+ m = n_m
+
+ # store m,S for these rows
+ m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
+ tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
+
+
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
+ for bq in [16, 32, 64]
+ for bk in [32, 64, 128]
+ ],
+ key=["HK", "HQ"],
+ cache_results=True,
+)
+@triton.jit
+def _scores_from_logits_kernel(
+ cu_k,
+ w_b,
+ in_m,
+ in_S, # [B, Hk, ROWS_MAX] f32
+ LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits
+ OUT, # [Nk, Hk] f32
+ #
+ QUERY_GROUP_SIZE: tl.constexpr,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ #
+ DO_POOL: tl.constexpr, # set True to enable in-place avg pool
+ KPOOL: tl.constexpr, # kernel size for avg pool (stride=1)
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ rows_b = win * QUERY_GROUP_SIZE
+
+ # === scores over computed region ===
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ scores = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for row0 in tl.range(0, rows_b, BLOCK_Q):
+ r_idx = row0 + tl.arange(0, BLOCK_Q)
+ rmask = r_idx < rows_b
+
+ # load m, S for rows
+ m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ m = tl.load(
+ m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
+ mask=rmask,
+ other=-float("inf"),
+ )
+ S = tl.load(
+ S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
+ )
+
+ valid_row = S > 0
+ m = tl.where(valid_row, m, 0.0)
+ S = tl.where(valid_row, S, 1.0)
+
+ # load stored logits^T: [BK, BQ]
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ s_T = tl.load(
+ log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
+ ) # [BK, BQ]
+
+ # probs^T = exp2(s_T - m) / S, sum over rows
+ probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
+ probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
+
+ scores += tl.sum(probs_T, 1) # [BK]
+
+ if DO_POOL and (KPOOL > 1):
+ i = tl.arange(0, BLOCK_K)[:, None]
+ j = tl.arange(0, BLOCK_K)[None, :]
+ band = (j <= i) & ((i - j) < KPOOL)
+ band = band & kmask[None, :]
+ # sum within band
+ sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1) # [BK]
+ denom = tl.sum(band, 1).to(tl.float32) # [BK]
+ denom = tl.where(denom > 0, denom, 1.0)
+ scores = sums / denom
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, scores, mask=kmask)
+
+ pad_beg = k_eff_end
+ pad_end = k_end
+ if pad_end > pad_beg:
+ for ks in tl.range(pad_beg, pad_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < pad_end
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(
+ out_ptrs, tl.full([BLOCK_K], float("inf"), dtype=tl.float32), mask=kmask
+ )
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue(
+ OUT, # [Nk, Hk], float32
+ cu_k,
+ w_b, # [B+1], [B] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ EPS: tl.constexpr, # e.g., 1e-12
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if k_eff_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_eff_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var + EPS)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+def query_aware_key_scores(
+ q: torch.Tensor, # [N_q, Hq, D]
+ k: torch.Tensor, # [N_k, Hk, D]
+ cu_seqlens_q: torch.Tensor, # [B+1], int32
+ cu_seqlens_k: torch.Tensor, # [B+1], int32
+ w: torch.Tensor | int, # [B], int32
+ sm_scale: float = None, # defaults to 1/sqrt(D)
+ *,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+ normalize: bool = False,
+) -> Optional[torch.Tensor]:
+ assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
+ device = q.device
+ N_q, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ B = cu_seqlens_q.numel() - 1
+ assert B == cu_seqlens_k.numel() - 1
+
+ G = Hq // Hk
+ if type(w) is int:
+ max_w = w
+ w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
+ else:
+ max_w = int(w.max().item())
+ assert w.numel() == B
+ ROWS_MAX = max_w * G
+ if ROWS_MAX == 0:
+ return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+
+ out = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+
+ # strides
+ STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
+ STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
+ STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
+ STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
+ STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
+ STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
+
+ def grid(META):
+ return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
+
+ _lse_and_store_logits_kernel[grid](
+ q,
+ k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ sm_scale,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ D=D,
+ STRIDE_Q_NQ=STRIDE_Q_NQ,
+ STRIDE_Q_HQ=STRIDE_Q_HQ,
+ STRIDE_K_NK=STRIDE_K_NK,
+ STRIDE_K_HK=STRIDE_K_HK,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ ROWS_MAX=ROWS_MAX,
+ )
+
+ _scores_from_logits_kernel[(B, Hk)](
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ out,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ STRIDE_OUT_NK=STRIDE_OUT_NK,
+ STRIDE_OUT_HK=STRIDE_OUT_HK,
+ DO_POOL=True,
+ KPOOL=5,
+ )
+ if normalize:
+ _zscore_per_batch_epilogue[(B,)](
+ out,
+ cu_seqlens_k,
+ w,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK=Hk,
+ EPS=1e-12,
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ accum_scores.mul_(accum_blending)
+ accum_scores.add_(out)
+ return accum_scores
+ else:
+ return out
diff --git a/vllm/compactor-vllm/src/compactor_vllm/config/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/config/constants.py b/vllm/compactor-vllm/src/compactor_vllm/config/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac943e40f261c61398ad82cb7eb4f714e0590aad
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/config/constants.py
@@ -0,0 +1,5 @@
+RESERVED_BATCH = 0
+# NOTE: Triton `tl.constexpr` is intended for use in kernel signatures/annotations.
+# Some Triton builds reject passing `tl.constexpr(...)` objects as constexpr values.
+# Keep the runtime value as a plain int and let kernel signatures declare constexpr.
+TRITON_RESERVED_BATCH = RESERVED_BATCH
diff --git a/vllm/compactor-vllm/src/compactor_vllm/config/engine_config.py b/vllm/compactor-vllm/src/compactor_vllm/config/engine_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4aca1e63c9d16de4d1f7b721868c8e0afda5fc
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/config/engine_config.py
@@ -0,0 +1,100 @@
+import os
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import List, Optional
+
+from transformers import AutoConfig
+
+
+class AttentionBackend(Enum):
+ FLASH_ATTENTION = auto()
+ COMPACTOR_TRITON = auto()
+
+
+@dataclass
+class LLMConfig:
+ """Configuration for the :class:`LLM` engine.
+ Parameters
+ ----------
+ model : str
+ Hugging Face model identifier (e.g. ``"meta-llama/Meta-Llama-3-8B"``) or
+ a local model name that can be resolved by
+ :func:`transformers.AutoConfig.from_pretrained`.
+ path : str, optional
+ Local directory containing the model weights. If ``None``, the engine
+ will attempt to resolve a local snapshot for ``model`` using
+ :func:`huggingface_hub.snapshot_download`.
+ max_num_seqs : int, default 256
+ Upper bound on the number of concurrent batches that the scheduler and
+ KV-cache manager are allowed to handle. This affects the size of the
+ page table and some internal buffers.
+ max_model_len : int, default 40960
+ Maximum context length (in tokens) that the engine will allocate KV cache
+ and CUDA graphs for. During initialization this value is clamped to
+ ``hf_config.max_position_embeddings`` for the chosen model.
+ gpu_memory_utilization : float, default 0.9
+ Fraction of the total GPU memory that may be used for KV cache and model
+ activations. Values should be in ``(0, 1]``. If this budget is too small,
+ the KV-cache manager may raise an error at warmup time due
+ to insufficient memory.
+ tensor_parallel_size : int, default 1
+ Number of tensor-parallel workers to shard the model
+ across. Must be between 1 and 8, and must evenly divide the model's
+ number of key/value heads.
+ enforce_eager : bool, default False
+ If ``True``, disable CUDA graph capture and always run the model in
+ eager mode during decoding. This reduces throughput. When ``False``,
+ the engine will capture and reuse CUDA graphs for supported
+ batch sizes and sequence lengths.
+ hf_config : transformers.AutoConfig, optional
+ Pre-loaded Hugging Face configuration for the model. If ``None``,
+ it will then be populated automatically based on ``model``.
+ eos : int, default -1
+ Primary stop token id (warmup / single-id paths). If ``-1``, the
+ :class:`LLM` constructor fills this and :attr:`eos_token_ids` from the
+ tokenizer.
+ eos_token_ids : list of int, optional
+ All token ids that terminate generation (e.g. HF tokenizers may expose
+ ``eos_token_id`` as a list for chat models). If ``None``, inferred in
+ :class:`LLM` from the tokenizer and model type.
+ kvcache_page_size : int, default 128
+ Number of tokens stored in a single KV-cache page. Smaller pages improve
+ allocation flexibility but increase page-table overhead; larger pages
+ reduce overhead but have coarser granularity.
+ leverage_sketch_size : int, default 48
+ Sketch dimension used by the Compactor leverage-score estimator.
+ attention_backend : AttentionBackend, default AttentionBackend.COMPACTOR_TRITON
+ Attention implementation to use. ``COMPACTOR_TRITON`` selects the custom
+ Triton kernels used by Compactor; ``FLASH_ATTENTION`` selects the
+ FlashAttention3 varlen backend. The COMPACTOR_TRITON tends to be faster
+ for longer sequence lengths, while FA3 is faster at shorter lengths.
+ """
+
+ model: str
+ path: Optional[str] = None
+ nccl_port: Optional[int] = 1218
+ max_num_seqs: int = 256
+ max_model_len: int = 40960
+ gpu_memory_utilization: float = 0.9
+ tensor_parallel_size: int = 1
+ enforce_eager: bool = False
+ hf_config: AutoConfig | None = None
+ eos: int = -1
+ eos_token_ids: Optional[List[int]] = None
+ kvcache_page_size: int = 128
+ leverage_sketch_size: int = 48
+ attention_backend: AttentionBackend = AttentionBackend.COMPACTOR_TRITON
+ show_progress_bar: bool = True
+
+ def __post_init__(self):
+ if self.path is not None and not os.path.isdir(self.path):
+ raise NotADirectoryError(f"Engine config dir {self.path} does not exist")
+ if self.tensor_parallel_size <= 0 or self.tensor_parallel_size > 8:
+ assert 1 <= self.tensor_parallel_size <= 8
+ raise ValueError("tensor_parallel_size must be >= 1 and <= 8")
+ if self.hf_config is None:
+ self.hf_config = AutoConfig.from_pretrained(self.model)
+ self.max_model_len = min(
+ self.max_model_len, self.hf_config.max_position_embeddings
+ )
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/config/sampling_params.py b/vllm/compactor-vllm/src/compactor_vllm/config/sampling_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..8202ad67d07ed082822eedcc926e3fa85cf40234
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/config/sampling_params.py
@@ -0,0 +1,11 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class SamplingParams:
+ temperature: float = 1.0
+ max_new_tokens: int = 256
+
+ def __post_init__(self):
+ if self.temperature < 0:
+ raise ValueError("Temperature cannot be negative")
diff --git a/vllm/compactor-vllm/src/compactor_vllm/core/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py b/vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b37b77b3dd5abdfade3a9d15f9d6041bfb869a
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py
@@ -0,0 +1,404 @@
+import atexit
+import inspect
+import logging
+from typing import Any, List, Optional, Union
+
+import torch.multiprocessing as mp
+from compactor_vllm.compression.compression_config import (
+ BatchCompressionParams,
+ SequenceCompressionParams,
+)
+from compactor_vllm.config.engine_config import LLMConfig
+from compactor_vllm.config.sampling_params import SamplingParams
+from compactor_vllm.core.model_runner import ModelRunner
+from compactor_vllm.models import MODEL_REGISTRY
+from compactor_vllm.utils.sequence import Sequence
+from transformers import AutoTokenizer
+
+logger = logging.getLogger(__name__)
+
+PromptLike = Union[str, List[int]]
+
+
+def _infer_stop_token_ids(tokenizer, hf_config) -> list[int]:
+ """
+ Build the set of token ids that should end generation.
+
+ Newer HF chat tokenizers often expose ``eos_token_id`` as a *list* of ids.
+ The engine must not compare generated ids to that list as a single ``int``;
+ see :attr:`LLMConfig.eos_token_ids` and decode-time ``torch.isin``.
+
+ Qwen chat uses ```` (im_end) as the assistant turn boundary; include it
+ when present in ``additional_special_tokens`` / ``added_tokens_encoder``. We
+ avoid loose substring matches like ``\"end\"`` that can tag unrelated tokens.
+ """
+ raw = tokenizer.eos_token_id
+ ids: list[int] = []
+ if isinstance(raw, (list, tuple)):
+ ids.extend(int(x) for x in raw)
+ elif raw is not None:
+ ids.append(int(raw))
+ unk_id = getattr(tokenizer, "unk_token_id", None)
+
+ def _maybe_add_tid(tid: int) -> None:
+ if not isinstance(tid, int) or tid < 0:
+ return
+ if unk_id is not None and tid == unk_id:
+ return
+ if tid not in ids:
+ ids.append(tid)
+
+ model_type = getattr(hf_config, "model_type", None)
+ if model_type in ("qwen2", "qwen3", "qwen2_moe", "qwen3_moe"):
+ enc = getattr(tokenizer, "added_tokens_encoder", None)
+ if isinstance(enc, dict):
+ for key, tid in enc.items():
+ if isinstance(key, str) and "im_end" in key:
+ _maybe_add_tid(int(tid))
+ for extra in getattr(tokenizer, "additional_special_tokens", []) or []:
+ if not isinstance(extra, str) or "im_end" not in extra:
+ continue
+ try:
+ tid = tokenizer.convert_tokens_to_ids(extra)
+ except (TypeError, ValueError, KeyError):
+ continue
+ _maybe_add_tid(tid)
+
+ if not ids:
+ raise ValueError(
+ "Could not infer stop token ids from the tokenizer; set "
+ "LLMConfig(eos_token_ids=[...]) explicitly."
+ )
+ return ids
+
+
+def _merge_apply_chat_template_kwargs(
+ tokenizer,
+ user_kwargs: Optional[dict[str, Any]],
+) -> dict[str, Any]:
+ """
+ Merge user kwargs with defaults for HF chat templates that support them.
+
+ Qwen3 (and similar) instruct models expect `add_generation_prompt=True` so
+ the first generated token continues the assistant turn; without it, output
+ can repeat punctuation / template fragments. `enable_thinking=False` avoids
+ the Qwen3 reasoning channel when the tokenizer supports it.
+ """
+ out = dict(user_kwargs or {})
+ try:
+ sig = inspect.signature(tokenizer.apply_chat_template)
+ except (TypeError, ValueError):
+ return out
+ if "add_generation_prompt" in sig.parameters and "add_generation_prompt" not in out:
+ out["add_generation_prompt"] = True
+ if "enable_thinking" in sig.parameters and "enable_thinking" not in out:
+ out["enable_thinking"] = False
+ return out
+
+
+def _runner_entry(config: LLMConfig, rank: int, evt):
+ runner = None
+ try:
+ runner = ModelRunner(config, rank, evt)
+ runner.loop()
+ except Exception as e:
+ logging.exception(f"Rank {rank}: {repr(e)}")
+ finally:
+ if runner is not None:
+ runner.exit()
+
+
+class LLMEngine:
+ """High-level engine coordinating model runners and scheduling"""
+
+ def __init__(self, config: LLMConfig):
+ self.config = config
+ if self.config.hf_config.model_type not in MODEL_REGISTRY:
+ raise ValueError(f"Unknown model {self.config.model}")
+ if config.path is None:
+ from huggingface_hub import snapshot_download
+
+ self.config.path = snapshot_download(
+ repo_id=config.model, local_files_only=True
+ )
+ logger.info(f"Using {self.config.model} snapshot @ {self.config.path}")
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model, use_fast=True)
+ if self.config.eos_token_ids is None:
+ if self.config.eos != -1:
+ self.config.eos_token_ids = [int(self.config.eos)]
+ else:
+ self.config.eos_token_ids = _infer_stop_token_ids(
+ self.tokenizer, self.config.hf_config
+ )
+ else:
+ self.config.eos_token_ids = [int(x) for x in self.config.eos_token_ids]
+ self.config.eos_token_ids = sorted(set(self.config.eos_token_ids))
+ if self.config.eos == -1:
+ self.config.eos = int(self.config.eos_token_ids[0])
+ else:
+ self.config.eos = int(self.config.eos)
+ if self.config.eos not in self.config.eos_token_ids:
+ self.config.eos_token_ids = sorted(
+ self.config.eos_token_ids + [self.config.eos]
+ )
+
+ self.ps = []
+ world_size = int(self.config.tensor_parallel_size)
+ self.events = []
+ if world_size > 1:
+ ctx = mp.get_context("spawn")
+ for r in range(1, world_size):
+ event = ctx.Event()
+ p = ctx.Process(
+ target=_runner_entry,
+ args=(self.config, r, event),
+ daemon=True,
+ )
+ p.start()
+ self.ps.append(p)
+ self.events.append(event)
+
+ self.master_model_runner = ModelRunner(
+ self.config, rank=0, peer_events=self.events
+ )
+ atexit.register(self.exit)
+
+ def exit(self):
+ if getattr(self, "_exited", False):
+ return
+ self._exited = True
+ runner = getattr(self, "master_model_runner", None)
+ if runner is not None:
+ try:
+ runner.exit()
+ except Exception:
+ logger.exception("Failed to exit master ModelRunner cleanly")
+ for p in self.ps:
+ if p.is_alive():
+ p.terminate()
+ p.join(timeout=1.0)
+ if hasattr(self, "events"):
+ self.events.clear()
+
+ def tokenize_prompt(self, prompt: PromptLike, **tokenizer_kwargs) -> List[int]:
+ """
+ Turn a raw prompt into token IDs.
+ """
+ if isinstance(prompt, str):
+ return self.tokenizer(prompt, **tokenizer_kwargs)["input_ids"]
+ else:
+ return list(prompt)
+
+ def detokenize_prompt(
+ self, sequences: List[Sequence], **detokenizer_kwargs
+ ) -> List[str]:
+ """
+ Turn completed Sequences into strings.
+ """
+ defaults: dict[str, Any] = {"skip_special_tokens": True}
+ merged = {**defaults, **detokenizer_kwargs}
+ return self.tokenizer.batch_decode(
+ [s.completion_token_ids for s in sequences], **merged
+ )
+
+ def _build_sequences(
+ self,
+ prompts: List[PromptLike] | PromptLike,
+ sampling_params: SamplingParams | List[SamplingParams],
+ per_sequence_compression_params: Optional[
+ SequenceCompressionParams | List[SequenceCompressionParams]
+ ] = None,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ ) -> List[Sequence]:
+ """
+ Build Sequence objects from prompts, sampling params, and optional
+ per-sequence compression parameters.
+ """
+ tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
+
+ if not isinstance(prompts, list):
+ prompts = [prompts]
+
+ if isinstance(sampling_params, SamplingParams):
+ sampling_params_list: List[SamplingParams] = [sampling_params] * len(
+ prompts
+ )
+ else:
+ sampling_params_list = sampling_params
+ assert len(sampling_params_list) == len(prompts), (
+ "sampling_params list must match prompts length"
+ )
+ if per_sequence_compression_params is None:
+ compression_params_list: List[SequenceCompressionParams] = [
+ SequenceCompressionParams(1.0) for _ in prompts
+ ]
+ elif isinstance(per_sequence_compression_params, SequenceCompressionParams):
+ compression_params_list = [per_sequence_compression_params] * len(prompts)
+ else:
+ # list-like
+ assert len(per_sequence_compression_params) == len(prompts), (
+ "per_sequence_compression_params list must match prompts length"
+ )
+ compression_params_list = list(per_sequence_compression_params)
+
+ seqs: List[Sequence] = []
+ for prompt, sparams, cparams in zip(
+ prompts, sampling_params_list, compression_params_list
+ ):
+ token_ids = self.tokenize_prompt(prompt, **tokenizer_kwargs)
+ if cparams.protected_first_tokens + cparams.protected_last_tokens >= len(token_ids):
+ cparams.compression_ratio = 1.0
+ seqs.append(
+ Sequence(
+ prompt_token_ids=token_ids,
+ sampling_params=sparams,
+ compression_params=cparams,
+ )
+ )
+ return seqs
+
+ def generate(
+ self,
+ prompts: List[PromptLike] | PromptLike,
+ sampling_params: SamplingParams | List[SamplingParams],
+ batch_compression_params: BatchCompressionParams,
+ *,
+ per_sequence_compression_params: Union[
+ List[SequenceCompressionParams], SequenceCompressionParams
+ ] = None,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ detokenizer_kwargs: Optional[dict[str, Any]] = None,
+ return_sequences: bool = False,
+ ) -> List[str] | tuple[List[str], List[Sequence]]:
+ """
+ Accept prompts and return completed Sequences.
+ Args:
+ :param prompts:
+ Single prompt or list of prompts, each either a raw text prompt,
+ or pre-tokenized input IDs.
+ :param sampling_params:
+ A single SamplingParams for all prompts in this batch or a list of
+ SamplingParams with the same length as ``prompts``.
+ :param batch_compression_params:
+ Compression settings for this batch.
+ :param per_sequence_compression_params:
+ Per-sequence compression parameters, including the compression
+ ratio to be applied and the size of the protected regions of the
+ sequence (how many start tokens and end tokens to keep uncompressed).
+ If a SequenceCompressionParams instance, the same params will be
+ applied to all sequences in this batch; if a list is provided,
+ each SequenceCompressionParams will be attached to the corresponding
+ prompt in the batch.
+ :param tokenizer_kwargs:
+ Extra kwargs forwarded to ``tokenizer(...)`` when tokenizing
+ string prompts.
+ :param detokenizer_kwargs:
+ Passed through to `tokenizer.batch_decode`.
+ :param return_sequences:
+ Whether to return sequence objects or not
+ Returns:
+ :return List[Sequence]:
+ One Sequence per input prompt, with `completion_token_ids`
+ filled in after generation.
+ """
+ tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
+ detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
+ seqs = self._build_sequences(
+ prompts,
+ sampling_params=sampling_params,
+ per_sequence_compression_params=per_sequence_compression_params,
+ tokenizer_kwargs=tokenizer_kwargs,
+ )
+ self.master_model_runner.generate(seqs, batch_compression_params)
+ output_strings = self.detokenize_prompt(seqs, **detokenizer_kwargs)
+ if return_sequences:
+ return output_strings, seqs
+ return output_strings
+
+ def generate_chat(
+ self,
+ messages_batch: List[List[dict]],
+ sampling_params: SamplingParams | List[SamplingParams],
+ batch_compression_params: BatchCompressionParams,
+ per_sequence_compression_params: Union[
+ SequenceCompressionParams, List[SequenceCompressionParams]
+ ],
+ *,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ detokenizer_kwargs: Optional[dict[str, Any]] = None,
+ return_sequences: bool = False,
+ ) -> List[str] | tuple[List[str], List[Sequence]]:
+ """
+ Convenience API for chat-style prompts using HF `apply_chat_template`.
+ Args:
+ :param messages_batch:
+ List of conversations, where each conversation is a list of
+ message dicts like:
+ {"role": "system" | "user" | "assistant", "content": str}
+ :param sampling_params:
+ A single SamplingParams for all prompts in this batch or a list of
+ SamplingParams with the same length as ``prompts``.
+ :param batch_compression_params:
+ Batch Level compression settings. Can set compression_method.
+ :param per_sequence_compression_params:
+ Per-sequence compression parameters, including the compression
+ ratio to be applied and the size of the protected regions of the
+ sequence (how many start tokens and end tokens to keep uncompressed).
+ If a SequenceCompressionParams instance, the same params will be
+ applied to all sequences in this batch; if a list is provided,
+ each SequenceCompressionParams will be attached to the corresponding
+ conversation in the batch.
+ :param tokenizer_kwargs:
+ Passed through to `tokenizer.apply_chat_template`.
+ :param detokenizer_kwargs:
+ Passed through to `tokenizer.batch_decode`.
+ :param return_sequences:
+ Whether to return sequence objects or not
+ Returns:
+ :return List[str] or tuple[List[str], List[Sequence]]:
+ One string per conversation.
+ """
+ prompts_token_ids: List[List[int]] = []
+ tokenizer_kwargs = _merge_apply_chat_template_kwargs(
+ self.tokenizer, tokenizer_kwargs
+ )
+ detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
+ for messages in messages_batch:
+ input_ids = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=True,
+ **tokenizer_kwargs,
+ )
+ if hasattr(input_ids, "tolist"):
+ input_ids = input_ids.tolist()
+ prompts_token_ids.append(input_ids)
+
+ return self.generate(
+ prompts_token_ids,
+ sampling_params=sampling_params,
+ batch_compression_params=batch_compression_params,
+ per_sequence_compression_params=per_sequence_compression_params,
+ tokenizer_kwargs=tokenizer_kwargs,
+ detokenizer_kwargs=detokenizer_kwargs,
+ return_sequences=return_sequences,
+ )
+
+ def generate_from_sequences(
+ self,
+ seqs: List[Sequence],
+ batch_compression_params: BatchCompressionParams,
+ ) -> List[Sequence]:
+ """
+ Args:
+ :param seqs:
+ List of Sequence instances
+ :param batch_compression_params:
+ Compression settings.
+
+ Returns:
+ :return List[Sequence]:
+ Same list, mutated in-place with completions.
+ """
+ self.master_model_runner.generate(seqs, batch_compression_params)
+ return seqs
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/core/memory_manager.py b/vllm/compactor-vllm/src/compactor_vllm/core/memory_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..474e1ffac98dec05156d4aee6a4d51322abb9c3c
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/core/memory_manager.py
@@ -0,0 +1,182 @@
+import logging
+from typing import Iterable, List, Optional
+
+import torch
+import torch.distributed as dist
+from compactor_vllm.config.engine_config import LLMConfig
+from compactor_vllm.kv_cache.page_table import KVAllocationStatus, PagedKVCache
+from torch import nn
+
+logger = logging.getLogger(__name__)
+
+
+class KVCacheManager:
+ def __init__(self, rank: int, config: LLMConfig):
+ super().__init__()
+ hf_config = config.hf_config
+ self.rank = rank
+ self.gpu_frac = config.gpu_memory_utilization
+ self.page_size = config.kvcache_page_size
+ self.world_size = config.tensor_parallel_size
+ self.max_num_batches = config.max_num_seqs
+ self.max_model_len = config.max_model_len
+ self.num_layers = hf_config.num_hidden_layers
+ self.model_dtype = hf_config.torch_dtype
+ self.head_dim = getattr(hf_config, "head_dim", None)
+ self.max_pages_per_batch = (
+ self.max_model_len + self.page_size - 1
+ ) // self.page_size
+ self.num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
+ assert hf_config.num_key_value_heads % dist.get_world_size() == 0, (
+ "world size needs to divide num_kv_heads"
+ )
+
+ self.num_pages = None
+ self.paged_cache: Optional[PagedKVCache] = None
+ self.max_batched_tokens = None
+
+ self.seq_id_to_batch = {}
+
+ def allocate_sequences(
+ self, seq_ids: List[int], max_positions: List[int]
+ ) -> (bool, Optional[torch.Tensor]):
+ batch_mapping = []
+ for seq_id, len_to_alloc in zip(seq_ids, max_positions):
+ if seq_id not in self.seq_id_to_batch:
+ batch_id = self.paged_cache.new_batch()
+ if batch_id is None:
+ logger.warning("Failed to allocate batch!")
+ return False, None
+ self.seq_id_to_batch[seq_id] = int(batch_id)
+ batch_mapping.append(self.seq_id_to_batch[seq_id])
+ if (
+ alloc_status := self.paged_cache.reserve_tokens(
+ self.seq_id_to_batch[seq_id], len_to_alloc
+ )
+ ) != KVAllocationStatus.SUCCESS:
+ logger.warning(f"Failed to allocate pages ({alloc_status})!")
+ return False, None
+ batch_mapping = torch.as_tensor(batch_mapping, dtype=torch.int32, device="cuda")
+ return True, batch_mapping
+
+ def free_sequences(self, seq_ids: Iterable[int]):
+ for seq_id in seq_ids:
+ global_batch_id = self.seq_id_to_batch.pop(seq_id, None)
+ self.paged_cache.free_batch(global_batch_id)
+
+ def init_cache(self, model: nn.Module):
+ self.num_pages = self.get_num_pages(self.gpu_frac, self.max_pages_per_batch)
+ self.paged_cache = PagedKVCache(
+ num_layers=self.num_layers,
+ H_kv=self.num_kv_heads,
+ head_dim=self.head_dim,
+ page_size=self.page_size,
+ num_pages=int(self.num_pages),
+ max_num_batches=self.max_num_batches,
+ device=f"cuda:{self.rank}",
+ dtype=self.model_dtype,
+ max_logical_pages_per_head=int(self.max_pages_per_batch),
+ )
+ self._assign_cache_to_layers(model)
+
+ def _assign_cache_to_layers(self, model) -> None:
+ for layer_index, layer in enumerate(model.model.layers):
+ attn = layer.self_attn.attn
+ k, v, pt, bh = self.paged_cache.layer_slices(layer_index)
+ attn.k_cache = k
+ attn.v_cache = v
+ attn.page_table = pt
+ attn.bh_seq_lens = bh
+ attn.page_size = self.page_size
+
+ def get_num_pages(self, frac: float, n_logical_pages_max: int):
+ free, total = torch.cuda.mem_get_info()
+ used = total - free
+ stats = torch.cuda.memory_stats()
+ peak = int(stats["allocated_bytes.all.peak"])
+ current = int(stats["allocated_bytes.all.current"])
+ bytes_for_kv_budget = int(total * frac * 0.9) - used - peak + current
+
+ if bytes_for_kv_budget <= 0:
+ raise RuntimeError(
+ f"Insufficient memory for KV cache."
+ f"Try increasing gpu_memory_utilization (currently {frac:.2f})."
+ )
+ # page_table[L, B, H_kv, N_LOGICAL_PAGES_MAX] + bh_seq_lens[L, B, H_kv]
+ int32_sz = torch.empty((), dtype=torch.int32).element_size() # 4
+ page_table_bytes_per_layer = (
+ self.max_num_batches
+ * self.num_kv_heads
+ * n_logical_pages_max
+ * int32_sz # page_table
+ + self.max_num_batches * self.num_kv_heads * int32_sz
+ )
+ total_page_table_bytes = self.num_layers * page_table_bytes_per_layer
+ kv_bytes_net = bytes_for_kv_budget - total_page_table_bytes
+ if kv_bytes_net <= 0:
+ raise RuntimeError(
+ "page-table footprint exceeds KV cache budget. "
+ f"reduce max_num_seqs ({self.max_num_batches}) "
+ f"or increase kv_cache_mem_fraction (currently {frac:.2f})."
+ )
+ dtype_sz = torch.empty((), dtype=self.model_dtype).element_size()
+ bytes_per_page_across_layers = self.num_layers * (
+ 2 * self.page_size * self.head_dim * dtype_sz
+ )
+ return max(1, kv_bytes_net // bytes_per_page_across_layers)
+
+ def estimate_max_batched_tokens(
+ self,
+ warmup_tokens: int,
+ bytes_used_before_warmup: int,
+ bytes_peak_after_warmup: int,
+ ) -> int:
+ """
+ Estimate the max total number of tokens that can be processed concurrently
+ without OOM.
+ """
+ assert warmup_tokens > 0, "warmup_tokens must be > 0"
+ # activation bytes per token
+ warmup_delta = max(
+ 0, int(bytes_peak_after_warmup) - int(bytes_used_before_warmup)
+ )
+ bytes_per_token = max(1, (warmup_delta + warmup_tokens - 1) // warmup_tokens)
+
+ free, total = torch.cuda.mem_get_info()
+ target = int(total * self.gpu_frac)
+ used_now = int(total - free)
+ # reserve headroom equal to the gap between peak and current allocations seen so far
+ stats = torch.cuda.memory_stats()
+ peak_cur = int(stats.get("allocated_bytes.all.peak", 0))
+ cur_now = int(stats.get("allocated_bytes.all.current", 0))
+ cushion = max(0, peak_cur - cur_now)
+
+ activation_budget = int(max(0, target - used_now - cushion) * 0.95)
+ max_tokens_per_batch = activation_budget // bytes_per_token
+ max_tokens_in_cache = (self.num_pages * self.page_size) // self.num_kv_heads
+ # round to lower multiple of page size
+ max_tokens_per_batch = (max_tokens_per_batch // self.page_size) * self.page_size
+ max_tokens_in_cache = (max_tokens_in_cache // self.page_size) * self.page_size
+ self.max_batched_tokens = min(max_tokens_in_cache, max_tokens_per_batch)
+ return self.max_batched_tokens
+
+ @property
+ def num_free_batches(self) -> int:
+ return len(self.paged_cache.free_batches)
+
+ @property
+ def num_free_pages(self) -> int:
+ return min(len(fp) for fp in self.paged_cache.free_pages)
+
+ def reclaim_pages(
+ self,
+ seq_ids_to_reclaim: Iterable[int],
+ future_reserved_buffer: List[int] | torch.Tensor,
+ ) -> int:
+ approximate_bytes_freed = 0
+ for i, seq_id in enumerate(seq_ids_to_reclaim):
+ batch_idx = self.seq_id_to_batch[seq_id]
+ approximate_bytes_freed += self.paged_cache.reclaim_pages(
+ batch_idx, future_reserved_buffer[i]
+ )
+ return approximate_bytes_freed
diff --git a/vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py b/vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0633ea3125e8f381d6f2ab7dfcd3d2dd6c09e36
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py
@@ -0,0 +1,584 @@
+import atexit
+import logging
+import inspect
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+from compactor_vllm.attention.sparse_decode_kernel import num_splits_heuristic
+from compactor_vllm.compression.compression_config import BatchCompressionParams
+from compactor_vllm.config.constants import RESERVED_BATCH
+from compactor_vllm.config.engine_config import AttentionBackend, LLMConfig
+from compactor_vllm.core.memory_manager import KVCacheManager
+from compactor_vllm.core.scheduler import Scheduler
+from compactor_vllm.layers.sampler import Sampler
+from compactor_vllm.models import MODEL_REGISTRY
+from compactor_vllm.utils.arguments import (
+ DecodeBatchArguments,
+ DecodeBatchOutput,
+ PackedTensorArguments,
+ PrefillBatchArguments,
+)
+from compactor_vllm.utils.context import CompressionContext, reset_context, set_context
+from compactor_vllm.utils.sequence import Sequence
+from torch.multiprocessing import Event
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+class ModelRunner:
+ """Per-rank execution loop. Manages model, sampler, KV cache, and warmup"""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ rank: int,
+ batch_ready: Optional[Event] = None,
+ peer_events: List[Event] = None,
+ ):
+ self.rank = rank
+ self.config = config
+ _dev = torch.device(f"cuda:{rank}")
+ assert config.eos_token_ids is not None and len(config.eos_token_ids) > 0, (
+ "LLMConfig.eos_token_ids must be set (filled in LLMEngine from tokenizer)."
+ )
+ self._stop_token_ids = torch.tensor(
+ config.eos_token_ids, dtype=torch.int64, device=_dev
+ )
+ hf_config = config.hf_config
+ self.enforce_eager = config.enforce_eager
+ self.world_size = config.tensor_parallel_size
+ self.leverage_sketch_size = config.leverage_sketch_size
+ self.show_progress_bar = config.show_progress_bar
+ self.max_num_batches = config.max_num_seqs
+ self.max_model_len = config.max_model_len
+ self.num_layers = hf_config.num_hidden_layers
+ self.model_dtype = hf_config.torch_dtype
+ self.head_dim = getattr(hf_config, "head_dim", None)
+
+ init_kwargs = {}
+ if "device_id" in inspect.signature(dist.init_process_group).parameters:
+ init_kwargs["device_id"] = torch.device(f"cuda:{rank}")
+ dist.init_process_group(
+ "nccl",
+ f"tcp://localhost:{config.nccl_port}",
+ world_size=self.world_size,
+ rank=rank,
+ **init_kwargs,
+ )
+ torch.cuda.set_device(rank)
+ default_dtype = torch.get_default_dtype()
+ torch.set_default_dtype(hf_config.torch_dtype)
+ torch.set_default_device("cuda")
+ model_type = hf_config.model_type
+ self.model = MODEL_REGISTRY[model_type](hf_config)
+ self.model.load_model(
+ config.path, use_tqdm=self.is_master and self.show_progress_bar
+ )
+ self.sampler = Sampler()
+
+ pre_warmup_mem = torch.cuda.memory_stats().get("allocated_bytes.all.current", 0)
+ self.warmup(
+ num_warmup_tokens=self.max_model_len,
+ attention_backend=AttentionBackend.FLASH_ATTENTION,
+ )
+ post_warmup_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
+
+ self.kv_manager = KVCacheManager(rank, config)
+ self.kv_manager.init_cache(self.model)
+
+ self.store_stream: Optional[torch.cuda.Stream] = torch.cuda.Stream()
+ torch.set_default_device("cpu")
+ torch.set_default_dtype(default_dtype)
+
+ self.batch_ready = batch_ready
+ self.peer_events = peer_events if peer_events is not None else []
+ self.captured_graphs = {}
+ self.min_captured_len = {}
+ self.max_batched_tokens = self.kv_manager.estimate_max_batched_tokens(
+ self.max_model_len, pre_warmup_mem, post_warmup_peak
+ )
+ if self.is_master:
+ logger.info(f"Estimated max batched tokens of {self.max_batched_tokens}")
+ if self.config.attention_backend == AttentionBackend.COMPACTOR_TRITON:
+ self.warmup(
+ num_warmup_tokens=self.max_model_len,
+ attention_backend=AttentionBackend.COMPACTOR_TRITON,
+ )
+
+ if not self.enforce_eager:
+ bs = [1 << i for i in range(self.max_num_batches.bit_length())]
+ for bs in (
+ tqdm(bs, desc="Capturing CUDA Graphs")
+ if self.is_master and self.show_progress_bar
+ else bs
+ ):
+ for seq_len in [1024, 4096, 8192, 16384]:
+ self.capture_cudagraph(bs, seq_len)
+
+ self.packed_args = PackedTensorArguments(
+ rank=self.rank,
+ max_batched_tokens=self.max_batched_tokens,
+ config=self.config,
+ )
+ atexit.register(self.exit)
+
+ @torch.inference_mode()
+ def warmup(self, num_warmup_tokens: int, attention_backend: AttentionBackend):
+ if self.rank == 0:
+ if attention_backend == AttentionBackend.COMPACTOR_TRITON:
+ backend_name = "Compactor Triton"
+ else:
+ backend_name = "Flash"
+ logger.info(f"Warming up with {backend_name} Attention Backend")
+ device = torch.device(f"cuda:{self.rank}")
+ input_ids = torch.tensor(
+ [self.config.eos] * num_warmup_tokens, device=device, dtype=torch.int64
+ )
+ positions = torch.arange(num_warmup_tokens, device=device, dtype=torch.int64)
+ cu_seqlens_q = torch.tensor(
+ [0, num_warmup_tokens], device=device, dtype=torch.int32
+ )
+ cu_seqlens_k = torch.tensor(
+ [0, num_warmup_tokens], device=device, dtype=torch.int32
+ )
+ if attention_backend == AttentionBackend.COMPACTOR_TRITON:
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ [-1], [num_warmup_tokens]
+ )
+ assert success
+ else:
+ batch_mapping = None
+ set_context(
+ is_prefill=True,
+ do_compression=False,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=num_warmup_tokens,
+ max_seqlen_k=num_warmup_tokens,
+ batch_mapping=batch_mapping,
+ attention_backend=attention_backend,
+ )
+ for _ in range(2):
+ torch.cuda.reset_peak_memory_stats()
+ self.model.compute_logits(self.model(input_ids, positions))
+ dist.barrier()
+ if attention_backend == AttentionBackend.COMPACTOR_TRITON:
+ self.kv_manager.paged_cache.bh_seq_lens.index_fill_(
+ 1, batch_mapping.to(torch.long), 0
+ )
+ reset_context()
+ if attention_backend == AttentionBackend.COMPACTOR_TRITON:
+ self.kv_manager.free_sequences([-1])
+
+ def exit(self):
+ if getattr(self, "_exited", False):
+ return
+ self._exited = True
+ try:
+ if hasattr(self, "captured_graphs"):
+ self.captured_graphs.clear()
+ finally:
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+ def loop(self):
+ while True:
+ if self.batch_ready.wait(1.0):
+ self._process_batches_peer()
+
+ @torch.inference_mode()
+ def run_prefill(
+ self, prefill_args: PrefillBatchArguments, batch_mapping: torch.Tensor
+ ):
+ assert prefill_args.B > 0 and prefill_args.N > 0
+ max_bh_len = (
+ self.kv_manager.paged_cache.bh_seq_lens.index_select(1, index=batch_mapping)
+ .max()
+ .item()
+ )
+ compression_context = CompressionContext(
+ compression_method=prefill_args.compression_method,
+ compression_chunk_size=prefill_args.compression_chunk_size,
+ batch_tokens_to_retain=prefill_args.batch_tokens_to_retain,
+ max_tokens_to_retain=prefill_args.max_tokens_to_retain,
+ context_lens=prefill_args.context_lens.tolist(),
+ PHI=prefill_args.PHI,
+ sketch_dimension=self.leverage_sketch_size,
+ protected_first_tokens=prefill_args.protected_first,
+ protected_last_tokens=prefill_args.protected_last,
+ compression_ratio=prefill_args.compression_ratio,
+ )
+ set_context(
+ is_prefill=True,
+ do_compression=prefill_args.do_compression,
+ cu_seqlens_q=prefill_args.cu_seqlens_q,
+ cu_seqlens_k=prefill_args.cu_seqlens_k,
+ max_seqlen_q=prefill_args.max_seqlen_q,
+ max_seqlen_k=prefill_args.max_seqlen_k,
+ batch_mapping=batch_mapping,
+ max_bh_len=max_bh_len,
+ compression_context=compression_context,
+ STORE_STREAM=self.store_stream,
+ attention_backend=self.config.attention_backend,
+ )
+ logits = self.model.compute_logits(
+ self.model(prefill_args.input_ids, prefill_args.positions)
+ )
+ reset_context()
+ return logits
+
+ def maybe_broadcast(self, tensor: torch.Tensor):
+ if self.world_size > 1:
+ return dist.broadcast(tensor, src=0)
+ return None
+
+ def maybe_release_peers(self, do_release=False):
+ if self.world_size > 1:
+ if self.is_master:
+ if do_release:
+ for event in self.peer_events:
+ event.clear()
+ dist.barrier()
+ else:
+ dist.barrier()
+
+ @torch.inference_mode()
+ def generate(
+ self,
+ all_sequences: List[Sequence],
+ batch_compression_params: Optional[BatchCompressionParams] = None,
+ ):
+ assert self.is_master, "generate can only be called on the master process"
+ for begin_execution_event in self.peer_events:
+ begin_execution_event.set()
+ if batch_compression_params is None:
+ batch_compression_params = BatchCompressionParams()
+ self._process_batches_master(all_sequences, batch_compression_params)
+
+ @property
+ def is_master(self):
+ return self.rank == 0
+
+ @torch.inference_mode()
+ def _process_batches_master(
+ self,
+ all_sequences: List[Sequence],
+ batch_compression_params: BatchCompressionParams,
+ ):
+ assert self.is_master
+ compression_details = f"Applying Compression Method: {batch_compression_params.compression_method}"
+ if any(seq.compression_params.compression_ratio < 1.0 for seq in all_sequences):
+ logger.info(compression_details)
+ scheduler = Scheduler(
+ all_sequences=all_sequences,
+ kv_manager=self.kv_manager,
+ use_tqdm=self.show_progress_bar,
+ )
+ decode_batch = DecodeBatchArguments()
+ decode_flags = torch.empty(2, dtype=torch.int32, device="cuda")
+ while not scheduler.is_finished():
+ sequences = scheduler.get_prefill_batch()
+ seq_ids_cpu = [seq.seq_id for seq in sequences]
+ scheduler.add_running_sequence_ids(seq_ids_cpu, update_status=True)
+ temps = torch.tensor(
+ [s.sampling_params.temperature for s in sequences],
+ dtype=torch.float32,
+ pin_memory=True,
+ ).cuda(non_blocking=True)
+ prefill_arguments = self.packed_args.build_prefill_args(
+ sequences, batch_compression_params=batch_compression_params
+ )
+ max_ctx_lens = (
+ prefill_arguments.max_new_tokens + prefill_arguments.context_lens
+ )
+
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ seq_ids_cpu, max_ctx_lens.tolist()
+ )
+ assert success, "failed to allocate pages for sequences"
+
+ logits = self.run_prefill(prefill_arguments, batch_mapping)
+ # Must match prefill `positions` dtype (int64). `context_lens` is int32
+ # from the packed buffer; using int32 here breaks RoPE indexing
+ # (`cos_sin_cache[positions]`) on CUDA for decode vs prefill.
+ positions = prefill_arguments.context_lens.to(dtype=torch.int64)
+ token_ids = self.sampler(logits, temps)
+ # Prefill KV writes + bh_seq_lens updates run on STORE_STREAM; reclaim
+ # reads bh_seq_lens on the default stream and must not race.
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+ # TODO: synchronize page counts accross dist
+ if self.world_size == 1:
+ self.kv_manager.reclaim_pages(
+ seq_ids_cpu, prefill_arguments.max_new_tokens
+ )
+ # with logging_redirect_tqdm():
+ # logger.info(
+ # f"Reclaimed {reclaimed_bytes / 1e6:.2f} MB from the KV cache"
+ # )
+
+ if scheduler.any_pending_sequences():
+ num_pending_batches = (
+ 0
+ if decode_batch.token_ids is None
+ else decode_batch.token_ids.shape[0]
+ )
+ occupancy = int((num_pending_batches + len(seq_ids_cpu)) * 0.66)
+ else:
+ occupancy = -1
+ run_decode = not scheduler.can_prefill_another_batch()
+ decode_batch = decode_batch.update(
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ prefill_arguments.seq_ids,
+ temps,
+ occupancy,
+ )
+ if self.world_size > 1:
+ decode_flags[0] = int(run_decode)
+ decode_flags[1] = occupancy
+ self.maybe_broadcast(decode_flags)
+ if not run_decode:
+ continue
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+
+ decode_output, decode_batch = self.run_decode_loop(decode_batch)
+ finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
+ decode_batch.seq_ids.tolist()
+ )
+ scheduler.record_finished_sequence_ids(
+ finished_sequence_ids, update_status=True
+ )
+ self.kv_manager.free_sequences(finished_sequence_ids)
+ self.maybe_release_peers(scheduler.is_finished())
+ scheduler.update_sequences(
+ decode_output.output_tokens.tolist(),
+ decode_output.output_seq_ids.tolist(),
+ )
+ scheduler.close()
+
+ @torch.inference_mode()
+ def _process_batches_peer(self):
+ assert not self.is_master
+ scheduler = Scheduler([], kv_manager=self.kv_manager)
+ decode_batch = DecodeBatchArguments()
+ decode_flags = torch.empty(2, dtype=torch.int32, device="cuda")
+ while self.batch_ready.is_set():
+ prefill_arguments = self.packed_args.build_prefill_args()
+
+ B = prefill_arguments.B
+ max_ctx_lens = (
+ prefill_arguments.max_new_tokens + prefill_arguments.context_lens
+ )
+
+ seq_ids_cpu = prefill_arguments.seq_ids.tolist()
+ scheduler.add_running_sequence_ids(seq_ids_cpu)
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ seq_ids_cpu, max_ctx_lens.tolist()
+ )
+ assert success, "failed to allocate pages for sequences"
+
+ self.run_prefill(prefill_arguments, batch_mapping)
+ positions = prefill_arguments.context_lens.to(dtype=torch.int64)
+ self.maybe_broadcast(decode_flags)
+ run_decode = bool(decode_flags[0].item())
+ occupancy = int(decode_flags[1].item())
+ token_ids = torch.empty(B, dtype=torch.int64, device="cuda")
+ decode_batch = decode_batch.update(
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ prefill_arguments.seq_ids,
+ None, # temps not used in peer process
+ occupancy,
+ )
+
+ if not run_decode:
+ continue
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+
+ _, decode_batch = self.run_decode_loop(decode_batch)
+ finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
+ decode_batch.seq_ids.tolist()
+ )
+ scheduler.record_finished_sequence_ids(finished_sequence_ids)
+ self.kv_manager.free_sequences(finished_sequence_ids)
+ self.maybe_release_peers()
+ scheduler.close()
+
+ @torch.inference_mode()
+ def run_decode_loop(
+ self,
+ decode_batch: DecodeBatchArguments,
+ ) -> tuple[DecodeBatchOutput, DecodeBatchArguments]:
+ if self.is_master:
+ num_stashed_batches = decode_batch.num_stashed_batches
+ tok_buffer = [
+ decode_batch.token_ids[num_stashed_batches:].to(
+ "cpu", non_blocking=True
+ )
+ ]
+ seq_buffer = [
+ decode_batch.seq_ids[num_stashed_batches:].to("cpu", non_blocking=True)
+ ]
+ while True:
+ self.maybe_broadcast(decode_batch.token_ids)
+ not_stopped = ~torch.isin(decode_batch.token_ids, self._stop_token_ids)
+ running_batches = (decode_batch.positions < decode_batch.max_ctx_lens) & (
+ not_stopped
+ )
+ decode_batch.token_ids = torch.masked_select(
+ decode_batch.token_ids, running_batches
+ )
+ decode_batch.positions = torch.masked_select(
+ decode_batch.positions, running_batches
+ )
+ decode_batch.batch_mapping = torch.masked_select(
+ decode_batch.batch_mapping, running_batches
+ )
+ decode_batch.max_ctx_lens = torch.masked_select(
+ decode_batch.max_ctx_lens, running_batches
+ )
+ decode_batch.seq_ids = torch.masked_select(
+ decode_batch.seq_ids, running_batches
+ )
+ if self.is_master:
+ decode_batch.temps = torch.masked_select(
+ decode_batch.temps, running_batches
+ )
+ num_remaining = decode_batch.token_ids.numel()
+ if (
+ num_remaining == 0
+ or num_remaining <= decode_batch.desired_batch_occupancy
+ ):
+ decode_batch.num_stashed_batches = num_remaining
+ break
+ if self.enforce_eager:
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=decode_batch.batch_mapping,
+ )
+ logits = self.model.compute_logits(
+ self.model(decode_batch.token_ids, decode_batch.positions)
+ )
+ else:
+ logits = self.run_graph_decode(
+ decode_batch.token_ids,
+ decode_batch.positions,
+ decode_batch.batch_mapping,
+ )
+
+ if self.is_master:
+ decode_batch.token_ids = self.sampler(logits, decode_batch.temps)
+ tok_buffer.append(decode_batch.token_ids.to("cpu", non_blocking=True))
+ seq_buffer.append(decode_batch.seq_ids.to("cpu", non_blocking=True))
+ decode_batch.positions += 1
+
+ if self.is_master:
+ # non_blocking D2H copies must finish before cat/tolist read CPU data.
+ torch.cuda.synchronize()
+ output = DecodeBatchOutput(
+ output_tokens=torch.cat(tok_buffer),
+ output_seq_ids=torch.cat(seq_buffer),
+ )
+ else:
+ output = DecodeBatchOutput(None, None)
+ return output, decode_batch
+
+ @torch.inference_mode()
+ def run_graph_decode(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ ):
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=batch_mapping,
+ )
+ bs = input_ids.shape[0]
+ graph_dict = self.get_cuda_graph(bs, int(positions.max()))
+ graph_dict["input_ids"][:bs] = input_ids
+ graph_dict["positions"][:bs] = positions
+ graph_dict["batch_mapping"].fill_(RESERVED_BATCH)
+ graph_dict["batch_mapping"][:bs] = batch_mapping
+ graph_dict["graph"].replay()
+ return (
+ graph_dict["logits"][:bs]
+ if graph_dict["logits"] is not None
+ else graph_dict["logits"]
+ )
+
+ @torch.inference_mode()
+ def capture_cudagraph(self, batch_size: int, max_seqlen_k: int):
+ dist.barrier()
+ device = torch.device("cuda")
+ logger.debug(
+ f"Capturing CUDA graph for batch size {batch_size} ({max_seqlen_k} tokens)"
+ )
+ _g_input_ids = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ _g_positions = torch.zeros(batch_size, dtype=torch.int64, device=device)
+ _g_logits = None
+ key_split = num_splits_heuristic(
+ batch_size * self.kv_manager.num_kv_heads,
+ max_seq_len=max_seqlen_k,
+ num_sms=torch.cuda.get_device_properties(device).multi_processor_count,
+ max_splits=12,
+ )
+
+ success, _g_batch_mapping = self.kv_manager.allocate_sequences(
+ list(range(batch_size)), [256] * batch_size
+ )
+ assert success
+
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=_g_batch_mapping,
+ key_split=key_split,
+ )
+ # warmup
+ self.model.compute_logits(self.model(_g_input_ids, _g_positions))
+ dist.barrier()
+ decode_graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(decode_graph):
+ _g_logits = self.model.compute_logits(
+ self.model(_g_input_ids, _g_positions)
+ )
+ graph_vars = {
+ "graph": decode_graph,
+ "input_ids": _g_input_ids,
+ "positions": _g_positions,
+ "batch_mapping": _g_batch_mapping,
+ "logits": _g_logits,
+ "key_split": key_split,
+ }
+ if batch_size not in self.captured_graphs:
+ self.captured_graphs[batch_size] = {}
+ self.min_captured_len[batch_size] = float("inf")
+
+ self.captured_graphs[batch_size][max_seqlen_k] = graph_vars
+ self.min_captured_len[batch_size] = min(
+ max_seqlen_k, self.min_captured_len[batch_size]
+ )
+ self.kv_manager.free_sequences(list(range(batch_size)))
+
+ def get_cuda_graph(self, batch_size: int, max_seqlen_k: int):
+ batch_size = next(x for x in self.captured_graphs.keys() if x >= batch_size)
+ batch_size_graphs = self.captured_graphs[batch_size]
+ # we want largest seq_len that is smaller than max_seqlen_k
+ best = self.min_captured_len[batch_size]
+ for seq_len in batch_size_graphs.keys():
+ if seq_len <= max_seqlen_k:
+ best = max(best, seq_len)
+ return batch_size_graphs[best]
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/core/scheduler.py b/vllm/compactor-vllm/src/compactor_vllm/core/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab80beb43279032438dbe3a668e2af724bb10ec
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/core/scheduler.py
@@ -0,0 +1,215 @@
+import time
+from typing import Iterable, List
+
+from compactor_vllm.core.memory_manager import KVCacheManager
+from compactor_vllm.utils.sequence import Sequence, SequenceStatus
+from tqdm import tqdm
+
+
+def cdiv(a, b):
+ """ceiling division"""
+ return (a + b - 1) // b
+
+
+class Scheduler:
+ """
+ Simple sequence scheduler for prefill + decode with a paged KV cache.
+ The scheduler tracks three disjoint sets of sequence IDs:
+
+ * ``pending_sequence_ids`` – sequences that have not yet been started.
+ * ``active_sequence_ids`` – sequences currently running.
+ * ``finished_sequence_ids`` – sequences that have generated all tokens.
+
+ At prefill time, :meth:`get_prefill_batch` selects a subset of pending
+ sequences that can fit into the available KV cache and per-step token
+ budget, given the constraints from the associated :class:`KVCacheManager`.
+
+ The class also handles basic bookkeeping of sequence statuses.
+
+ Args:
+ :param all_sequences:
+ Iterable of :class:`Sequence` objects to be scheduled. Each
+ sequence must have a unique ``seq_id``.
+ :param kv_manager:
+ A :class:`KVCacheManager` instance that this scheduler will use
+ to determine whether additional batches can be scheduled.
+ :param use_tqdm:
+ If True, two progress bars are created:
+ * "Started Batches" – increments when a sequence moves from
+ pending to running.
+ * "Finished Batches" – increments when a sequence finishes.
+ """
+
+ def __init__(
+ self,
+ all_sequences: Iterable[Sequence],
+ kv_manager: KVCacheManager,
+ *,
+ use_tqdm=False,
+ ):
+ self.allseq_mapping: dict[int, Sequence] = {s.seq_id: s for s in all_sequences}
+ self.pending_sequence_ids: set[int] = set([s.seq_id for s in all_sequences])
+ self.active_sequence_ids: set[int] = set()
+ self.finished_sequence_ids: set[int] = set()
+ self.manager = kv_manager
+ self.use_tqdm = use_tqdm
+ self.start_time = time.perf_counter()
+ self.total_tokens_generated = 0
+ self.total_tokens_input = 0
+ self.pbar = None
+ if use_tqdm:
+ self.pbar = tqdm(
+ total=len(self.pending_sequence_ids),
+ desc="Completed Batches",
+ )
+
+ def get_prefill_batch(self) -> List[Sequence]:
+ """
+ Select a batch of pending sequences to prefill under KV/memory constraints.
+
+ The selection is greedy over ``pending_sequence_ids`` in iteration order.
+ A sequence is added to the batch if:
+
+ * The sum of its prompt length and the total prompt tokens selected so
+ far does not exceed ``manager.max_batched_tokens``, and
+ * There is at least one free KV "batch slot" left
+ (``manager.num_free_batches``), and
+ * The total number of KV pages required by the sequence's prompt +
+ max_new_tokens does not exceed the remaining free pages.
+ Returns:
+ :return List[Sequence]:
+ The list of :class:`Sequence` objects chosen for prefill in
+ this step. The caller is responsible for marking them as
+ active via :meth:`add_running_sequence_ids`.
+ """
+ total_tok, sequences = 0, []
+ num_free_batches, num_free_pages = (
+ self.manager.num_free_batches,
+ self.manager.num_free_pages,
+ )
+ for seq_id in self.pending_sequence_ids:
+ seq = self.allseq_mapping[seq_id]
+ prompt_length = seq.prompt_len
+ pages_needed = (
+ cdiv(
+ prompt_length + seq.sampling_params.max_new_tokens,
+ self.manager.page_size,
+ )
+ * self.manager.num_kv_heads
+ )
+ if (
+ prompt_length + total_tok <= self.manager.max_batched_tokens
+ and num_free_batches > 0
+ and pages_needed < num_free_pages
+ ):
+ sequences.append(seq)
+ total_tok += prompt_length
+ num_free_pages -= pages_needed
+ num_free_batches -= 1
+ return sequences
+
+ def is_finished(self) -> bool:
+ """
+ Check whether all sequences have completed.
+ """
+ return (
+ len(self.pending_sequence_ids) == 0 and len(self.active_sequence_ids) == 0
+ )
+
+ def any_pending_sequences(self) -> bool:
+ """
+ Check whether any sequences are still pending (not yet started).
+ """
+ return len(self.pending_sequence_ids) != 0
+
+ def add_running_sequence_ids(
+ self, active_sequence_ids: Iterable[int], *, update_status: bool = False
+ ):
+ """
+ Mark a set of sequences as active / running. This moves sequence IDs
+ from ``pending_sequence_ids`` into ``active_sequence_ids``. Optionally,
+ it also updates the per-sequence status and progress bar.
+
+ Args:
+ :param active_sequence_ids:
+ Iterable of sequence IDs that have been scheduled for prefill
+ or decode and should now be considered running.
+ :param update_status:
+ If True, set each corresponding :class:`Sequence`'s
+ ``status = SequenceStatus.RUNNING`` and increment the
+ "Started Batches" progress bar if ``use_tqdm`` is enabled.
+ """
+ self.active_sequence_ids.update(active_sequence_ids)
+ self.pending_sequence_ids.difference_update(self.active_sequence_ids)
+ if update_status:
+ for seq_id in active_sequence_ids:
+ self.allseq_mapping[seq_id].status = SequenceStatus.RUNNING
+ self.total_tokens_input += self.allseq_mapping[seq_id].prompt_len
+
+ def get_finished_sequence_ids_from_unfinished(
+ self, unfinished_sequence_ids: Iterable[int]
+ ) -> set[int]:
+ """
+ Infer which active sequences have finished given the
+ unfinished set (for decode steps where the caller knows
+ which sequences are still generating but not necessarily
+ which have just completed).
+ Args:
+ :param unfinished_sequence_ids:
+ Iterable of sequence IDs that are still running
+ Returns:
+ :return set[int]:
+ The inferred set of sequence IDs that transitioned from active
+ to finished.
+ """
+ return self.active_sequence_ids.difference(unfinished_sequence_ids)
+
+ def record_finished_sequence_ids(
+ self, finished_sequence_ids: Iterable[int], *, update_status: bool = False
+ ):
+ """
+ Record that a set of sequences has finished generation.
+
+ This moves IDs from ``active_sequence_ids`` into
+ ``finished_sequence_ids``.
+
+ Args:
+ :param finished_sequence_ids:
+ Iterable of sequence IDs that have completed generation and
+ no longer require KV cache.
+ :param update_status:
+ If True, set each corresponding :class:`Sequence`'s
+ ``status = SequenceStatus.FINISHED``
+ """
+ self.active_sequence_ids.difference_update(finished_sequence_ids)
+ self.finished_sequence_ids.update(finished_sequence_ids)
+ if update_status:
+ for seq_id in finished_sequence_ids:
+ self.allseq_mapping[seq_id].status = SequenceStatus.FINISHED
+ if self.pbar is not None:
+ self.pbar.update(1)
+
+ def update_sequences(self, tokens: Iterable[int], seq_ids: Iterable[int]):
+ """
+ Append newly generated tokens to their corresponding sequences.
+ Args:
+ :param tokens:
+ Iterable of generated token IDs, one per sequence.
+ :param seq_ids:
+ Iterable of sequence IDs aligned with ``tokens``.
+ """
+ cur_time = time.perf_counter()
+ for tok, seq_id in zip(tokens, seq_ids):
+ self.allseq_mapping[seq_id].add_new_token(tok)
+ self.total_tokens_generated += 1
+ if self.pbar is not None:
+ self.pbar.set_description(
+ f"Throughput: {(self.total_tokens_generated + self.total_tokens_input) / (cur_time - self.start_time):.2f} tok/s"
+ )
+
+ def close(self):
+ if self.pbar is not None:
+ self.pbar.close()
+
+ def can_prefill_another_batch(self) -> bool:
+ return len(self.get_prefill_batch()) > 0
diff --git a/vllm/compactor-vllm/src/compactor_vllm/kv_cache/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/kv_cache/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/kv_cache/page_table.py b/vllm/compactor-vllm/src/compactor_vllm/kv_cache/page_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b6ca46c7c941ddb94c03acd9e28166c9363975a
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/kv_cache/page_table.py
@@ -0,0 +1,313 @@
+import heapq
+import logging
+from enum import Enum, auto
+from typing import List, Optional, Union
+
+import torch
+from compactor_vllm.config.constants import RESERVED_BATCH
+from compactor_vllm.kv_cache.write_page_table import scatter_to_page_table
+
+logger = logging.getLogger(__name__)
+
+
+def cdiv(a, b):
+ return (a + b - 1) // b
+
+
+def next_multiple(a, b):
+ return cdiv(a, b) * b
+
+
+class KVAllocationStatus(Enum):
+ EXCEEDS_MAX_SEQUENCE_LENGTH = auto()
+ EXCEEDS_CURRENTLY_AVAILABLE_PAGES = auto()
+ EXCEEDS_MAX_NUM_BATCHES = auto()
+ SUCCESS = auto()
+
+
+class PagedKVCache(torch.nn.Module):
+ """
+ Global paged KV cache.
+ This module manages:
+ * A global K/V backing buffer for all layers:
+ ``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
+ where the first dimension indexes K vs V.
+ * A per-layer page table:
+ ``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
+ mapping logical (batch, kv-head, logical_page) to a physical page ID
+ in the global K/V buffer.
+ * Per-layer, per-(batch, kv-head) logical sequence lengths
+ ``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
+ the number of allocated pages ``bh_num_pages`` for each (layer, batch,
+ head).
+ * A page allocator implemented as a min-heap of free physical pages
+ per layer, plus free batch indices.
+ Pages are of fixed size ``page_size`` tokens.
+ Args:
+ :param num_layers:
+ Number of transformer layers that will use this cache.
+ :param max_logical_pages_per_head:
+ Maximum number of logical pages that can be assigned to a single
+ (batch, kv-head) pair.
+ :param num_pages:
+ Total number of physical pages available in the global cache per
+ layer. The global K/V buffers are of length
+ ``num_pages * page_size`` along the token dimension.
+ :param page_size:
+ Number of tokens stored per page.
+ :param H_kv:
+ Number of KV heads per layer.
+ :param head_dim:
+ Head dimension for K/V.
+ :param max_num_batches:
+ Maximum number of concurrent batches / sequences supported. One
+ batch index is reserved for internal use (``RESERVED_BATCH``).
+ :param dtype:
+ Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
+ :param device:
+ Device on which to allocate the cache (string, torch.device, or
+ int; defaults to ``"cuda"``).
+ """
+
+ def __init__(
+ self,
+ num_layers: int,
+ max_logical_pages_per_head: int,
+ num_pages: int,
+ page_size: int, # tokens per page
+ H_kv: int,
+ head_dim: int,
+ max_num_batches: int,
+ dtype: torch.dtype,
+ device: Union[str, torch.device, int] = "cuda",
+ ):
+ super().__init__()
+ self.n_pages = num_pages
+ self.num_layers = num_layers
+ self.page_size: int = int(page_size)
+ self.H_kv = int(H_kv)
+ self.max_pages_per_head = max_logical_pages_per_head
+ max_num_batches += 1
+ self.max_num_batches = max_num_batches
+ self.head_dim = head_dim
+ cache_shape = (2, num_layers, num_pages * page_size, head_dim)
+ self.kv_cache = torch.empty(cache_shape, dtype=dtype, device=device)
+
+ self.page_table = torch.empty(
+ (num_layers, max_num_batches, H_kv, self.max_pages_per_head),
+ device=device,
+ dtype=torch.int32,
+ )
+
+ # Per-(batch, head) logical seq length (tokens)
+ self.bh_seq_lens = torch.zeros(
+ (num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
+ )
+ # self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
+ self.bh_num_pages = torch.zeros(
+ (num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
+ )
+
+ # Page allocator (min-heap of free physical pages)
+ self.free_pages: List[List[int]] = [
+ list(range(num_pages)) for _ in range(num_layers)
+ ]
+ for free_pages in self.free_pages:
+ heapq.heapify(free_pages)
+ # batch zero is reserved
+ self.free_batches: List[int] = list(reversed(range(max_num_batches)))
+ self.free_batches.remove(RESERVED_BATCH)
+ # Record of physical page ids owned by a batch (for freeing)
+ self.pages_indices_per_batch: List[List[set[int]]] = [
+ [set() for _ in range(num_layers)] for _ in range(max_num_batches)
+ ]
+
+ def new_batch(self) -> Optional[int]:
+ """
+ Reserve a new batch slot.
+ A batch slot corresponds to a row in ``bh_seq_lens`` /
+ ``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
+ heads. This method checks whether a free batch index is available, and
+ whether each layer has at least ``H_kv`` free pages remaining.
+ If both checks pass, it returns a batch index and removes it from
+ ``free_batches``. Otherwise, it returns ``None``.
+
+ Returns:
+ :return Optional[int]:
+ Newly reserved batch index, or ``None`` if no capacity is
+ available.
+ """
+ if self.free_batches and all([self.H_kv <= len(fp) for fp in self.free_pages]):
+ return self.free_batches.pop()
+ return None
+
+ def reserve_tokens(self, batch_index: int, add_tokens: int) -> KVAllocationStatus:
+ """
+ Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
+ Args:
+ :param batch_index:
+ Batch index to reserve space for.
+ :param add_tokens:
+ Number of additional tokens to reserve capacity for.
+ All heads in this batch and all layers reserve
+ the same number of extra tokens.
+ Returns:
+ :return bool:
+ ``True`` if the reservation succeeds; ``False`` otherwise .
+ """
+ cur_bh_lens = self.bh_seq_lens[:, batch_index] # [L, H]
+ curr_pages = self.bh_num_pages[:, batch_index] # [L, H]
+ curr_cap_tokens = curr_pages * self.page_size # [L, H]
+ need_tokens = cur_bh_lens + add_tokens # [L, H]
+ if (need_tokens <= curr_cap_tokens).all():
+ return KVAllocationStatus.SUCCESS
+ missing_tokens = need_tokens - curr_cap_tokens
+ add_pages = cdiv(missing_tokens, self.page_size)
+ new_total_pages = curr_pages + add_pages
+ if (new_total_pages > self.max_pages_per_head).any():
+ return KVAllocationStatus.EXCEEDS_MAX_SEQUENCE_LENGTH
+ # CPU work
+ pages_per_layer_cpu = add_pages.sum(dim=-1).tolist()
+ new_phys_pages = []
+ for layer_index in range(self.num_layers):
+ if pages_per_layer_cpu[layer_index] > len(self.free_pages[layer_index]):
+ return KVAllocationStatus.EXCEEDS_CURRENTLY_AVAILABLE_PAGES
+ for layer_index in range(self.num_layers):
+ this_layer_pages = [
+ heapq.heappop(self.free_pages[layer_index])
+ for _ in range(pages_per_layer_cpu[layer_index])
+ ]
+ self.pages_indices_per_batch[batch_index][layer_index] |= set(
+ this_layer_pages
+ )
+ new_phys_pages.extend(this_layer_pages)
+
+ new_phys_pages = torch.tensor(new_phys_pages, dtype=torch.int32, device="cuda")
+
+ scatter_to_page_table(
+ add_pages=add_pages,
+ new_phys_pages=new_phys_pages,
+ curr_pages=curr_pages,
+ page_table=self.page_table[:, batch_index],
+ max_pages_per_head=self.max_pages_per_head,
+ )
+
+ self.bh_num_pages[:, batch_index, :] = new_total_pages.to(
+ self.bh_num_pages.dtype
+ )
+ return KVAllocationStatus.SUCCESS
+
+ def reclaim_pages(
+ self,
+ batch_index: int,
+ future_reserve_tokens: int = 0,
+ ):
+ """
+ Reclaim unused pages for a single batch index. This shrinks the KV
+ allocation for the batch down to the minimum number of pages needed
+ to hold the current (plus optional future) sequence length.
+
+ Args:
+ :param batch_index:
+ Batch index whose pages should be compacted.
+ :param future_reserve_tokens:
+ Optional number of extra tokens to keep capacity for, beyond
+ the current sequence length. This can reduce churn when
+ sequences are expected to grow slightly in the near future.
+
+ Returns:
+ :return int:
+ Approximate number of bytes freed across both K and V.
+ """
+ device = self.bh_seq_lens.device
+ L, B, H = self.bh_seq_lens.shape
+ assert 0 <= batch_index < B
+
+ seq = self.bh_seq_lens[:, batch_index, :] + future_reserve_tokens # [L, H]
+ alloc = self.bh_num_pages[:, batch_index, :] # [L, H]
+ pt = self.page_table[:, batch_index, :, :].reshape(-1) # [L, H, P]
+
+ # Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
+ used_pages = cdiv(seq, self.page_size)
+ used_pages = torch.minimum(used_pages, alloc)
+
+ # page indices [0..P-1], broadcasted over [L, H, P]
+ p = torch.arange(
+ self.max_pages_per_head, device=device, dtype=torch.int32
+ ).view(1, 1, self.max_pages_per_head)
+
+ # allocated: p < alloc
+ alloc_mask = p < alloc.unsqueeze(-1) # [L, H, P]
+ # to free: allocated and p in [used_pages, alloc)
+ free_mask = alloc_mask & (p >= used_pages.unsqueeze(-1))
+ free_mask_flat = free_mask.view(-1) # [L*H*P]
+ if not free_mask_flat.any():
+ return 0
+
+ idx = free_mask_flat.nonzero(as_tuple=False).squeeze(
+ -1
+ ) # indices of freed slots
+
+ # Freed physical page ids
+ freed_pages = pt[idx]
+ # Compute layer index for each freed slot:
+ # layout is [L, H, P] → flat index = ((l * H) + h) * P + p
+ freed_layers = (idx // (H * self.max_pages_per_head)).to(torch.int32)
+ freed_pages = freed_pages.tolist()
+ layer_mapping = freed_layers.tolist()
+ self.bh_num_pages[:, batch_index, :] = used_pages
+ for page, layer in zip(freed_pages, layer_mapping):
+ self.pages_indices_per_batch[batch_index][layer].remove(page)
+ heapq.heappush(self.free_pages[layer], page)
+ approximate_bytes_freed = (
+ len(freed_pages)
+ * (self.page_size * self.head_dim * self.kv_cache.element_size())
+ * 2
+ ) # multiply for two for K + V
+ return approximate_bytes_freed
+
+ def _free_batch_layer(self, layer_index: int, batch_index: int) -> None:
+ """
+ Free all pages belonging to batch_index and reset its metadata.
+ """
+ # Return pages to the global heap
+ for phys in self.pages_indices_per_batch[batch_index][layer_index]:
+ heapq.heappush(self.free_pages[layer_index], int(phys))
+
+ self.pages_indices_per_batch[batch_index][layer_index] = set()
+
+ def free_batch(self, batch_index: int) -> None:
+ """
+ Free all resources associated with a batch index.
+ Args:
+ :param batch_index:
+ Batch index to release. Must have been previously allocated
+ via :meth:`new_batch`.
+ """
+ for layer in range(self.num_layers):
+ self._free_batch_layer(layer, batch_index)
+ self.bh_seq_lens[:, batch_index].zero_()
+ self.bh_num_pages[:, batch_index].zero_()
+ self.free_batches.append(batch_index)
+
+ def layer_slices(self, layer: int):
+ """
+ Return layer-local views needed by the attention module.
+
+ For a given ``layer`` index, this method returns the slices of the
+ global K/V cache, page table, and per-(batch, head) sequence lengths
+ corresponding to that layer.
+ Args:
+ :param layer:
+ Layer index ``l`` in ``[0, num_layers)``.
+
+ Returns:
+ :return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ ``(k, v, pt, bh)`` as described above.
+ """
+ assert 0 <= layer < self.num_layers
+ k = self.kv_cache[0, layer]
+ v = self.kv_cache[1, layer]
+ pt = self.page_table[layer]
+ bh = self.bh_seq_lens[layer]
+ return k, v, pt, bh
diff --git a/vllm/compactor-vllm/src/compactor_vllm/kv_cache/store_kv_cache.py b/vllm/compactor-vllm/src/compactor_vllm/kv_cache/store_kv_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..405b90dfe34b9d0b47f1d3aebfc8fa1ee3dc6495
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/kv_cache/store_kv_cache.py
@@ -0,0 +1,468 @@
+import torch
+import triton
+import triton.language as tl
+from compactor_vllm.config.constants import (
+ TRITON_RESERVED_BATCH as _TRITON_RESERVED_BATCH,
+)
+
+
+@triton.jit
+def _prefill_store_topk_kv_kernel(
+ key,
+ value, # [N_total, H, D] (D stride assumed 1)
+ batch_mapping, # [B] int32 (local b -> true batch)
+ num_tokens_to_retain, # [B] int32
+ indices_topk, # [B, MAX_SEL] int32 (across all heads)
+ # Lengths & page table:
+ bh_lens, # [B, H] int32 (contiguous)
+ page_table, # [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
+ k_cache,
+ v_cache, # [N_PAGES * PAGE_SIZE, D]
+ sk_n,
+ sk_h, # strides for key,value. D stride assumed 1
+ sv_n,
+ sv_h,
+ # Runtime ints
+ MAX_SEL, # num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ K_TILE: tl.constexpr, # how many selected tokens each program processes
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ b_local = tl.program_id(0)
+ tile_id = tl.program_id(1)
+ offs = tl.arange(0, D)
+ # how many tokens we actually keep for this batch
+ k_total = tl.load(num_tokens_to_retain + b_local)
+ if k_total == 0:
+ return
+ # map to true batch row in the page table
+ b_true = tl.load(batch_mapping + b_local)
+ if b_true == TRITON_RESERVED_BATCH:
+ return
+ base = tile_id * K_TILE
+ # process up to K_TILE tokens
+ for j in tl.range(0, K_TILE):
+ sel_idx = base + j
+ if sel_idx < k_total and sel_idx < MAX_SEL:
+ # flattened selection: sel = token * H + head
+ sel = tl.load(indices_topk + b_local * MAX_SEL + sel_idx)
+ tok = sel // HKV
+ head = sel - (tok * HKV)
+ # atomically reserve one position in (b_local, hed)
+ # i.e the KV cache is scrambled when storing
+ len_ptr = bh_lens + b_local * HKV + head
+ pos = tl.atomic_add(len_ptr, 1) # old length (int32)
+ lp = pos // PAGE_SIZE
+ off = pos - lp * PAGE_SIZE
+ # translate logical page to physical page
+ pt_base = (b_true * HKV + head) * N_LOGICAL_PAGES_MAX
+ phys = tl.load(page_table + pt_base + lp).to(tl.int64)
+ # destination row and element offset
+ dst_row = phys * PAGE_SIZE + off
+ dst_off = dst_row * D + offs
+ # load one vector from [N_total, H, D]
+ k_src = key + tok * sk_n + head * sk_h + offs
+ v_src = value + tok * sv_n + head * sv_h + offs
+ tl.store(
+ k_cache + dst_off,
+ tl.load(k_src, cache_modifier=".cv", eviction_policy="evict_first"),
+ eviction_policy="evict_first",
+ )
+ tl.store(
+ v_cache + dst_off,
+ tl.load(v_src, cache_modifier=".cv", eviction_policy="evict_first"),
+ eviction_policy="evict_first",
+ )
+
+
+def prefill_store_topk_kv(
+ *,
+ new_keys: torch.Tensor, # [N_total, H, D]
+ new_vals: torch.Tensor, # [N_total, H, D]
+ indices_topk: torch.Tensor, # [B, MAX_SEL] int32 (global flattened token*H + head)
+ num_tokens_to_retain: torch.Tensor, # [B] int32
+ page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
+ batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
+ bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
+ k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ PAD_TO_PAGE_SIZE: bool = True,
+ cu_seqlens_k: torch.Tensor | None = None,
+ K_TILE: int = 16,
+ TRITON_RESERVED_BATCH: int = None,
+):
+ assert new_keys.shape == new_vals.shape
+ N_total, H, D = new_keys.shape
+ B = indices_topk.shape[0]
+ assert page_table.shape[1] == H
+ assert bh_lens.shape == (B, H)
+ assert new_keys.device == k_cache.device == v_cache.device
+ assert page_table.is_contiguous(), "page table must be contiguous."
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+ assert new_keys.stride(-1) == 1 and new_vals.stride(-1) == 1, (
+ "new_keys/new_vals last dim must be contiguous."
+ )
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ page_table = page_table.to(torch.int32)
+ bh_lens = bh_lens.to(torch.int32)
+ batch_mapping = batch_mapping.to(torch.int32)
+ indices_topk = indices_topk.to(torch.int32)
+ num_tokens_to_retain = num_tokens_to_retain.to(torch.int32)
+
+ # strides (elements) for [N_total, H, D]
+ sk_n, sk_h, _ = new_keys.stride()
+ sv_n, sv_h, _ = new_vals.stride()
+
+ # tile second grid dim
+ MAX_SEL = indices_topk.shape[-1]
+ N_TILES = (MAX_SEL + K_TILE - 1) // K_TILE
+ grid = (B, max(1, N_TILES))
+ if TRITON_RESERVED_BATCH is None:
+ TRITON_RESERVED_BATCH = _TRITON_RESERVED_BATCH
+ _prefill_store_topk_kv_kernel[grid](
+ key=new_keys,
+ value=new_vals,
+ batch_mapping=batch_mapping,
+ num_tokens_to_retain=num_tokens_to_retain,
+ indices_topk=indices_topk,
+ bh_lens=bh_lens,
+ page_table=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ MAX_SEL=int(MAX_SEL),
+ HKV=H,
+ N_LOGICAL_PAGES_MAX=page_table.shape[2],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ K_TILE=K_TILE,
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
+ )
+ if PAD_TO_PAGE_SIZE:
+ assert cu_seqlens_k is not None
+ assert indices_topk.is_contiguous()
+ assert page_table.is_contiguous()
+ _prefill_store_topk_pad_kernel[(B, H)](
+ key=new_keys,
+ value=new_vals,
+ batch_mapping=batch_mapping,
+ num_tokens_to_retain=num_tokens_to_retain,
+ indices=indices_topk,
+ local_lens=bh_lens,
+ page_table_flat=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ cu_seqlens_k=cu_seqlens_k,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ MAX_SEL=int(MAX_SEL),
+ H=H, # type: ignore
+ N_LOGICAL_PAGES_MAX=page_table.shape[2], # type: ignore
+ D=D, # type: ignore
+ PAGE_SIZE=PAGE_SIZE, # type: ignore
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
+ )
+
+
+@triton.jit
+def _prefill_store_topk_pad_kernel(
+ key, # [N_total, H, D]
+ value, # [N_total, H, D]
+ batch_mapping, # [B] int32 (local b -> true batch)
+ num_tokens_to_retain, # [B] int32
+ indices, # [B, MAX_SEL] int32 (across all heads)
+ local_lens, # [B, H] int32 (contiguous)
+ page_table_flat, # [B_total*H*N_LOGICAL_PAGES_MAX] int32
+ k_cache,
+ v_cache, # [N_PAGES*PAGE_SIZE, D]
+ cu_seqlens_k,
+ sk_n,
+ sk_h,
+ sv_n,
+ sv_h,
+ MAX_SEL,
+ # Constexprs
+ H: tl.constexpr, # number of KV heads
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ b_local = tl.program_id(0)
+ h = tl.program_id(1)
+ offs_d = tl.arange(0, D)
+ L = tl.load(local_lens + b_local * H + h)
+ modulo_page_size = L - (L // PAGE_SIZE) * PAGE_SIZE
+ if modulo_page_size == 0:
+ return
+ need = PAGE_SIZE - modulo_page_size
+ b_true = tl.load(batch_mapping + b_local)
+ if b_true == TRITON_RESERVED_BATCH:
+ return
+ pt_base = (b_true * H + h) * N_LOGICAL_PAGES_MAX
+ written_tokens = 0
+ idx = tl.load(num_tokens_to_retain + b_local)
+ this_batch_ctx_len = tl.load(cu_seqlens_k + b_local + 1) - tl.load(
+ cu_seqlens_k + b_local
+ )
+ max_additional = this_batch_ctx_len - L
+ while (written_tokens < need and idx < MAX_SEL) and (
+ written_tokens < max_additional
+ ):
+ # candidate head
+ cand_idx = tl.load(indices + b_local * MAX_SEL + idx)
+ cand_h = cand_idx % H
+ if cand_h == h:
+ tok = cand_idx // H
+ pos = L + written_tokens
+ lp = pos // PAGE_SIZE
+ off = pos - lp * PAGE_SIZE
+ phys = tl.load(page_table_flat + pt_base + lp).to(tl.int32)
+
+ dst_row = phys * PAGE_SIZE + off
+ dst_off = dst_row.to(tl.int64) * D + offs_d
+
+ k_src = key + tok * sk_n + h * sk_h + offs_d
+ v_src = value + tok * sv_n + h * sv_h + offs_d
+
+ tl.store(
+ k_cache + dst_off,
+ tl.load(k_src),
+ )
+ tl.store(
+ v_cache + dst_off,
+ tl.load(v_src),
+ )
+
+ written_tokens += 1
+ idx += 1
+ tl.store(local_lens + b_local * H + h, L + written_tokens)
+
+
+@triton.jit
+def _prefill_store_all_kv_kernel(
+ key,
+ value, # [N, H, D] (D contiguous)
+ cu_seqlens_k, # [B + 1] int32
+ batch_mapping, # [B] int32 (local b -> true batch index)
+ bh_lens, # [B * HKV] int32 (UPDATED)
+ pt_flat, # [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
+ k_cache,
+ v_cache, # [N_PAGES * PAGE_SIZE, D]
+ # source strides (elements)
+ sk_n,
+ sk_h,
+ sv_n,
+ sv_h,
+ # constexpr
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ K_TILE: tl.constexpr, # number of (token, head) pairs processed per program
+):
+ pid_b = tl.program_id(0)
+ pid_blk = tl.program_id(1)
+
+ start = tl.load(cu_seqlens_k + pid_b)
+ end = tl.load(cu_seqlens_k + pid_b + 1)
+ num_toks_this_batch = end - start
+ if num_toks_this_batch <= 0:
+ return
+
+ total_elems = num_toks_this_batch * HKV
+
+ # base linear index in (token, head) grid for this program
+ base = pid_blk * K_TILE
+
+ offs_d = tl.arange(0, D)
+
+ # Iterate K_TILE elements in this tile
+ for i in tl.range(0, K_TILE):
+ idx = base + i
+ if idx < total_elems:
+ # map linear idx -> (t, h)
+ t = idx // HKV
+ h = idx - t * HKV
+
+ len_idx = pid_b * HKV + h
+ L0 = tl.load(bh_lens + len_idx)
+
+ token_idx_in_cache = L0 + t
+ lp = token_idx_in_cache // PAGE_SIZE # logical page
+ off_in_pg = token_idx_in_cache - lp * PAGE_SIZE # pos in page
+
+ # physical page
+ b_true = tl.load(batch_mapping + pid_b).to(tl.int32)
+ pt_base = (b_true * HKV + h) * N_LOGICAL_PAGES_MAX
+ phys = tl.load(pt_flat + pt_base + lp).to(tl.int64)
+
+ row = phys * PAGE_SIZE + off_in_pg
+ dst_off = row * D + offs_d
+
+ n_global = (start + t).to(tl.int64)
+
+ # Use strides for non-contiguous [N, H, D] (D stride == 1)
+ k_src = key + n_global * sk_n + h * sk_h + offs_d
+ v_src = value + n_global * sv_n + h * sv_h + offs_d
+
+ tl.store(k_cache + dst_off, tl.load(k_src))
+ tl.store(v_cache + dst_off, tl.load(v_src))
+
+
+def prefill_store_all_kv(
+ *,
+ new_keys: torch.Tensor,
+ new_values: torch.Tensor, # [N, H_kv, D]
+ cu_seqlens_k: torch.Tensor, # [B + 1] int32
+ max_seqlen_k: int,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ page_table: torch.Tensor, # [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
+ bh_lens: torch.Tensor, # [B, H_kv] int32 (UPDATED)
+ batch_mapping: torch.Tensor, # [B] int32 (local->true)
+ PAGE_SIZE: int,
+ K_TILE: int = 32, # how many (token, head) pairs per program
+):
+ assert new_keys.stride(-1) == 1 and new_values.stride(-1) == 1, (
+ "last dim must be contiguous"
+ )
+ assert page_table.is_contiguous(), "page table must be contiguous"
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous"
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous"
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+
+ N, HKV, D = new_keys.shape
+ B = batch_mapping.shape[0]
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+
+ sk_n, sk_h, _ = new_keys.stride()
+ sv_n, sv_h, _ = new_values.stride()
+ n_tiles = (max_seqlen_k * HKV + K_TILE - 1) // K_TILE
+ grid = (B, n_tiles)
+ _prefill_store_all_kv_kernel[grid](
+ new_keys,
+ new_values,
+ cu_seqlens_k,
+ batch_mapping,
+ bh_lens,
+ page_table,
+ k_cache,
+ v_cache,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ HKV=HKV,
+ N_LOGICAL_PAGES_MAX=page_table.shape[-1],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ K_TILE=K_TILE,
+ )
+ bh_lens += cu_seqlens_k.diff()[:, None]
+
+
+@triton.jit
+def _decode_store_kv_kernel(
+ key,
+ value,
+ batch_mapping, # [B] int32
+ bh_lens, # [B*HKV] int32
+ page_table, # [B_total*HKV*N_LOGICAL_PAGES_MAX]
+ k_cache,
+ v_cache, # [N_PAGES*PAGE_SIZE, D]
+ sk_b,
+ sk_h,
+ sv_b,
+ sv_h,
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ h = tl.program_id(1)
+ mapped_b = tl.load(batch_mapping + pid_b)
+ if mapped_b == TRITON_RESERVED_BATCH:
+ return
+ offs_d = tl.arange(0, D)
+
+ length = tl.load(bh_lens + pid_b * HKV + h)
+ logical_page = length // PAGE_SIZE
+ internal_offset = length - logical_page * PAGE_SIZE
+
+ pt_base = (mapped_b * HKV + h) * N_LOGICAL_PAGES_MAX
+ physical_page = tl.load(page_table + pt_base + logical_page).to(tl.int64)
+
+ dst_row = physical_page * PAGE_SIZE + internal_offset
+
+ # Source addressing using strides (D stride == 1)
+ k_src = key + pid_b * sk_b + h * sk_h + offs_d
+ v_src = value + pid_b * sv_b + h * sv_h + offs_d
+
+ dst_off = dst_row * D + offs_d
+ tl.store(k_cache + dst_off, tl.load(k_src))
+ tl.store(v_cache + dst_off, tl.load(v_src))
+ tl.store(bh_lens + pid_b * HKV + h, length + 1)
+
+
+def decode_store_kv(
+ *,
+ key: torch.Tensor, # [B, HKV, D]
+ value: torch.Tensor, # [B, HKV, D]
+ batch_mapping: torch.Tensor, # [B] int32
+ bh_lens: torch.Tensor, # [B, HKV] or flattened [B*HKV] int32
+ page_table: torch.Tensor, # [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor, # [N_PAGES*PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ TRITON_RESERVED_BATCH: int = None,
+):
+ assert key.shape == value.shape and key.ndim == 3, "key/value must be [B, HKV, D]"
+ B, HKV, D = key.shape
+ assert key.stride(-1) == 1 and value.stride(-1) == 1, (
+ "key/value last dim must be contiguous."
+ )
+ assert page_table.is_contiguous(), "page table must be contiguous."
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ sk_b, sk_h, _ = key.stride()
+ sv_b, sv_h, _ = value.stride()
+ grid = (
+ int(batch_mapping.shape[0]),
+ HKV,
+ )
+ _decode_store_kv_kernel[grid](
+ key=key,
+ value=value,
+ batch_mapping=batch_mapping,
+ bh_lens=bh_lens,
+ page_table=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sk_b=sk_b,
+ sk_h=sk_h,
+ sv_b=sv_b,
+ sv_h=sv_h,
+ HKV=HKV,
+ N_LOGICAL_PAGES_MAX=page_table.shape[2],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH
+ if TRITON_RESERVED_BATCH is not None
+ else _TRITON_RESERVED_BATCH,
+ )
diff --git a/vllm/compactor-vllm/src/compactor_vllm/kv_cache/write_page_table.py b/vllm/compactor-vllm/src/compactor_vllm/kv_cache/write_page_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99c4e1f566af65c4586c47c727ae671f9c801d7
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/kv_cache/write_page_table.py
@@ -0,0 +1,110 @@
+import torch
+import triton
+import triton.language as tl
+
+
+def scatter_to_page_table(
+ add_pages: torch.Tensor, # [L, H] int32
+ new_phys_pages: torch.Tensor, # [N]
+ curr_pages: torch.Tensor, # [L, H] int32
+ page_table: torch.Tensor, # [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
+ max_pages_per_head: int,
+):
+ """
+ Append newly allocated physical pages into a layered page table via Triton.
+ For each (layer ``l``, head ``h``):
+ Args:
+ :param add_pages:
+ Tensor of shape ``[L, H]`` (int32) indicating how many pages to
+ append for each (layer, head).
+ :param new_phys_pages:
+ 1D tensor of shape ``[N]`` (int32) containing physical page IDs
+ for all (layer, head) pairs, concatenated in row-major (L, H)
+ order. ``N`` must equal ``add_pages.sum()``.
+ :param curr_pages:
+ Tensor of shape ``[L, H]`` (int32) with the current logical page
+ counts per (layer, head) before this update.
+ :param page_table:
+ Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
+ the logical to physical page mapping. The last dimension is
+ logically indexed as logical_page ∈ [0, max_pages_per_head).
+ :param max_pages_per_head:
+ Maximum number of logical pages permitted per (layer, head). The
+ kernel skips writes beyond this bound.
+ Returns:
+ None. The function updates ``page_table`` in-place.
+ """
+ L, H = add_pages.shape
+ if L == 0 or H == 0:
+ return
+ add_flat = add_pages.to(torch.int32).contiguous().view(-1)
+ curr_flat = curr_pages.to(torch.int32).contiguous().view(-1)
+ cum_page_heads = torch.empty(L * H + 1, device="cuda", dtype=torch.int32)
+ cum_page_heads[0] = 0
+ torch.cumsum(add_flat, 0, out=cum_page_heads[1:])
+ stride_pl, stride_ph, stride_pp = page_table.stride()
+ grid = (L, H)
+ _scatter_pages_kernel_lh[grid](
+ add_flat,
+ cum_page_heads,
+ new_phys_pages,
+ curr_flat,
+ page_table,
+ stride_pl,
+ stride_ph,
+ stride_pp,
+ L=L,
+ H=H,
+ max_pages_per_head=max_pages_per_head,
+ )
+
+
+@triton.jit
+def _scatter_pages_kernel_lh(
+ add_pages, # int32 [L*H]
+ cum_page_heads, # int32 [L*H], base offset in flat_new_phys per (l,h)
+ flat_new_phys, # int32 [total_pages]
+ curr_pages, # int32 [L*H], existing logical pages per (l,h)
+ page_table_ptr, # int32* base pointer to page_table
+ stride_pl, # int, stride for layer dim
+ stride_ph, # int, stride for head dim
+ stride_pp, # int, stride for page dim
+ L: tl.constexpr,
+ H: tl.constexpr,
+ max_pages_per_head: tl.constexpr,
+):
+ layer_idx = tl.program_id(0)
+ h = tl.program_id(1)
+ if layer_idx >= L or h >= H:
+ return
+
+ lh = layer_idx * H + h
+ ap = tl.load(add_pages + lh)
+ if ap <= 0:
+ return
+
+ base = tl.load(cum_page_heads + lh)
+ cp = tl.load(curr_pages + lh)
+
+ # Append ap pages: logical pages [cp .. cp+ap)
+ for i in tl.range(0, ap):
+ phys = tl.load(flat_new_phys + base + i)
+ lp = cp + i
+ if lp < max_pages_per_head:
+ offset = layer_idx * stride_pl + h * stride_ph + lp * stride_pp
+ tl.store(page_table_ptr + offset, phys)
+
+
+# TODO: write reclaim kernel
+@triton.jit
+def reclaim_page_kernel():
+ pass
+
+
+def reclaim_pages(
+ batch_index: int,
+ bh_seq_lens: torch.Tensor,
+ bh_num_pages: torch.Tensor,
+ page_table: torch.Tensor,
+):
+ pass
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/activation.py b/vllm/compactor-vllm/src/compactor_vllm/layers/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a19e488cf3f5d25670fcdc8f4a17161ca64e1010
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/activation.py
@@ -0,0 +1,13 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class SiluAndMul(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # @torch.compile
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, y = x.chunk(2, -1)
+ return F.silu(x) * y
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/attention.py b/vllm/compactor-vllm/src/compactor_vllm/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..adb13677531520e4ba53a61dd41fc711eaa6d5b0
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/attention.py
@@ -0,0 +1,170 @@
+from typing import Optional
+
+import torch
+from compactor_vllm.attention.sparse_decode_kernel import head_sparse_decode_attention
+from compactor_vllm.attention.sparse_varlen_kernel import (
+ causal_sparse_varlen_with_cache,
+)
+from compactor_vllm.compression.common import extract_and_store_top_kv
+from compactor_vllm.config.engine_config import AttentionBackend
+from compactor_vllm.kv_cache.store_kv_cache import decode_store_kv, prefill_store_all_kv
+from compactor_vllm.utils.context import Context, get_context
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from flash_attn.flash_attn_interface import flash_attn_varlen_func
+from torch import nn
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ num_heads,
+ head_dim,
+ scale,
+ num_kv_heads,
+ ):
+ super().__init__()
+ self.num_heads: int = num_heads
+ self.head_dim = head_dim
+ self.scale: float = scale
+ self.num_kv_heads = int(num_kv_heads)
+
+ self.k_cache: Optional[torch.Tensor] = None
+ self.v_cache: Optional[torch.Tensor] = None
+ self.page_table: Optional[torch.Tensor] = None
+ self.bh_seq_lens: Optional[torch.Tensor] = None
+ self.page_size: Optional[int] = None
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ scores: Optional[torch.Tensor] = None,
+ ):
+ context: Context = get_context()
+ batch_mapping = context.batch_mapping
+ seq_lens = (
+ None
+ if self.bh_seq_lens is None
+ else self.bh_seq_lens.index_select(0, batch_mapping).contiguous()
+ )
+ if context.is_prefill:
+ seq_lens_copy = seq_lens.clone() if seq_lens is not None else None
+ if (
+ self.k_cache is not None
+ and context.do_compression
+ and scores is not None
+ ):
+ compression_context = context.compression_context
+ assert scores is not None
+ assert compression_context is not None
+ maybe_execute_in_stream(
+ extract_and_store_top_kv,
+ scores=scores,
+ cu_seqlens_k=context.cu_seqlens_k,
+ max_k_len=context.max_seqlen_k,
+ top_k=compression_context.max_tokens_to_retain,
+ H=int(self.num_kv_heads),
+ new_keys=k,
+ new_vals=v,
+ num_tokens_to_retain=compression_context.batch_tokens_to_retain,
+ page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=seq_lens,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ PAGE_SIZE=self.page_size,
+ PAD_TO_PAGE_SIZE=True,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ elif self.k_cache is not None:
+ maybe_execute_in_stream(
+ prefill_store_all_kv,
+ new_keys=k,
+ new_values=v,
+ cu_seqlens_k=context.cu_seqlens_k,
+ max_seqlen_k=context.max_seqlen_k,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ page_table=self.page_table,
+ bh_lens=seq_lens,
+ batch_mapping=batch_mapping,
+ PAGE_SIZE=self.page_size,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+
+ # No compression: FA varlen on q,k,v (matches HF). Compressed: Triton reads paged KV.
+ use_flash_prefill = context.attention_backend == AttentionBackend.FLASH_ATTENTION or (
+ context.attention_backend == AttentionBackend.COMPACTOR_TRITON
+ and not context.do_compression
+ )
+ if use_flash_prefill:
+ o = flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ max_seqlen_q=context.max_seqlen_q,
+ cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_k=context.max_seqlen_k,
+ cu_seqlens_k=context.cu_seqlens_k,
+ softmax_scale=self.scale,
+ causal=True,
+ )
+ elif context.attention_backend == AttentionBackend.COMPACTOR_TRITON:
+ # Top-k KV writes on STORE_STREAM; Triton prefill must see finished writes.
+ if context.do_compression and context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+ o = causal_sparse_varlen_with_cache(
+ q,
+ k,
+ v,
+ self.k_cache,
+ self.v_cache,
+ seq_lens_bh=seq_lens_copy,
+ global_page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_q=context.max_seqlen_q,
+ max_seqlen_k_cache=context.max_bh_len,
+ HKV=int(self.num_kv_heads),
+ PAGE_SIZE=self.page_size,
+ sm_scale=self.scale,
+ )
+ else:
+ raise NotImplementedError
+ else:
+ assert self.k_cache is not None, "KV Cache must be initialized for decoding"
+ decode_store_kv(
+ key=k,
+ value=v,
+ batch_mapping=batch_mapping,
+ bh_lens=seq_lens,
+ page_table=self.page_table,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ PAGE_SIZE=self.page_size,
+ )
+
+ o = head_sparse_decode_attention(
+ q,
+ self.k_cache,
+ self.v_cache,
+ seq_lens,
+ self.page_table,
+ batch_mapping,
+ int(self.num_kv_heads),
+ self.page_size,
+ self.scale,
+ key_split=context.key_split,
+ )
+ if self.bh_seq_lens is not None:
+ longbm = batch_mapping.to(torch.long)
+ maybe_execute_in_stream(
+ self.bh_seq_lens.index_copy_,
+ 0,
+ longbm,
+ seq_lens,
+ STORE_STREAM=context.STORE_STREAM if context.is_prefill else None,
+ )
+ return o
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/embed_head.py b/vllm/compactor-vllm/src/compactor_vllm/layers/embed_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f2bcf414bf7e4bf06ee4ad94128d997c91a3ca5
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/embed_head.py
@@ -0,0 +1,69 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from compactor_vllm.utils.context import get_context
+from torch import nn
+
+
+class VocabParallelEmbedding(nn.Module):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ ):
+ super().__init__()
+ self.tp_rank = dist.get_rank()
+ self.tp_size = dist.get_world_size()
+ assert num_embeddings % self.tp_size == 0
+ self.num_embeddings = num_embeddings
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
+ self.weight = nn.Parameter(
+ torch.empty(self.num_embeddings_per_partition, embedding_dim)
+ )
+ self.weight.weight_loader = self.weight_loader
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(0)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor):
+ if self.tp_size > 1:
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
+ x = mask * (x - self.vocab_start_idx)
+ y = F.embedding(x, self.weight)
+ if self.tp_size > 1:
+ y = mask.unsqueeze(1) * y
+ dist.all_reduce(y)
+ return y
+
+
+class ParallelLMHead(VocabParallelEmbedding):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ bias: bool = False,
+ ):
+ assert not bias
+ super().__init__(num_embeddings, embedding_dim)
+
+ def forward(self, x: torch.Tensor):
+ context = get_context()
+ if context.is_prefill:
+ last_indices = context.cu_seqlens_q[1:] - 1
+ x = x[last_indices].contiguous()
+ logits = F.linear(x, self.weight)
+ if self.tp_size > 1:
+ all_logits = (
+ [torch.empty_like(logits) for _ in range(self.tp_size)]
+ if self.tp_rank == 0
+ else None
+ )
+ dist.gather(logits, all_logits, 0)
+ logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
+ return logits
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/layernorm.py b/vllm/compactor-vllm/src/compactor_vllm/layers/layernorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dabaad38ce9dec79b9e7c40a1405809c9235f3c
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/layernorm.py
@@ -0,0 +1,49 @@
+import torch
+from torch import nn
+
+
+class RMSNorm(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ eps: float = 1e-6,
+ ) -> None:
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+
+ # @torch.compile
+ def rms_forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ orig_dtype = x.dtype
+ x = x.float()
+ var = x.pow(2).mean(dim=-1, keepdim=True)
+ x.mul_(torch.rsqrt(var + self.eps))
+ x = x.to(orig_dtype).mul_(self.weight)
+ return x
+
+ # @torch.compile
+ def add_rms_forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_dtype = x.dtype
+ x = x.float().add_(residual.float())
+ residual = x.to(orig_dtype)
+ var = x.pow(2).mean(dim=-1, keepdim=True)
+ x.mul_(torch.rsqrt(var + self.eps))
+ x = x.to(orig_dtype).mul_(self.weight)
+ return x, residual
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ return self.rms_forward(x)
+ else:
+ return self.add_rms_forward(x, residual)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/linear.py b/vllm/compactor-vllm/src/compactor_vllm/layers/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..cded48352c1290709f27db645af5ab7558f6610f
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/linear.py
@@ -0,0 +1,153 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch import nn
+
+
+def divide(numerator, denominator):
+ assert numerator % denominator == 0
+ return numerator // denominator
+
+
+class LinearBase(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ tp_dim: int | None = None,
+ ):
+ super().__init__()
+ self.tp_dim = tp_dim
+ self.tp_rank = dist.get_rank()
+ self.tp_size = dist.get_world_size()
+ self.weight = nn.Parameter(torch.empty(output_size, input_size))
+ self.weight.weight_loader = self.weight_loader
+ if bias:
+ self.bias = nn.Parameter(torch.empty(output_size))
+ self.bias.weight_loader = self.weight_loader
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class ReplicatedLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ super().__init__(input_size, output_size, bias)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param.data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.linear(x, self.weight, self.bias)
+
+
+class ColumnParallelLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ tp_size = dist.get_world_size()
+ super().__init__(input_size, divide(output_size, tp_size), bias, 0)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(self.tp_dim)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.linear(x, self.weight, self.bias)
+
+
+class MergedColumnParallelLinear(ColumnParallelLinear):
+ def __init__(
+ self,
+ input_size: int,
+ output_sizes: list[int],
+ bias: bool = False,
+ ):
+ self.output_sizes = output_sizes
+ super().__init__(input_size, sum(output_sizes), bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int
+ ):
+ param_data = param.data
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_data.copy_(loaded_weight)
+
+
+class QKVParallelLinear(ColumnParallelLinear):
+ def __init__(
+ self,
+ hidden_size: int,
+ head_size: int,
+ total_num_heads: int,
+ total_num_kv_heads: int | None = None,
+ bias: bool = False,
+ ):
+ tp_size = dist.get_world_size()
+ total_num_kv_heads = total_num_kv_heads or total_num_heads
+ self.head_size = head_size
+ self.num_heads = divide(total_num_heads, tp_size)
+ self.num_kv_heads = divide(total_num_kv_heads, tp_size)
+ output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
+ super().__init__(hidden_size, output_size, bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str
+ ):
+ param_data = param.data
+ assert loaded_shard_id in ["q", "k", "v"]
+ if loaded_shard_id == "q":
+ shard_size = self.num_heads * self.head_size
+ shard_offset = 0
+ elif loaded_shard_id == "k":
+ shard_size = self.num_kv_heads * self.head_size
+ shard_offset = self.num_heads * self.head_size
+ else:
+ shard_size = self.num_kv_heads * self.head_size
+ shard_offset = (
+ self.num_heads * self.head_size + self.num_kv_heads * self.head_size
+ )
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_data.copy_(loaded_weight)
+
+
+class RowParallelLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ tp_size = dist.get_world_size()
+ super().__init__(divide(input_size, tp_size), output_size, bias, 1)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(self.tp_dim)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
+ if self.tp_size > 1:
+ dist.all_reduce(y)
+ return y
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/moe.py b/vllm/compactor-vllm/src/compactor_vllm/layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..c71b99df34bd10a140b6d7c7a6bb873cb521009b
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/moe.py
@@ -0,0 +1,164 @@
+import torch
+import torch.distributed as dist
+from compactor_vllm.triton_kernels.matmul_ogs import matmul_ogs
+from torch import nn
+
+
+def divide(numerator, denominator):
+ assert numerator % denominator == 0
+ return numerator // denominator
+
+
+class TritonFusedMoeLinearBase(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ tp_dim: int | None = None,
+ ) -> None:
+ super().__init__()
+ self.tp_dim = tp_dim
+ self.tp_rank = dist.get_rank()
+ self.tp_size = dist.get_world_size()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.num_experts = num_experts
+
+ self.weight = nn.Parameter(
+ torch.empty((num_experts, in_features, out_features)).transpose(-1, -2)
+ )
+ self.weight.weight_loader = self.weight_loader
+
+ if bias:
+ self.bias = nn.Parameter(torch.empty((num_experts, out_features)))
+ self.bias.weight_loader = self.weight_loader
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class ReplicatedTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ super().__init__(in_features, out_features, num_experts, bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ param.data[expert_idx].copy_(loaded_weight, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ return matmul_ogs(
+ x,
+ self.weight,
+ self.bias,
+ **kwargs,
+ )
+
+
+class RowParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ tp_size = dist.get_world_size() if dist.is_initialized() else 1
+ super().__init__(
+ divide(in_features, tp_size), out_features, num_experts, bias, 2
+ )
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ shard_size = param.size(2)
+ start_idx = self.tp_rank * shard_size
+ local_shard = loaded_weight[:, start_idx : start_idx + shard_size]
+ param.data[expert_idx].copy_(local_shard, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ y = matmul_ogs(
+ x,
+ w,
+ self.bias,
+ **kwargs,
+ )
+ if self.tp_size > 1:
+ dist.all_reduce(y)
+ return y
+
+
+class ColumnParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ tp_size = dist.get_world_size() if dist.is_initialized() else 1
+ super().__init__(
+ in_features, divide(out_features, tp_size), num_experts, bias, 1
+ )
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ shard_size = param.size(1)
+ start_idx = self.tp_rank * shard_size
+ local_shard = loaded_weight[start_idx : start_idx + shard_size, :]
+ param.data[expert_idx].copy_(local_shard, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ y = matmul_ogs(
+ x,
+ w,
+ self.bias,
+ **kwargs,
+ )
+ return y
+
+
+class MergedColumnParallelTritonFusedMoeLinear(ColumnParallelTritonFusedMoeLinear):
+ def __init__(
+ self,
+ in_features: int,
+ out_feature_list: list[int],
+ num_experts: int,
+ bias: bool = False,
+ ):
+ self.out_feature_list = out_feature_list
+ super().__init__(in_features, sum(out_feature_list), num_experts, bias)
+
+ def weight_loader(
+ self,
+ param: nn.Parameter,
+ loaded_weight: torch.Tensor,
+ expert_idx: int,
+ shard_id: int,
+ ):
+ param_data = param.data
+ shard_offset = sum(self.out_feature_list[:shard_id]) // self.tp_size
+ shard_size = self.out_feature_list[shard_id] // self.tp_size
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ local_weight = loaded_weight.chunk(self.tp_size, dim=self.tp_dim - 1)[
+ self.tp_rank
+ ]
+ param_data[expert_idx].copy_(local_weight, non_blocking=True)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/rotary_embedding.py b/vllm/compactor-vllm/src/compactor_vllm/layers/rotary_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..506616f912a57ff1dcf2543d62ec096b258e31d6
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/rotary_embedding.py
@@ -0,0 +1,94 @@
+import math
+from functools import lru_cache
+
+import torch
+from torch import nn
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ x1, x2 = torch.chunk(x.float(), 2, dim=-1)
+ y1 = x1 * cos - x2 * sin
+ y2 = x2 * cos + x1 * sin
+ return torch.cat((y1, y2), dim=-1).to(x.dtype)
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: float,
+ rope_scaling: tuple,
+ ) -> None:
+ super().__init__()
+ self.head_size = head_size
+ assert rotary_dim == head_size
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)
+ )
+ if rope_scaling is not None:
+ (
+ rope_type,
+ factor,
+ low_freq_factor,
+ high_freq_factor,
+ original_max_position_embeddings,
+ ) = rope_scaling
+ assert rope_type == "llama3"
+ old_context_len = original_max_position_embeddings
+ low_freq_wavelen = old_context_len / low_freq_factor
+ high_freq_wavelen = old_context_len / high_freq_factor
+ wavelen = 2 * math.pi / inv_freq
+
+ inv_freq_llama = torch.where(
+ wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
+ )
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
+ high_freq_factor - low_freq_factor
+ )
+ smoothed_inv_freq = (
+ 1 - smooth_factor
+ ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(
+ wavelen > low_freq_wavelen
+ )
+ inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
+
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
+
+ # @torch.compile
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ query = apply_rotary_emb(query, cos, sin)
+ key = apply_rotary_emb(key, cos, sin)
+ return query, key
+
+
+@lru_cache(1)
+def get_rope(
+ head_size: int,
+ rotary_dim: int,
+ max_position: int,
+ base: float,
+ rope_scaling: tuple | None = None,
+):
+ rotary_emb = RotaryEmbedding(
+ head_size, rotary_dim, max_position, base, rope_scaling
+ )
+ return rotary_emb
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/sampler.py b/vllm/compactor-vllm/src/compactor_vllm/layers/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0761b7c79bc7dc511180078c4d059c3423b3f8f
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/sampler.py
@@ -0,0 +1,27 @@
+import torch
+from torch import nn
+
+
+class Sampler(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # @torch.compile
+ def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
+ temps = temperatures.view(-1)
+ scaled = logits.float()
+
+ greedy_mask = temps == 0.0
+ sample_mask = ~greedy_mask
+
+ if sample_mask.any():
+ temps_sample = temps[sample_mask].unsqueeze(-1) # [B_sample, 1]
+ scaled_sample = scaled[sample_mask].div(temps_sample) # temperature scaling
+
+ E = torch.empty_like(scaled_sample).exponential_(1).clamp_min_(1e-10).log()
+ scaled_sample = scaled_sample - E
+
+ scaled = scaled.clone()
+ scaled[sample_mask] = scaled_sample
+
+ return scaled.argmax(dim=-1)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/layers/triton_helpers.py b/vllm/compactor-vllm/src/compactor_vllm/layers/triton_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c1a31669bac31c9fcef53259f1211e6de19bc37
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/layers/triton_helpers.py
@@ -0,0 +1,101 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _masked_index_select_kernel(
+ X_ptr,
+ IDX_ptr,
+ OUT_ptr,
+ N,
+ stride_xn,
+ stride_xh,
+ stride_ob,
+ stride_oh,
+):
+ b = tl.program_id(0) # which output row (0..B-1)
+ h = tl.program_id(1)
+ idx = tl.load(IDX_ptr + b) # int32
+ valid = (idx >= 0) & (idx < N)
+ out_ptrs = OUT_ptr + b * stride_ob + h * stride_oh
+
+ if not valid:
+ tl.store(out_ptrs, 0)
+ else:
+ x_ptrs = X_ptr + idx * stride_xn + h * stride_xh
+ vals = tl.load(x_ptrs)
+ tl.store(out_ptrs, vals)
+
+
+def masked_index_select_triton_dim0(
+ input: torch.Tensor, index: torch.Tensor
+) -> torch.Tensor:
+ """
+ X: [N, H] : contiguous in the H dimension
+ b_m: [B] int32/int64 on same device; out-of-range -> zeros)
+ Returns: [B, H]
+ """
+ assert input.ndim == 2 and index.ndim == 1
+ N, H = input.shape
+ B = index.numel()
+ out = torch.empty((B, H), dtype=input.dtype, device=input.device)
+ _masked_index_select_kernel[(B, H)](
+ input,
+ index,
+ out,
+ N,
+ input.stride(0),
+ input.stride(1),
+ out.stride(0),
+ out.stride(1),
+ )
+ return out
+
+
+@triton.jit
+def _masked_index_copy_kernel(
+ DST_ptr,
+ IDX_ptr,
+ SRC_ptr,
+ N,
+ stride_dn,
+ stride_dh,
+ stride_sb,
+ stride_sh,
+):
+ b = tl.program_id(0)
+ h = tl.program_id(1)
+ idx = tl.load(IDX_ptr + b)
+ valid = (idx >= 0) & (idx < N)
+ if valid:
+ src_ptrs = SRC_ptr + b * stride_sb + h * stride_sh
+ dst_ptrs = DST_ptr + idx * stride_dn + h * stride_dh
+ tl.store(dst_ptrs, tl.load(src_ptrs))
+
+
+def masked_index_copy_triton_dim0(
+ dst: torch.Tensor, index: torch.Tensor, src: torch.Tensor
+):
+ """
+ In-place: dst.index_copy_(0, index, src) but masked:
+ - rows with index[b] < 0 or >= dst.shape[0] are skipped (no write).
+ Shapes:
+ dst: [N, H]
+ src: [B, H]
+ index: [B]
+ """
+ assert dst.ndim == 2 and src.ndim == 2 and index.ndim == 1
+ N, H = dst.shape
+ B, Hs = src.shape
+ assert Hs == H and index.numel() == B
+ _masked_index_copy_kernel[(B, H)](
+ dst,
+ index,
+ src,
+ N,
+ dst.stride(0),
+ dst.stride(1),
+ src.stride(0),
+ src.stride(1),
+ )
diff --git a/vllm/compactor-vllm/src/compactor_vllm/models/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae530206ed6b941656eb95a7adae2fc6c1638a3c
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/models/__init__.py
@@ -0,0 +1,18 @@
+import logging
+
+from compactor_vllm.models.llama3 import LlamaForCausalLM
+from compactor_vllm.models.qwen3 import Qwen3ForCausalLM
+
+logger = logging.getLogger(__name__)
+
+MODEL_REGISTRY = {
+ "llama": LlamaForCausalLM,
+ "qwen3": Qwen3ForCausalLM,
+}
+
+try:
+ from compactor_vllm.models.qwen3_moe import Qwen3MoeForCausalLM
+except Exception as exc:
+ logger.debug("Skipping qwen3_moe registration due to import error: %s", exc)
+else:
+ MODEL_REGISTRY["qwen3_moe"] = Qwen3MoeForCausalLM
diff --git a/vllm/compactor-vllm/src/compactor_vllm/models/llama3.py b/vllm/compactor-vllm/src/compactor_vllm/models/llama3.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e5587f6881bf039fa951c6fc8d0b0c16f0ba61
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/models/llama3.py
@@ -0,0 +1,281 @@
+import os
+from glob import glob
+
+import torch
+import torch.distributed as dist
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import LlamaConfig
+
+from compactor_vllm.compression import (
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from compactor_vllm.layers.activation import SiluAndMul
+from compactor_vllm.layers.attention import Attention
+from compactor_vllm.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from compactor_vllm.layers.layernorm import RMSNorm
+from compactor_vllm.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from compactor_vllm.layers.rotary_embedding import get_rope
+from compactor_vllm.utils.context import get_context
+
+
+class LlamaAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: dict | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = dist.get_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ if rope_scaling is not None:
+ rope_scaling_tuple = (
+ rope_scaling["rope_type"],
+ rope_scaling["factor"],
+ rope_scaling["low_freq_factor"],
+ rope_scaling["high_freq_factor"],
+ rope_scaling["original_max_position_embeddings"],
+ )
+ else:
+ rope_scaling_tuple = None
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling_tuple,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = q.view(-1, self.num_heads, self.head_dim)
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ mlp_bias: bool,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=mlp_bias,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=mlp_bias,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: LlamaConfig,
+ ) -> None:
+ super().__init__()
+ self.self_attn = LlamaAttention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ qkv_bias=getattr(config, "attention_bias", False),
+ head_dim=getattr(config, "head_dim", None),
+ rope_theta=getattr(config, "rope_theta", 500000.0),
+ rope_scaling=getattr(config, "rope_scaling", None),
+ )
+ self.mlp = LlamaMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ mlp_bias=config.mlp_bias,
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class LlamaModel(nn.Module):
+ def __init__(
+ self,
+ config: LlamaConfig,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(positions, hidden_states, residual)
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class LlamaForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(self, config: LlamaConfig) -> None:
+ super().__init__()
+ self.model = LlamaModel(config)
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, positions)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", "cpu") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_loaded = False
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, loaded_weight: p.data.copy_(loaded_weight),
+ )
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/compactor-vllm/src/compactor_vllm/models/qwen3.py b/vllm/compactor-vllm/src/compactor_vllm/models/qwen3.py
new file mode 100644
index 0000000000000000000000000000000000000000..37000fa1752cef5cad0bd79140ded42d6b57b59f
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/models/qwen3.py
@@ -0,0 +1,286 @@
+import os
+from glob import glob
+
+import torch
+import torch.distributed as dist
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import Qwen3Config
+
+from compactor_vllm.compression import (
+ CompressionMethod,
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from compactor_vllm.layers.activation import SiluAndMul
+from compactor_vllm.layers.attention import Attention
+from compactor_vllm.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from compactor_vllm.layers.layernorm import RMSNorm
+from compactor_vllm.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from compactor_vllm.layers.rotary_embedding import get_rope
+from compactor_vllm.utils.context import get_context
+
+
+class Qwen3Attention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: tuple | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = dist.get_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
+ k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ cc = context.compression_context
+ if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
+ # 关键:注入 wo_weight 到 compression_context
+ wo_raw = self.o_proj.weight
+ hidden_size, _ = wo_raw.shape
+ Hq, D = self.num_heads, self.head_dim
+ cc.wo_weight = (
+ wo_raw.transpose(0, 1)
+ .contiguous()
+ .view(Hq, D, hidden_size)
+ .to(dtype=torch.float32)
+ )
+
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class Qwen3MLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class Qwen3DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3Config,
+ ) -> None:
+ super().__init__()
+ self.self_attn = Qwen3Attention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, "attention_bias", False),
+ head_dim=getattr(config, "head_dim", None),
+ rope_theta=getattr(config, "rope_theta", 1000000),
+ rope_scaling=getattr(config, "rope_scaling", None),
+ )
+ self.mlp = Qwen3MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class Qwen3Model(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3Config,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(positions, hidden_states, residual)
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Qwen3ForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(self, config: Qwen3Config) -> None:
+ super().__init__()
+ self.model = Qwen3Model(config)
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, positions)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", "cpu") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_loaded = False
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, loaded_weight: p.data.copy_(loaded_weight),
+ )
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/compactor-vllm/src/compactor_vllm/models/qwen3_moe.py b/vllm/compactor-vllm/src/compactor_vllm/models/qwen3_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..81e56637e9ccc71f310be33ed2f53a0069aa5015
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/models/qwen3_moe.py
@@ -0,0 +1,378 @@
+import os
+from glob import glob
+
+import torch
+import torch.distributed as dist
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import Qwen3MoeConfig
+
+from compactor_vllm.compression import (
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from compactor_vllm.layers.activation import SiluAndMul
+from compactor_vllm.layers.attention import Attention
+from compactor_vllm.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from compactor_vllm.layers.layernorm import RMSNorm
+from compactor_vllm.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from compactor_vllm.layers.moe import (
+ MergedColumnParallelTritonFusedMoeLinear,
+ RowParallelTritonFusedMoeLinear,
+)
+from compactor_vllm.layers.rotary_embedding import get_rope
+from compactor_vllm.triton_kernels.routing import routing
+from compactor_vllm.utils.context import get_context
+
+
+class Qwen3MoeAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: tuple | None = None,
+ sliding_window: int | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = dist.get_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.sliding_window = sliding_window
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
+ k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class Qwen3MoeMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class Qwen3MoeTritonSparseMoeBlock(nn.Module):
+ def __init__(
+ self,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size: int,
+ num_experts_per_tok: int,
+ norm_topk_prob: bool,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.num_experts = num_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.norm_topk_prob = norm_topk_prob
+ self.hidden_size = hidden_size
+ self.moe_intermediate_size = intermediate_size
+
+ self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False)
+ self.gate_up_proj = MergedColumnParallelTritonFusedMoeLinear(
+ hidden_size, [intermediate_size] * 2, num_experts
+ )
+ self.down_proj = RowParallelTritonFusedMoeLinear(
+ intermediate_size, hidden_size, num_experts
+ )
+ self.act_fn = SiluAndMul()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ x = hidden_states
+ if x.numel() == 0:
+ return x
+ logits = self.gate(x)
+ rdata, gather_indx, scatter_indx = routing(
+ logits,
+ self.num_experts_per_tok,
+ simulated_ep=1, # single device, replicated experts
+ )
+ x = self.gate_up_proj(x, routing_data=rdata, gather_indx=gather_indx)
+ x = self.act_fn(x)
+ x = self.down_proj(
+ x, routing_data=rdata, scatter_indx=scatter_indx, gammas=rdata.gate_scal
+ )
+ return x
+
+
+class Qwen3MoeBlock(Qwen3MoeTritonSparseMoeBlock):
+ pass
+
+
+class Qwen3MoeRMSNorm(RMSNorm):
+ pass
+
+
+class Qwen3MoeDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ layer_idx: int,
+ ) -> None:
+ super().__init__()
+ self.self_attn = Qwen3MoeAttention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ head_dim=getattr(config, "head_dim", None),
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, "attention_bias", False),
+ rope_theta=config.rope_theta,
+ rope_scaling=config.rope_scaling,
+ sliding_window=config.sliding_window,
+ )
+ if (layer_idx not in config.mlp_only_layers) and (
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
+ ):
+ self.mlp = Qwen3MoeBlock(
+ num_experts=config.num_experts,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ num_experts_per_tok=config.num_experts_per_tok,
+ norm_topk_prob=config.norm_topk_prob,
+ hidden_act=config.hidden_act,
+ )
+ else:
+ self.mlp = Qwen3MoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = Qwen3MoeRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = Qwen3MoeRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ # Self Attention
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Qwen3MoeModel(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [
+ Qwen3MoeDecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_ids,
+ )
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+
+class Qwen3MoeForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ ) -> None:
+ super().__init__()
+ self.model = Qwen3MoeModel(config)
+ self.num_experts = config.num_experts
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, position_ids)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ rank = dist.get_rank()
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", f"cuda:{rank}") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_expert = "mlp.experts" in weight_name
+ is_loaded = False
+
+ # Process experts params name
+ if is_expert:
+ mlp_module_name, expert_module_name = weight_name.split(
+ ".experts."
+ )
+ expert_idx = int(expert_module_name.split(".")[0])
+ proj_name = expert_module_name.replace(f"{expert_idx}.", "")
+ weight_name = f"{mlp_module_name}.{proj_name}"
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ if is_expert:
+ weight_loader(
+ param, weight_tensor, expert_idx, shard_id
+ )
+ else:
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, lw: p.data.copy_(lw, non_blocking=True),
+ )
+ if is_expert:
+ weight_loader(param, weight_tensor, expert_idx)
+ else:
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..21d471befd0d710f96f01882fa9e8b8a84059bd9
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction.py
@@ -0,0 +1,76 @@
+import torch
+from .compaction_details._masked_compaction import _masked_compaction
+from .tensor import Bitmatrix
+
+
+def compaction(yv, yi, bitmask, sentinel=-1):
+ """
+ Return compacted copies of *yv* and *yi* based on a per-row bitmask.
+
+ Only the elements whose index appears among the active bits of *bitmask*
+ are kept; the rest are replaced by *sentinel*. Kept elements preserve
+ their original left-to-right order.
+
+ Parameters
+ ----------
+ yv : torch.Tensor, shape (B, K)
+ Values tensor.
+ yi : torch.Tensor, shape (B, K), dtype torch.long
+ Integer indices (0 ≤ index < 32) associated with *yv*.
+ bitmask : torch.Tensor, shape (B,) **or** (B, 32)
+ Per-row mask of active indices. See the in-place version for details.
+ sentinel : int, default -1
+ Value written into dropped positions of the returned tensors.
+
+ Returns
+ -------
+ (yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
+ New tensors with the same dtype/device as the inputs.
+
+ """
+
+ n_rows, n_cols = yi.shape
+ ret_yv = torch.empty_like(yv)
+ ret_yi = torch.empty_like(yi)
+ if isinstance(bitmask, Bitmatrix):
+ bitmask = bitmask.storage.data
+
+ _masked_compaction[(n_rows,)](
+ yv,
+ yi,
+ bitmask,
+ bitmask.stride(0),
+ bitmask.stride(1), # inputs
+ ret_yv,
+ ret_yi, # outputs
+ sentinel, # sentinel
+ K=n_cols, # constants
+ )
+ return ret_yv, ret_yi
+
+
+def compaction_torch(
+ yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1
+):
+ """
+ reference implementation of `masked_compact`
+ """
+ B, K = yi.shape
+ device = yi.device
+ # Expand bitmask to a boolean matrix of active bits (B, 32)
+ w = 1 << torch.arange(32, device=device, dtype=bitmask.dtype)
+ bits = (bitmask.unsqueeze(-1) & w) != 0
+ mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
+ # For every yi element decide whether it should be kept
+ keep = mask.gather(1, yi.long())
+ # Build a stable permutation that brings all "keep" items forward
+ # False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
+ order = (~keep).to(torch.int).argsort(dim=1, stable=True)
+ # Re‑order tensors according to above permutation
+ yi_sorted = yi.gather(1, order)
+ yv_sorted = yv.gather(1, order)
+ # fill relevant positions with sentinel
+ keep_sorted = keep.gather(1, order)
+ yi_sorted[~keep_sorted] = sentinel
+ yv_sorted[~keep_sorted] = sentinel
+ return yv_sorted, yi_sorted
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/_masked_compaction.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/_masked_compaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..58fe2412cf19386dbbe73bea1a5daf75d464ffb2
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/_masked_compaction.py
@@ -0,0 +1,22 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _masked_compaction(
+ Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr
+):
+ pid_m = tl.program_id(0)
+ yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
+ yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
+ div = yi // 32
+ rem = yi % 32
+ active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
+ exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
+ active_flags = active_bits.to(tl.int1)
+ rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
+ write_indx = exc_cumsum + rev_arange
+ yv = tl.where(active_flags, yv, sentinel)
+ yi = tl.where(active_flags, yi, sentinel)
+ tl.store(RetYv + pid_m * K + write_indx, yv)
+ tl.store(RetYi + pid_m * K + write_indx, yi)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53ba9994bd2e5a7bfb6b04e195a339bebb3cb2c
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs.py
@@ -0,0 +1,609 @@
+# isort: off
+# fmt: off
+from dataclasses import dataclass
+import itertools
+import sys
+import torch
+import triton
+from enum import Enum, auto
+import math
+# utilities
+from compactor_vllm.triton_kernels import target_info
+from compactor_vllm.triton_kernels.numerics import InFlexData, OutFlexData
+from compactor_vllm.triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
+from compactor_vllm.triton_kernels.target_info import is_cuda
+# details
+from .matmul_ogs_details._matmul_ogs import _matmul_ogs
+from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
+from .matmul_ogs_details._reduce_grouped import _reduce_grouped
+from .numerics_details.mxfp import MXFP_BLOCK_SIZE
+from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
+from .specialize import specialize
+from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
+
+
+@dataclass(frozen=True)
+class FnSpecs:
+ name: str
+ fn: "triton.runtime.jit.JITFunction"
+ fn_arg_names: tuple[str]
+ fn_arg_do_not_specialize: tuple[str] = tuple()
+
+ @staticmethod
+ def default():
+ return FnSpecs("dflt", None, tuple())
+
+
+@dataclass(frozen=True)
+class FusedActivation:
+ specs: FnSpecs = FnSpecs.default()
+ fn_args: tuple[object] = tuple()
+ reduction_n: int = 1
+
+
+@dataclass(frozen=True)
+class Epilogue:
+ specs: FnSpecs = FnSpecs.default()
+ fn_arg_values_matmul: tuple[object] = tuple()
+ fn_arg_values_finalize: tuple[object] = tuple()
+ effective_itemsize: float = None
+
+class FnName(Enum):
+ QUANTIZE_MXFP8 = auto()
+
+
+EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
+
+_kernels = dict()
+
+
+def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
+ global _kernels
+ key = (fused_activation.name, epilogue.name)
+ if key in _kernels:
+ return _kernels[key]
+ spec_constants = {
+ "ACTIVATION_FN": fused_activation.fn,
+ "EPILOGUE_FN": epilogue.fn,
+ }
+ spec_tuples = {
+ "activation_fn_args": fused_activation.fn_arg_names,
+ "epilogue_fn_args": epilogue.fn_arg_names,
+ }
+ do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
+ import types
+
+ module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
+ sys.modules[module.__name__] = module
+ module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ module._reduce_grouped = specialize(_reduce_grouped, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ _kernels[key] = module
+ return module
+
+
+# -----------------------------------------------------------------------------
+# Matrix Multiplication + Outer Gather/Scatter
+# -----------------------------------------------------------------------------
+
+
+def can_overflow_int32(tensor: torch.Tensor):
+ max_int32 = (1 << 31) - 1
+ offset = 0
+ for i in range(tensor.ndim):
+ offset += (tensor.shape[i] - 1) * tensor.stride(i)
+ return offset > max_int32
+
+
+def should_upcast_indices(*args):
+ return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
+
+
+# ---------------------
+# Numerics
+# ---------------------
+
+# fmt: off
+
+@dataclass(frozen=True)
+class FlexCtx:
+ lhs_data: InFlexData = InFlexData()
+ rhs_data: InFlexData = InFlexData()
+ out_data: OutFlexData = OutFlexData()
+
+@dataclass
+class PrecisionConfig:
+ max_num_imprecise_acc: int = None
+ allow_tf32: bool = True
+ flex_ctx: FlexCtx = FlexCtx()
+ acc_scale: int = 1.0
+ flexpoint_saturate_inf: bool = False
+ report_quantization_err_fn: callable = None
+ act_scale: Tensor | None = None
+ weight_scale: Tensor| None = None
+ out_scale: Tensor | None = None
+ out_dtype: torch.dtype = None
+ enforce_bitwise_invariance: bool = False
+
+
+# TODO: merge in opt_flags
+def get_swap_xw(precision_config, opt_flags):
+ if target_info.cuda_capability_geq(10, 0):
+ return precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
+ return False
+
+# ---------------------
+# Allocation
+# ---------------------
+
+@dataclass
+class MatmulAllocation:
+ device: str
+ output: tuple[tuple[int], torch.dtype]
+ scratchpads: dict[str, tuple]
+
+def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags):
+ # ---- output ------
+ N = w.shape[-1]
+ # by default - M is number of rows in the activations
+ M = x.shape[-2]
+ # if the activations are gathered, then M is number of gather indices
+ if gather_indx is not None:
+ M = gather_indx.src_indx.shape[0]
+ # final output
+ if routing_data.n_expts_act == 1 or scatter_indx is None:
+ y_rows = M
+ else:
+ Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
+ y_rows = Mc
+ batch_dim = x.shape[0] if x.ndim == 3 else 1
+ out_shape = (batch_dim, y_rows, N // fused_activation.reduction_n)
+ out_dtype = precision_config.out_dtype or x.dtype
+ output = (out_shape, out_dtype)
+ # ---- scratchpad -----#
+ scratchpad = dict()
+ if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
+ scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
+ scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), scratch_out_dtype)
+ if "matmul" in scratchpad and precision_config.out_scale is not None:
+ scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
+ return MatmulAllocation(x.device, output, scratchpad)
+
+def apply_allocation(allocation: MatmulAllocation, output):
+ ret = dict()
+ if output is None:
+ output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
+ else:
+ assert output.shape == allocation.output[0]
+ ret["output"] = output[None, :, :]
+ ret["scratchpad"] = {
+ k: torch.empty(v[0], device=allocation.device, dtype=v[1])
+ for k, v in allocation.scratchpads.items()
+ }
+ return ret
+
+# -----------------------------------------------------------------------------
+# Canonicalize
+# -----------------------------------------------------------------------------
+# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
+# we can canonicalize storages to make the implementation more uniform
+
+def _canonicalize_storage(storage, out_ndim, flex_data):
+ assert out_ndim >= storage.data.ndim
+ # Need to use as_strided instead of view because for a tensor with
+ # shape[-2] == 1 can have ambuiguity related to col-wise. Fo example,
+ # > t = torch.randn(2, 5, 1).mT
+ # > t_view = t.view(t.shape)
+ # > t.stride(), t_view.stride()
+ # ((5, 1, 1), (5, 5, 1))
+ # Our check t_view is col-wise fails since t_view.stride(-2) != 1
+ # This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py
+ new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape)
+ new_storage_view = storage.data.view(new_storage_shape)
+ new_storage_stride = [new_storage_view.stride(0)] * (out_ndim - storage.data.ndim) + list(storage.data.stride())
+ new_storage_data = storage.data.as_strided(new_storage_shape, new_storage_stride)
+ if flex_data is not None:
+ new_storage_data = flex_data.reinterpret(new_storage_data)
+ return Storage(new_storage_data, storage.layout)
+
+#
+
+def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_mx_scale: torch.Tensor,
+ fused_activation, epilogue,
+ x_flex: InFlexData | None = None,
+ out_flex: OutFlexData | None = None, x_mx_scale: torch.Tensor | None = None,
+ out_dtype: bool = None, flexpoint_saturate_inf: bool = False):
+ """
+ In-place grouped row reduction.
+
+ Arguments
+ - x: Tensor[AnyFloat] of shape [(num_groups * K), N]
+ - indx: Tensor[Int] of shape [num_groups, K]
+
+ Description
+ For each group g in [0, num_groups), this routine sums the K rows of `x`
+ specified by `indx[g, :]` and overwrites the row corresponding to the first
+ valid (non-negative) index with the per-group sum. Accumulation is performed
+ in float32 for numerical stability, and the result is written back in the
+ dtype of `x`.
+
+ Behavior and edge cases
+ - Invalid (-1) entries are skipped during accumulation and do not generate
+ memory traffic. If a group has no valid entries, nothing is written for
+ that group.
+ - Reduction is performed tile-by-tile along the N dimension within a single
+ kernel launch (persistent along N) to minimize launch overhead.
+
+ Performance notes
+ - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x),
+ plus index reads. With no invalid entries, this becomes (K + 1) reads/writes
+ of length N per group.
+
+ Returns
+ - The input tensor `x` (modified in place).
+ """
+ if indx is None and x.shape[0] == 1:
+ return x.squeeze(0), None
+ if indx is not None:
+ num_groups = indx.shape[0]
+ else:
+ num_groups = x.shape[-2]
+ if x_flex is None:
+ x_flex = InFlexData()
+ if out_flex is None:
+ out_flex = OutFlexData()
+ K = 1 if indx is None else indx.shape[1]
+ out_dtype = x.dtype if out_dtype is None else out_dtype
+ assert x.shape[-1] % fused_activation.reduction_n == 0
+ BLOCK_N = 512
+ # Resolve scalar flex scales (may be None)
+ x_expected_scale = None if x_flex is None else x_flex.scale
+ out_expected_scale = None if out_flex is None else out_flex.expected_scale
+ out_actual_scale = None if out_flex is None else out_flex.actual_scale
+ out_checksum_scale = None if out_flex is None else out_flex.checksum_scale
+ # Resolve MXFP output scale row stride
+ stride_mxb = 0 if x_mx_scale is None else x_mx_scale.stride(0)
+ stride_mxs = 0 if x_mx_scale is None else x_mx_scale.stride(1)
+ stride_omxs = 0 if out_mx_scale is None else out_mx_scale.stride(0)
+ kernels = get_kernels(epilogue.specs, fused_activation.specs)
+ kernels._reduce_grouped[(num_groups, )](
+ x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), #
+ x_expected_scale, # scalar input scale
+ out_flex.reinterpret(out), out.stride(1), out.stride(2), #
+ out_expected_scale, out_actual_scale, out_checksum_scale, indx, #
+ x.shape[0], x.shape[-1], #
+ x_mx_scale, stride_mxb, stride_mxs, #
+ out_mx_scale, stride_omxs, #
+ *fused_activation.fn_args, fused_activation.reduction_n,
+ *epilogue.fn_arg_values_finalize,
+ HAS_IN_MX_SCALE=x_mx_scale is not None, HAS_OUT_MX_SCALE=out_mx_scale is not None,
+ FLEXPOINT_SATURATE_INF=flexpoint_saturate_inf, #
+ BLOCK_N=BLOCK_N, K=K, #
+ num_warps=1, #
+ )
+ return out, out_mx_scale
+
+# -----------------------------------------------------------------------------
+# Triton Implementation
+# -----------------------------------------------------------------------------
+
+def matmul_ogs_set_idle_sms(num_idle_sms):
+ """
+ persistent kernels will leave `num_idle_sms` idle
+ """
+ update_opt_flags_constraints({"idle_sms": num_idle_sms})
+
+def matmul_ogs(x, w, bias,
+ routing_data: RoutingData | None = None,
+ gather_indx: GatherIndx | None = None,
+ scatter_indx: ScatterIndx | None = None,
+ precision_config: PrecisionConfig | None = None,
+ betas: torch.Tensor | None = None,
+ gammas: torch.Tensor | None = None,
+ out_alpha: float | None = None,
+ y: torch.Tensor | None = None,
+ fused_activation: FusedActivation | None = None,
+ epilogue: Epilogue | None = None,
+ ):
+ """
+ Y[:, :] = 0.
+ for e in num_experts:
+ Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
+ """
+ is_input_batched = x.ndim == 3
+ if is_input_batched:
+ assert gather_indx is None, "gather not supported in batched mode"
+ assert scatter_indx is None, "scatter not supported in batched mode"
+ assert routing_data is None, "routing not supported in batched mode"
+ assert w.ndim == 3 and w.shape[0] == x.shape[0]
+ # canonicalize inputs
+ if precision_config is None:
+ precision_config = PrecisionConfig()
+ if fused_activation is None:
+ fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
+ if epilogue is None:
+ epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
+ if routing_data is None:
+ routing_data = RoutingData(None, None, max(1, w.shape[0]), 1)
+ # unpack scales
+ w_scale = precision_config.weight_scale
+ w_has_mx = w_scale is not None
+ is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
+ if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
+ if not isinstance(w, Tensor):
+ # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
+ dtype = FP4 if w.dtype == torch.uint8 else w.dtype
+ w = wrap_torch_tensor(w, dtype=dtype)
+ if w_scale is not None and not isinstance(w_scale, Tensor):
+ w_scale = Tensor(w_scale)
+ if w_scale is not None:
+ w_scale.storage.data = w_scale.data.view(torch.uint8)
+ w_scale.dtype = torch.uint8
+ x_scale = precision_config.act_scale
+ x_has_mx = x_scale is not None
+ if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp"
+ if x_scale is not None and not isinstance(x_scale, Tensor):
+ x_scale = Tensor(x_scale)
+ if not isinstance(x, Tensor):
+ x = Tensor(x, dtype=x.dtype)
+ # determine shapes
+ has_gather = gather_indx is not None
+ has_scatter = scatter_indx is not None
+ is_ragged = routing_data.expt_hist is not None
+ M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
+ batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
+ K, N = w.shape[-2:]
+ assert K == x.shape[-1]
+ if x.ndim == 3 and w.ndim == 3:
+ assert x.shape[0] == w.shape[0]
+ # compute optimization flags
+ out_dtype = precision_config.out_dtype or x.dtype
+ can_use_tma = x.numel() > 0 and x.storage.is_tma_compliant() and \
+ w.numel() > 0 and w.storage.is_tma_compliant() and \
+ (w_scale is None or w_scale.storage.is_tma_compliant())
+ # hopper w/ mxfp4 doesn't support TMA
+ can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
+ can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
+ opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
+ M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
+ )
+ if not can_use_fused_scatter and opt_flags.fused_scatter:
+ raise InapplicableConstraint("Fused scatter is not supported")
+ if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
+ raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
+ if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
+ raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
+ # fused activation
+ matmul_fused_activation = fused_activation
+ reduce_fused_activation = FusedActivation()
+ if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
+ matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation
+ # allocate output/scratchpad memory
+ allocation = init_allocation(x, w, precision_config, fused_activation,
+ routing_data, gather_indx, scatter_indx, opt_flags)
+ memory = apply_allocation(allocation, y)
+ # early exit
+ if batch_size * M * N == 0:
+ ret = memory["output"].squeeze(0)
+ if not is_input_batched:
+ ret = ret.squeeze(0)
+ return ret
+ # TMA descriptors require a global memory allocation
+ if opt_flags.is_persistent:
+ triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
+ # Intermediate tensors and postprocess kernels for each situation
+ has_scratchpad = "matmul" in memory["scratchpad"]
+ # Canonical output tensor (matmul scratchpad if present, otherwise final output tensor)
+ out_matmul = memory["scratchpad"].get("matmul", memory["output"])
+ out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data
+ # Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer
+ out_matmul_scale = precision_config.out_scale
+ if out_matmul_scale is not None:
+ out_matmul_scale = out_matmul_scale.data.view(torch.uint8)
+ if has_scratchpad and "mx_out_scale" in memory["scratchpad"]:
+ out_matmul_scale = memory["scratchpad"]["mx_out_scale"]
+ out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1
+ # matrix multiplication
+ flex = precision_config.flex_ctx
+ bias_stride = None if bias is None else bias.stride(0)
+ num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
+ # moe metadata
+ expt_data = routing_data.expt_data
+ block_m = opt_flags.block_m
+ expt_hist = None if expt_data is None else expt_data.hist
+ expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
+ expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
+ expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
+ # spmd grid
+ grid_m = triton.cdiv(M, opt_flags.block_m)
+ if expt_block_pid_map is not None:
+ grid_m = routing_data.n_blocks(M, opt_flags.block_m)
+ grid_n = triton.cdiv(N, opt_flags.block_n)
+ max_grid = batch_size * grid_m * grid_n * opt_flags.split_k
+ grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
+ # canonicalize storage
+ has_gather_tma = has_gather and target_info.has_tma_gather()
+ has_scatter_tma = opt_flags.fused_scatter and target_info.has_tma_gather()
+ y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if opt_flags.fused_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:]))
+ x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
+ w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
+ y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)
+ # create tma descriptor for x
+ x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
+ x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k]
+ x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
+ x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data
+ # create tma descriptor for y
+ y_has_tma = opt_flags.is_persistent and (has_scatter_tma or not opt_flags.fused_scatter)
+ block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.reduction_n
+ y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n]
+ y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
+ y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data
+ # create tma descriptor for w
+ w_has_tma = opt_flags.is_persistent
+ w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
+ # create tma descriptor for w_scale
+ w_scale_tensor_or_tma = w_scale
+ w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
+ w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale
+ # canonicalize strides
+ x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
+ x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
+ x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
+ w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
+ w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
+ out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None)
+ out_matmul_scale_strides = (0, ) * (3 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
+ # launch kernel
+ kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs)
+ # When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
+ # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
+ # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
+ # w_transpose = w_storage.data.stride()[-1] != 1
+ w_transpose = w_storage.data.stride()[-2] == 1
+ (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
+ y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
+ *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
+ *out_matmul_scale_strides[-3:],
+ x_tensor_or_tma, x_storage.data, *x_strides,
+ flex.lhs_data.scale,
+ None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
+ w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,
+ flex.rhs_data.scale,
+ w_scale_tensor_or_tma, *w_scale_strides,
+ bias, bias_stride,
+ x.shape[-2],
+ x.shape[-2] if routing_data.expt_hist is None else None,
+ N, K,
+ betas, gammas,
+ None if gather_indx is None else gather_indx.src_indx,
+ None if scatter_indx is None else scatter_indx.src_indx,
+ num_indx,
+ None if not opt_flags.fused_scatter else scatter_indx.dst_indx,
+ None if not opt_flags.fused_scatter else scatter_indx.dst_indx.shape[0],
+ expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
+ batch_size, grid_m, grid_n,
+ out_alpha,
+ *matmul_fused_activation.fn_args, matmul_fused_activation.reduction_n,
+ *epilogue.fn_arg_values_matmul,
+ routing_data.n_expts_tot, routing_data.n_expts_act,
+ precision_config.max_num_imprecise_acc,
+ precision_config.allow_tf32,
+ precision_config.flexpoint_saturate_inf,
+ flex.rhs_data.is_per_batch,
+ opt_flags.block_m,
+ opt_flags.block_n,
+ opt_flags.block_k,
+ opt_flags.group_m,
+ XCD_SWIZZLE=opt_flags.xcd_swizzle,
+ SWIZZLE_MX_VALUE=w.storage.layout.name,
+ SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name,
+ EPILOGUE_SUBTILE=opt_flags.epilogue_subtile,
+ SPLIT_K=opt_flags.split_k,
+ EVEN_K=K % opt_flags.block_k == 0,
+ W_CACHE_MODIFIER=opt_flags.w_cache_modifier,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt,
+ num_warps=opt_flags.num_warps,
+ num_stages=opt_flags.num_stages,
+ arch=opt_flags.arch,
+ UPCAST_INDICES=should_upcast_indices(x, w, out_matmul),
+ X_TMA_MODE=x_tma_mode,
+ Y_TMA_MODE=y_tma_mode,
+ SWAP_XW=get_swap_xw(precision_config, opt_flags),
+ IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name,
+ NUM_SMS = grid if opt_flags.is_persistent else 0,
+ **opt_flags.target_kernel_kwargs)
+ # Build grouped reduction inputs in a uniform way
+ group_indx = None if scatter_indx is None or opt_flags.fused_scatter else scatter_indx.src_indx.view(-1, routing_data.n_expts_act)
+ out_final, out_final_mx_scale = reduce_grouped(
+ out_matmul,
+ group_indx,
+ memory["output"].squeeze(0),
+ precision_config.out_scale,
+ reduce_fused_activation,
+ epilogue,
+ x_flex=InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale),
+ out_flex=precision_config.flex_ctx.out_data,
+ x_mx_scale=out_matmul_scale.squeeze(1) if out_matmul_has_mx else None,
+ out_dtype=memory["output"].dtype,
+ flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf,
+ )
+ if not is_input_batched:
+ out_final = out_final.squeeze(0)
+ if out_final_mx_scale is not None:
+ precision_config.out_scale = out_final_mx_scale
+ return out_final
+
+# -----------------------------------------------------------------------------
+# Reference Implementation
+# -----------------------------------------------------------------------------
+
+def matmul_ogs_torch(x, w, bias,
+ routing_data: RoutingData = None,
+ gather_indx: GatherIndx = None,
+ scatter_indx: ScatterIndx = None,
+ precision_config: PrecisionConfig = None,
+ betas = None,
+ gammas = None,
+ round_x = None, round_y = None,
+ ):
+ is_input_batched = x.ndim == 3
+ assert x.dtype.itemsize > 1
+ assert w.dtype.itemsize > 1
+ if is_input_batched:
+ assert gather_indx is None, "gather not supported in batched mode"
+ assert scatter_indx is None, "scatter not supported in batched mode"
+ assert routing_data is None, "routing not supported in batched mode"
+ assert w.ndim == 3 and w.shape[0] == x.shape[0]
+ if round_x is None:
+ round_x = lambda x, idx: x
+ if round_y is None:
+ round_y = lambda x: x
+ if bias is not None and bias.ndim == 1:
+ bias = bias.view(1, *bias.shape)
+ if w.ndim == 2:
+ w = w.view(1, *w.shape)
+ if x.ndim == 2:
+ x = x.view(1, *x.shape)
+ if routing_data is None:
+ routing_data = RoutingData(None, None, w.shape[0], 1)
+ n_expts_act = routing_data.n_expts_act
+ # memory offsets
+ if routing_data.n_expts_tot > 1 and not is_input_batched:
+ sizes = routing_data.expt_hist
+ off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32)
+ off[1:] = torch.cumsum(sizes, 0)
+ offs = list(itertools.pairwise(off))
+ else:
+ offs = [[0, x.shape[1]] for _ in range(w.shape[0])]
+ # compute
+ n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0]
+ y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype)
+ for i, (lo, hi) in enumerate(offs):
+ if gather_indx is None:
+ idx = torch.arange(lo, hi, device=x.device)
+ else:
+ idx = gather_indx.src_indx[lo:hi] // n_expts_act
+ batch = i if is_input_batched else 0
+ out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
+ w[i].float())
+ if bias is not None:
+ out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
+ if gammas is not None:
+ out *= gammas[lo:hi, None]
+ y[batch, lo:hi, :] = round_y(out)
+ if not is_input_batched:
+ y = y.view(y.shape[1], y.shape[2])
+ if scatter_indx is None:
+ return y
+ # accumulate output from all experts
+ n_rows = y.shape[0] // n_expts_act
+ out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device)
+ for i, (lo, hi) in enumerate(offs):
+ dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act
+ msk = dst_idx != -1
+ out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float()
+ return out
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_common.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d5c99493872d779643aff2a9f7293685d8c4f2b
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_common.py
@@ -0,0 +1,179 @@
+import torch
+
+import triton
+import triton.language as tl
+
+# -----------------------------------------------------------------------------
+# Utilities
+# -----------------------------------------------------------------------------
+
+
+@triton.constexpr_function
+def get_scaled_dot_format_string(dtype: tl.dtype):
+ mapping = {
+ tl.float16: "fp16",
+ tl.bfloat16: "bf16",
+ tl.uint8: "e2m1",
+ tl.float8e4nv: "e4m3",
+ tl.float8e5: "e5m2",
+ }
+ return mapping[dtype]
+
+
+@triton.jit
+def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr):
+ """
+ Swizzle the program id based on integer XCD_SWIZZLE.
+ This is useful for reording how blocks are ordered. A scheduler may, for example,
+ assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2.
+ This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment
+ becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to
+ the same hardware unit.
+ """
+ # Number of pids per group in the new arrangement
+ pids_per_group = domain_size // XCD_SWIZZLE
+ extra_pid_groups = domain_size % XCD_SWIZZLE
+
+ # Compute current current and local pid within the group
+ group = pid % XCD_SWIZZLE
+ local_pid = pid // XCD_SWIZZLE
+
+ # Calculate new pid based on the new grouping
+ new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid
+ return new_pid
+
+
+@triton.jit
+def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr):
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ tl.assume(group_size >= 0)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ return pid_m, pid_n
+
+
+def make_matmul_repr(base_name, order):
+ def matmul_repr(specialization):
+ signature = specialization.signature
+ constants = specialization.constants
+ reorder = lambda L: [L[i] for i in order]
+ layout = lambda stride: "N" if stride in constants else "T"
+
+ def convert_dtype(dtype):
+ if "tensordesc" in dtype:
+ ret = convert_dtype(dtype.split("<")[1].split("[")[0])
+ return ret
+ elif "u8" in dtype:
+ return "mxfp4"
+ elif dtype[0] == "*":
+ return dtype[1:]
+ else:
+ return dtype
+
+ dtypes = "x".join(
+ [convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])]
+ )
+ layouts = "".join(
+ [
+ f"{layout(i)}"
+ for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])
+ ]
+ )
+ blocks = "x".join(
+ [f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]]
+ )
+ # mode = []
+ # if "GatherIndx" not in constants:
+ # mode += ['g']
+ # if "ScatterSrcIndx" not in constants:
+ # mode += ['s']
+ # suffix = "" if not mode else "_o" + (''.join(mode))
+ # if base_name.startswith("_p"):
+ # suffix += "_ptma"
+ return f"{base_name}_{layouts}_{dtypes}_{blocks}"
+
+ return matmul_repr
+
+
+def matmul_launch_metadata(grid, kernel, args):
+ from ..proton_opts import launch_metadata_allow_sync
+
+ ret = dict()
+ M, N, K = args["M"], args["N"], args["K"]
+ Y, X, W = args["YPtr"], args["XPtr"], args["WPtr"]
+ tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
+ hist = args["ExptHist"]
+ if hist is not None:
+ # If annotation is given, use that to generate name for profiling.
+ if tokens_per_expt is not None:
+ n_rows = f"{tokens_per_expt}*"
+ elif launch_metadata_allow_sync():
+ n_rows = int(hist.float().mean())
+ else:
+ n_rows = "unknown"
+
+ if launch_metadata_allow_sync():
+ n_tokens = float(hist.sum())
+ n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (
+ hist > 0
+ ).sum()
+ elif tokens_per_expt is not None:
+ n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
+ # This may not be totally correct (e.g., we might not be using all experts)
+ # but it's better than nothing.
+ n_w_bytes = W.numel() * W.element_size()
+ else:
+ n_tokens = None
+ n_w_bytes = 0
+
+ # If annotation is given, use that to generate name for profiling.
+ tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
+ n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
+ else:
+ n_tokens = None
+ n_w_bytes = W.numel() * W.element_size()
+ repr = (
+ lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}"
+ )
+ nbits = X.dtype.itemsize * 8
+ batch_repr = ""
+ if "batch_size" in args and args["batch_size"] > 1:
+ batch_repr = repr("B", args["batch_size"]) + ", "
+ ret["name"] = (
+ f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}"
+ )
+ ep_subtile = args["EPILOGUE_SUBTILE"]
+ if ep_subtile is not None and ep_subtile > 1:
+ ret["name"] += f" ep/{ep_subtile}"
+
+ if hist is not None and n_tokens is None:
+ return ret # Don't fill metadata because we can't compute them properly.
+
+ fM = M if M is not None else n_tokens
+ fK = K if K is not None else n_tokens
+ ret[f"flops{nbits}"] = 2.0 * fM * N * fK
+
+ gindx = args.get("GatherIndx", None)
+ # sindx = args.get("WriteBackIndx", None)
+ n_x_bytes = X.numel() * X.element_size()
+ n_y_bytes = Y.numel() * Y.element_size()
+ if hist is not None:
+ assert n_tokens is not None
+ n_expts_act = args["N_EXPTS_ACT"]
+
+ if (gindx is not None) and launch_metadata_allow_sync():
+ # recreate inverse GatherIndx.
+ dst = torch.full_like(gindx, -1)
+ idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
+ mask = gindx != -1
+ dst[gindx[mask]] = idx[mask]
+ n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
+ else:
+ n_read_rows = n_tokens
+ n_x_bytes = n_read_rows * X.shape[-1] * X.element_size()
+ n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size()
+ ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes)
+
+ return ret
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea994d20ab32497adc26b8049350b0ad959ee8fe
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_matmul_ogs.py
@@ -0,0 +1,429 @@
+# isort: off
+# fmt: off
+import triton
+import triton.language as tl
+from compactor_vllm.triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
+from compactor_vllm.triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
+from compactor_vllm.triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
+from compactor_vllm.triton_kernels.tensor_details.layout_details.cdna4_scale import unswizzle_mx_scale_cdna4
+from compactor_vllm.triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
+from compactor_vllm.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
+from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
+
+
+@triton.jit
+def _zero_masked_rows(
+ pid_m, pid_n,
+ Y, stride_y_m, stride_y_n,
+ N,
+ ScatterSrcIndx, num_idxs,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
+ offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M)
+ offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
+ src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
+ YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
+ mask_n = offs_n < N
+ mask = (src_idx == -1)[:, None] & mask_n[None, :]
+ tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask)
+
+
+_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
+@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
+ repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
+def _matmul_ogs(
+ Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
+ YExpectedScale, YActualScale, YChecksumScale,
+ stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
+ X, XPtr, stride_x_z, stride_x_m, stride_x_k,
+ XScale,
+ XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
+ W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
+ WScale,
+ WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
+ B, stride_b_e, # Bias
+ NRows, M, N, K, # shapes
+ # expt data
+ Betas, Gammas,
+ GatherIndx,
+ ScatterSrcIndx, num_idxs,
+ WriteBackIndx, writeback_size,
+ ExptHist, ExptOffs, ExptOffsSum, ExptData,
+ # true grid size
+ batch_size, grid_m, grid_n,
+ # Out scale
+ out_alpha,
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
+ # MoE config
+ N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
+ # precision config
+ MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ PER_BATCH_SCALE: tl.constexpr,
+ # optimization config
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
+ # One of ["HOPPER", "BLACKWELL", None]
+ SWIZZLE_MX_VALUE: tl.constexpr,
+ # One of ["HOPPER", "BLACKWELL", None]
+ SWIZZLE_MX_SCALE: tl.constexpr,
+ EPILOGUE_SUBTILE: tl.constexpr,
+ EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
+ W_CACHE_MODIFIER: tl.constexpr,
+ NUM_SMS: tl.constexpr,
+ X_TMA_MODE: tl.constexpr,
+ Y_TMA_MODE: tl.constexpr,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=None,
+ UPCAST_INDICES: tl.constexpr = False,
+ SWAP_XW: tl.constexpr = False,
+ IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
+
+ tl.assume(stride_y_k >= 0)
+ tl.assume(stride_y_z >= 0)
+ tl.assume(stride_y_m >= 0)
+ tl.assume(stride_y_n >= 0)
+ tl.assume(stride_x_z >= 0)
+ tl.assume(stride_x_m >= 0)
+ tl.assume(stride_x_k >= 0)
+ tl.assume(stride_w_e >= 0)
+ tl.assume(stride_w_k >= 0)
+ tl.assume(stride_w_n >= 0)
+ if stride_w_mx_e is not None:
+ tl.assume(stride_w_mx_e >= 0)
+ if stride_w_mx_k is not None:
+ tl.assume(stride_w_mx_k >= 0)
+ if stride_w_mx_n is not None:
+ tl.assume(stride_w_mx_n >= 0)
+ if B is not None:
+ tl.assume(stride_b_e >= 0)
+ tl.assume(batch_size >= 0)
+ tl.assume(grid_m >= 0)
+ tl.assume(grid_n >= 0)
+
+ is_w_microscaled: tl.constexpr = WMxScale is not None
+ MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
+ if is_w_microscaled:
+ w_type: tl.constexpr = W.dtype.element_ty
+ is_mxfp4: tl.constexpr = w_type == tl.uint8
+ tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
+ "mx_weight_ptr must be uint8 or fp8")
+ tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
+ else:
+ tl.static_assert(SWIZZLE_MX_VALUE is None)
+ tl.static_assert(SWIZZLE_MX_SCALE is None)
+ is_x_microscaled: tl.constexpr = XMxScale is not None
+ if is_x_microscaled:
+ x_type: tl.constexpr = X.dtype.element_ty
+ tl.static_assert(is_w_microscaled)
+ tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
+ tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ is_out_microscaled: tl.constexpr = stride_y_mx_z is not None
+
+ OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
+ yN = N // ACTIVATION_REDUCTION_N
+
+ pid = tl.program_id(0)
+ if ExptOffsSum is not None and XCD_SWIZZLE > 1:
+ # Determine how much padding there is on the expert data. This allows us to
+ # know the true grid size and avoid processing padding tiles.
+ padding_m = grid_m - tl.load(ExptOffsSum)
+ else:
+ padding_m: tl.constexpr = 0
+
+ HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
+ index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32
+
+ unpadded_m = grid_m - padding_m
+ tl.assume(unpadded_m >= 0)
+ total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K
+ if padding_m > 0 and pid >= total_actual_tiles:
+ tl.device_assert(batch_size == 0)
+ pid_mn = pid - total_actual_tiles
+ if pid_mn < padding_m * grid_n:
+ pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M)
+
+ # set masked out rows to 0
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
+ _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
+ return
+
+ # swizzle program ids
+ pid_emnk = pid
+ if XCD_SWIZZLE != 1:
+ pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE)
+ pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K)
+ pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K)
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ pid_m, pid_n = swizzle2d(pid_mn, unpadded_m, grid_n, GROUP_M)
+ # For split-k, advance to the output k slice
+ if SPLIT_K > 1:
+ Y += pid_k.to( index_type) * stride_y_k
+ if is_out_microscaled:
+ YActualScale += pid_k.to(index_type) * stride_x_mx_k
+ # set masked out rows to 0
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
+ _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
+ # unpack expert data
+ if ExptData is None:
+ tl.static_assert(M is not None)
+ expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m
+ else:
+ tl.static_assert(M is None)
+ expt_data = tl.load(ExptData + pid_m)
+ if expt_data == -1:
+ return
+ expt_id = expt_data & 0x0000FFFF
+ block_id = expt_data >> 16
+ M = tl.load(ExptHist + expt_id)
+ start_m = tl.load(ExptOffs + expt_id)
+ start_z = 0
+ expt_id, block_id = expt_id.to(index_type), block_id.to(index_type)
+ start_m, start_z = start_m.to(index_type), start_z.to(index_type)
+ pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type)
+ # A pointers
+ offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
+ offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M)
+ X += start_z * stride_x_z
+ if GatherIndx is None:
+ X += start_m * stride_x_m
+ else:
+ GatherIndx += start_m
+ # no needs to bounds-check here because `offs_x_m` wraps around M dim
+ offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT
+ offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K)
+ XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k
+
+ # TODO: refactor if/else when triton front end improves
+ if is_w_microscaled:
+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
+ tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
+ tl.static_assert(not is_x_microscaled)
+ # We have pack 2 fp4 values in a byte but we divide the dimension by 2
+ # when swizzling
+ W_K_DIVISOR: tl.constexpr = 1
+ W_K_MULTIPLIER: tl.constexpr = 2
+ W_N_DIVISOR: tl.constexpr = 4
+ else:
+ # We have pack 2 fp4 values in a byte
+ W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1
+ W_K_MULTIPLIER: tl.constexpr = 1
+ W_N_DIVISOR: tl.constexpr = 1
+
+ if W_TRANSPOSE:
+ # When weight is transposed, 2 fp4 values are packed per Byte along
+ # the contiguous dimension, K.
+ PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR
+ else:
+ # When weight is not transposed, fp4 values are *not* packed along
+ # the contiguous dimension, N.
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_K_DIVISOR
+ MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
+
+ WMxScale += expt_id * stride_w_mx_e
+
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ # TODO: support non W_TRANSPOSE with blackwell swizzling
+ tl.static_assert(W_TRANSPOSE)
+ tl.static_assert(BLOCK_N % 128 == 0)
+ tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0)
+ PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 128
+ stride_scale_k: tl.constexpr = 1
+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
+ # TODO: support non W_TRANSPOSE with Hopper swizzling
+ tl.static_assert(W_TRANSPOSE)
+ n_warps: tl.constexpr = tl.extra.cuda.num_warps()
+ tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0)
+ tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0)
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32
+ stride_scale_k = stride_w_mx_k
+ elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
+ tl.static_assert(stride_w_mx_k is not None)
+ tl.static_assert(stride_w_mx_n is not None)
+ NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE
+ stride_scale_k = stride_w_mx_k
+ else:
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N
+ stride_scale_k = stride_w_mx_k
+ offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
+ offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
+ # K dimension must be the last dimension for the scales
+ offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
+ WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
+ else:
+ WMxScalePtrs = None
+ offs_k_scale = None
+ W_K_DIVISOR: tl.constexpr = 1
+ W_K_MULTIPLIER: tl.constexpr = 1
+ W_N_DIVISOR: tl.constexpr = 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N
+
+ # B pointers
+ offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
+ offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
+
+ if is_x_microscaled:
+ XMxScale += start_z.to(index_type) * stride_x_mx_z
+ if GatherIndx is None:
+ XMxScale += start_m * stride_x_mx_m
+ offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
+ XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
+ else:
+ XMxScalePtrs = None
+
+ offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W)
+ W += expt_id * stride_w_e
+ WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n)
+ # compute output
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)):
+ if EVEN_K:
+ mask_k = tl.full([BLOCK_K], True, dtype=tl.int1)
+ mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1)
+ if is_w_microscaled and SWIZZLE_MX_SCALE is None:
+ mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1)
+ if is_x_microscaled:
+ mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
+ else:
+ mask_k = offs_k < k
+ mask_k_w = offs_w_k < ((k // (W_K_DIVISOR if W_TRANSPOSE else 1)) * W_K_MULTIPLIER)
+ if is_w_microscaled and SWIZZLE_MX_SCALE is None:
+ mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k
+ if is_x_microscaled:
+ mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < k
+
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+ w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER)
+ if is_w_microscaled:
+ x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
+ w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
+
+ if is_x_microscaled:
+ x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :])
+ elif x_format == "fp16" or x_format == "bf16":
+ x_scales: tl.constexpr = None
+ else:
+ # Scale of 1 in E8M0 format
+ x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8)
+
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ w_scales = unswizzle_mx_scale_bw(tl.load(WMxScalePtrs))
+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
+ # Handshake with the swizzling code
+ num_warps: tl.constexpr = tl.extra.cuda.num_warps()
+ w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps)
+ elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
+ w_scales = unswizzle_mx_scale_cdna4(tl.load(WMxScalePtrs), BLOCK_N, MX_SCALE_BLOCK_K)
+ else:
+ w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :])
+
+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
+ # Handshake with the swizzling code
+ tl.static_assert(x_format == "bf16")
+ tl.static_assert(w_format == "e2m1")
+ w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1)
+ tl.static_assert(w.dtype == tl.bfloat16)
+ acc = acc.trans()
+ x = x.trans()
+ # w = w.trans()
+ acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ acc = acc.trans()
+ else:
+ rhs_k_pack: tl.constexpr = W_TRANSPOSE or not is_w_microscaled or W_K_DIVISOR != 2
+ acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True, rhs_k_pack=rhs_k_pack)
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ WMxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_w_mx_k
+ else:
+ WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k
+ if is_x_microscaled:
+ XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k
+ else:
+ acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k
+ WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k
+ # bias + scale
+ offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
+ offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
+ mask_m = offs_m < M
+ mask_n = offs_y_n < N
+ if B is not None:
+ BPtrs = B + expt_id * stride_b_e + offs_y_n
+ if pid_k == 0:
+ bias = tl.load(BPtrs, mask=mask_n, other=0)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ if Betas is not None:
+ betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0)
+ else:
+ betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ if Gammas is not None:
+ gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0)
+ else:
+ gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ # flexpoint
+ x_scale = load_scale(XScale)
+ if PER_BATCH_SCALE:
+ w_scale = load_scale(WScale + expt_id)
+ else:
+ w_scale = load_scale(WScale)
+ acc *= x_scale * w_scale
+ acc = acc + bias[None, :] * betas[:, None]
+ if out_alpha is not None:
+ acc *= out_alpha
+ if ACTIVATION_FN is not None:
+ out = ACTIVATION_FN(acc, *activation_fn_args)
+ tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
+ offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N)
+ mask_n = offs_y_n < yN
+ else:
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
+ out = acc
+ out *= gammas[:, None]
+ # write-back
+ Y += start_z.to(index_type) * stride_y_z
+ if WriteBackIndx is not None:
+ WriteBackIndx += start_m
+ dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1)
+ mask_m = mask_m & (dst_idx != -1)
+ offs_y_m = dst_idx
+ else:
+ Y += start_m * stride_y_m
+ offs_y_m = offs_m
+
+ YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n
+ mask = mask_m[:, None] & mask_n[None, :]
+ if is_out_microscaled:
+ MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
+ N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
+ tl.static_assert(EPILOGUE_FN is not None)
+ out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
+ tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
+ offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N)
+ mask_n_scale = offs_y_n_scale < N_MX_BLOCK
+ YActualScale += start_z.to(index_type) * stride_y_mx_z
+ if WriteBackIndx is None:
+ YActualScale += start_m * stride_y_mx_m
+ YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
+ else:
+ YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
+ tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
+ else:
+ out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)
+ if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8:
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty)
+ tl.store(YPtrs, out, mask=mask)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..74e254a8eaba62c4f228ad9691c36824d1409912
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
@@ -0,0 +1,471 @@
+# isort: off
+# fmt: off
+import torch
+import triton
+import triton.language as tl
+from triton.tools.ragged_tma import load_ragged, store_ragged
+from compactor_vllm.triton_kernels import target_info
+from compactor_vllm.triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
+from compactor_vllm.triton_kernels.numerics_details.flexpoint import (
+ float_to_flex,
+ load_scale,
+ nan_propagating_absmax_reduce,
+ compute_scale,
+)
+from compactor_vllm.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
+from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
+
+
+@triton.constexpr_function
+def cuda_capability_geq(major, minor):
+ return target_info.cuda_capability_geq(major, minor)
+
+@triton.constexpr_function
+def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
+ if isinstance(tensor_or_desc, tl.tensor):
+ return tensor_or_desc.dtype.element_ty
+ elif isinstance(tensor_or_desc, tl.tensor_descriptor):
+ return tensor_or_desc.dtype
+ else:
+ raise ValueError(f"Invalid type: {type(tensor_or_desc)}")
+
+@triton.jit
+def _load_tile_attrs(
+ tile_id, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr):
+ # unpack and swizzle program ids
+ pid_emnk = tile_id
+ if XCD_SWIZZLE != 1:
+ pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE)
+ pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
+ pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
+ if SPLIT_K > 1:
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ else:
+ pid_k: tl.constexpr = 0
+ pid_mn = pid_mnk
+ pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
+
+ # unpack expert data
+ if ExptData is None:
+ tl.static_assert(M is not None)
+ expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1
+ else:
+ tl.static_assert(M is None)
+ expt_data = tl.load(ExptData + pid_m)
+ expt_id = expt_data & 0x0000FFFF
+ block_id = expt_data >> 16
+ eM = tl.load(ExptHist + expt_id)
+ start_m = tl.load(ExptOffs + expt_id)
+ start_z = 0
+
+ off_m = BLOCK_M * block_id
+ off_n = BLOCK_N * pid_n
+
+ return expt_id, start_z, start_m, eM, off_m, off_n, pid_k
+
+@triton.jit
+def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
+ mask = mask & (offs < writeback_size)
+ offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1)
+ mask = offs != -1
+ return (offs, mask)
+
+
+_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2])
+@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
+ repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
+def _p_matmul_ogs(
+ Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
+ YExpectedScale, YActualScale, YChecksumScale,
+ stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
+ X, XPtr, stride_x_z, stride_x_m, stride_x_k,
+ XScale,
+ XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
+ W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
+ WScale,
+ MxScale, stride_mx_e, stride_mx_k, stride_mx_n,
+ B, stride_b_e, # Bias
+ NRows, M, N, K, # shapes
+ # expt data
+ Betas, Gammas,
+ GatherIndx,
+ ScatterSrcIndx, num_idxs,
+ WriteBackIndx, writeback_size,
+ ExptHist, ExptOffs, ExptOffsSum, ExptData,
+ # true grid size
+ batch_size, grid_m, grid_n,
+ # Out scale
+ out_alpha,
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
+ # MoE config
+ N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
+ # precision config
+ MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ PER_BATCH_SCALE: tl.constexpr,
+ # optimization config
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
+ # NYI: Must be None
+ SWIZZLE_MX_VALUE: tl.constexpr,
+ # One of ["BLACKWELL", None]
+ SWIZZLE_MX_SCALE: tl.constexpr,
+ EPILOGUE_SUBTILE: tl.constexpr,
+ EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
+ W_CACHE_MODIFIER: tl.constexpr,
+ NUM_SMS: tl.constexpr,
+ X_TMA_MODE: tl.constexpr,
+ Y_TMA_MODE: tl.constexpr,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=None,
+ UPCAST_INDICES:tl.constexpr=False,
+ SWAP_XW: tl.constexpr = False,
+ IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
+ # tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")
+
+ # why is this faster than using host-side tensor descriptor?!
+ if Y_TMA_MODE is not None:
+ Y = tl.make_tensor_descriptor(YPtr, Y.shape, Y.strides[:-1] + (1,), Y.block_shape)
+
+ is_microscaled_format: tl.constexpr = MxScale is not None
+ tl.static_assert(not is_microscaled_format or W_TRANSPOSE, "NYI. Non-transposed mxfp4 weights")
+ MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
+ if is_microscaled_format:
+ w_type: tl.constexpr = get_dtype(W)
+ tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
+ "mx_weight_ptr must be uint8")
+ tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
+
+ # We have pack 2 fp4 values in a byte
+ W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR
+ MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
+ else:
+ W_PACK_DIVISOR: tl.constexpr = 1
+ MX_SCALE_BLOCK_K: tl.constexpr = 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ tl.static_assert(SWIZZLE_MX_SCALE is None)
+
+ if ExptOffsSum is not None:
+ # Determine how much padding there is on the expert data. This allows us to
+ # know the true grid size and avoid processing padding tiles.
+ padding_m = grid_m - tl.load(ExptOffsSum)
+ else:
+ padding_m: tl.constexpr = 0
+
+ index_type: tl.constexpr = tl.int64
+
+ USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None
+ HAS_SCATTER: tl.constexpr = WriteBackIndx is not None
+ HAS_GATHER: tl.constexpr = GatherIndx is not None
+ USE_GATHER_TMA: tl.constexpr = HAS_GATHER and X_TMA_MODE == "dense"
+ USE_SCATTER_TMA: tl.constexpr = HAS_SCATTER and Y_TMA_MODE == "dense"
+
+ if EPILOGUE_SUBTILE is None:
+ SUBTILE_FACTOR: tl.constexpr = 1
+ else:
+ SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE
+ EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR
+ OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N
+ yN = N // ACTIVATION_REDUCTION_N
+
+ # set masked out rows to 0
+ if HAS_SCATTER and N_EXPTS_ACT == 1:
+ # Iterate with reversed pids so that later pids will get more tiles if the number of
+ # tiles isn't evenly divisible by the number of SMs.
+ # The main loop after this iterates in the forward direction such that earlier
+ # pids get more tiles if the number of tiles isn't evenly divisible.
+ # This helps balance the work across the SMs.
+ for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS):
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M)
+
+ z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32)
+ offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0])
+ offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1])
+ src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
+ YPtrs = YPtr + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
+ mask_n = offs_n < yN
+ mask = (src_idx == -1)[:, None] & mask_n[None, :]
+ tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask)
+
+
+ k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K)
+ num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
+
+ # If true, do not share loop-carried variables between the prologue and the
+ # epilogue to enable better pipelining with mmav5
+ INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0)
+
+ # start negative; will be incremented at the top of the loop
+ if INDEPENDENT_EPILOGUE:
+ tile_id1 = tl.program_id(0) - NUM_SMS
+
+ # Keep track of local max for updating flexpoint scales.
+ THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
+ local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
+
+ DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
+
+ for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True):
+ expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(
+ tile_id, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M, BLOCK_N, SPLIT_K,
+ GROUP_M, XCD_SWIZZLE)
+
+ # Base pointers and offsets.
+ if X_TMA_MODE is None:
+ XBase = X + start_z.to(index_type) * stride_x_z
+ offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k
+ if SPLIT_K > 1:
+ offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k
+
+ if USE_GATHER_TMA:
+ offs_m = off_m + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < (M if M is not None else eM)
+ if ExptData is None:
+ offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m)
+ # Bump rows to account for the Z offset.
+ offs_x_m += start_z * (stride_x_z // stride_x_m)
+ offs_x_m = tl.where(mask_m, offs_x_m, -1)
+ else:
+ offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
+ mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
+ elif X_TMA_MODE is None:
+ tl.static_assert(HAS_GATHER)
+ offs_m = off_m + tl.arange(0, BLOCK_M)
+ if M is not None:
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M)
+ else:
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
+ # no needs to bounds-check here because `offs_m` wraps around M dim
+ offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
+ offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
+
+
+ acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
+ off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
+ off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
+ off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
+
+ # --- load x ---
+ if USE_GATHER_TMA:
+ x = X.gather(offs_x_m, off_k)
+ elif X_TMA_MODE == "dense":
+ x = X.load([start_z, start_m + off_m, off_k])
+ x = x.reshape(BLOCK_M, BLOCK_K)
+ elif X_TMA_MODE == "ragged":
+ x = load_ragged(X, start_m, eM, [start_z, off_m, off_k], ragged_dim=1)
+ x = x.reshape(BLOCK_M, BLOCK_K)
+ else:
+ tl.static_assert(X_TMA_MODE is None)
+ XPtrs = XBase + offs_x_m + offs_x_k
+ XBase += BLOCK_K * SPLIT_K * stride_x_k
+ mask_k = tl.arange(0, BLOCK_K) < K - off_k
+ if EVEN_K:
+ if SPLIT_K > 1:
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+ else:
+ x = tl.load(XPtrs)
+ else:
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+
+ # --- load w ---
+ if W_TRANSPOSE:
+ w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:]).T
+ else:
+ w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:])
+
+ # --- load w_scale ---
+ if is_microscaled_format:
+ x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
+ mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
+ if x_format == "fp16" or x_format == "bf16":
+ x_scales: tl.constexpr = None
+ else:
+ x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8)
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128)
+ w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
+ w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
+ w_scales = unswizzle_mx_scale_bw(w_scales)
+ else:
+ w_scales = MxScale.load([expt_id, off_k_mx, off_n])
+ w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T
+
+ # --- update accumulator ---
+ if is_microscaled_format:
+ if SWAP_XW:
+ acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
+ else:
+ acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
+ else:
+ if SWAP_XW:
+ acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ else:
+ acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+
+ if INDEPENDENT_EPILOGUE:
+ tile_id1 += NUM_SMS
+ expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs(
+ tile_id1, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M, BLOCK_N, SPLIT_K,
+ GROUP_M, XCD_SWIZZLE)
+ else:
+ tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM
+ off_m1, off_n1, pid_k1 = off_m, off_n, pid_k
+
+ offs_m = off_m1 + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < (M if M is not None else eM1)
+ if USE_SCATTER_TMA:
+ offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
+ MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
+ if SPLIT_K > 1:
+ # Compute the split k offset in number of rows, and add it to offs_y_m.
+ # This allows us to write to the correct slice in the output tensor while using
+ # a 2D TMA scatter.
+ tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m))
+ split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m)
+ offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m)
+ elif Y_TMA_MODE is None:
+ tl.static_assert(HAS_SCATTER)
+ offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
+ MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
+ else:
+ offs_y_m = start_m1 + offs_m
+ MASK_ACC = False if USE_GATHER_TMA else USE_FLEXPOINT_SCALE
+
+ # bias + scale
+ offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
+ mask_n = offs_y_n < N
+ if B is not None:
+ BPtrs = B + expt_id1 * stride_b_e + offs_y_n
+ if pid_k1 == 0:
+ bias = tl.load(BPtrs, mask=mask_n, other=0)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ if Betas is not None:
+ betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
+ else:
+ betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ if Gammas is not None:
+ gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0)
+ else:
+ gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ x_scale = load_scale(XScale)
+ if PER_BATCH_SCALE:
+ w_scale = load_scale(WScale + expt_id1)
+ else:
+ w_scale = load_scale(WScale)
+
+ accs = (acc,)
+ biases = (bias,)
+
+ if SUBTILE_FACTOR >= 2:
+ acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
+ accs = (acc0, acc1)
+ bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
+ biases = (bias0, bias1)
+
+ if SUBTILE_FACTOR >= 4:
+ acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
+ acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
+ accs = (acc00, acc01, acc10, acc11)
+ bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
+ bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
+ biases = (bias00, bias01, bias10, bias11)
+
+ tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
+ tl.static_assert(len(accs) == SUBTILE_FACTOR)
+
+ for a_i in tl.static_range(len(accs)):
+ acc_tile = accs[a_i]
+ acc_tile *= x_scale * w_scale
+
+ if SWAP_XW:
+ acc_tile = acc_tile.T
+
+ acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
+ if out_alpha is not None:
+ acc_tile *= out_alpha
+
+ if ACTIVATION_FN is not None:
+ out = ACTIVATION_FN(acc_tile, *activation_fn_args)
+ tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
+ else:
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
+ out = acc_tile
+
+ out *= gammas[:, None]
+
+ if MASK_ACC:
+ out = tl.where(mask_m[:, None], out, 0.0)
+ # Flexpoint
+ out_view = tl.reshape(out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
+ local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
+ out = float_to_flex(
+ out, YExpectedScale,
+ None, # ActualScale: local absmax is tracked and updated after the loop
+ YChecksumScale,
+ None, # mask: out is manually masked to 0
+ YPtr, FLEXPOINT_SATURATE_INF
+ )
+ if EPILOGUE_FN is not None:
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
+
+ out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
+ out = out.to(YPtr.dtype.element_ty)
+ if USE_SCATTER_TMA:
+ # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
+ # there shouldn't be any other negative values.
+ offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True)
+ Y.scatter(out, offs_y_m, out_off_n)
+ elif Y_TMA_MODE == "dense":
+ out = tl.reshape(out, [1] + out.shape)
+ off_kz = pid_k * batch_size + start_z1
+ Y.store([off_kz, off_m1, out_off_n], out)
+ elif Y_TMA_MODE == "ragged":
+ out = tl.reshape(out, [1] + out.shape)
+ store_ragged(Y, start_m1, eM1, [pid_k, off_m1, out_off_n], out, ragged_dim=1)
+ else:
+ tl.static_assert(Y_TMA_MODE is None)
+ offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
+ mask_n = offs_y_n < yN
+
+ YPtrs = YPtr + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n
+ mask = mask_m[:, None] & mask_n[None, :]
+ tl.store(YPtrs, out, mask=mask)
+
+
+ # Update the flexpoint scales
+ if YActualScale is not None:
+ tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), YPtr), sem="relaxed")
+
+
+_per_device_alloc_fns = {}
+def get_per_device_per_stream_alloc_fn(device):
+ if device not in _per_device_alloc_fns:
+ _per_stream_tensors = {}
+ def alloc_fn(size: int, alignment: int, stream):
+ assert alignment == 128
+ if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size:
+ _per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8)
+ _per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
+ return _per_stream_tensors[stream]
+
+ _per_device_alloc_fns[device] = alloc_fn
+ return _per_device_alloc_fns[device]
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_reduce_grouped.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_reduce_grouped.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfb58a4b1b81be577d8fd3e50f64ed07dda1d6cf
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_reduce_grouped.py
@@ -0,0 +1,126 @@
+from compactor_vllm.triton_kernels.numerics_details.flexpoint import (
+ float_to_flex,
+ load_scale,
+)
+from compactor_vllm.triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _reduce_grouped(
+ X,
+ stride_xb: tl.uint64,
+ stride_xm: tl.uint64,
+ stride_xn, #
+ XScale, # input scalar flex scale
+ Out,
+ stride_om: tl.uint64,
+ stride_on, # output tensor
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale, # output scalar flex scales
+ InIndx,
+ B,
+ N, #
+ XMxScale,
+ stride_mxb: tl.uint64,
+ stride_mxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
+ OutMxScale,
+ stride_omxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr,
+ activation_fn_args,
+ ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr,
+ epilogue_fn_args,
+ #
+ HAS_IN_MX_SCALE: tl.constexpr,
+ HAS_OUT_MX_SCALE: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ K: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ pid_t = tl.program_id(0)
+ BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
+ # persistent along N: single program on N, iterate tiles of size BLOCK_N
+ start = pid_t * K
+ # load indices into a tuple
+ if InIndx is None:
+ indxs = (pid_t,)
+ else:
+ indxs = ()
+ for i in tl.static_range(0, K):
+ indxs = indxs + (tl.load(InIndx + start + i),)
+ # determine first valid topk row
+ fi = indxs[(K - 1)]
+ for i in tl.static_range(K - 2, -1, -1):
+ fi = tl.where(indxs[i] != -1, indxs[i], fi)
+ # record overwritten row index (may be -1 if none)
+ XPtrs = X + tl.arange(0, BLOCK_N) * stride_xn
+ OutPtrs = Out + tl.arange(0, BLOCK_N_OUT) * stride_on
+ if HAS_IN_MX_SCALE:
+ XScalePtrs = XMxScale + tl.arange(0, BLOCK_N // 32) * stride_xn
+ if HAS_OUT_MX_SCALE:
+ OutScalePtrs = OutMxScale + tl.arange(0, BLOCK_N_OUT // 32) * stride_on
+ x_scale = load_scale(XScale)
+ for n_curr in tl.range(0, N, BLOCK_N, num_stages=4):
+ acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)
+ x_n_mask = tl.arange(0, BLOCK_N) < N - n_curr
+ x_n_mask_scale = tl.arange(0, BLOCK_N // 32) < tl.cdiv(N - n_curr, 32)
+ # accumulate contributions for this tile
+ for i in tl.static_range(0, K):
+ curr = tl.zeros([BLOCK_N], dtype=tl.float32)
+ # iterate over split_k partial values
+ for b in tl.range(0, B):
+ is_valid = indxs[i] != -1
+ x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb
+ vals = tl.load(x_row_ptr, mask=x_n_mask & is_valid, other=0.0)
+ vals = vals.to(tl.float32)
+ if HAS_IN_MX_SCALE:
+ scale_row_ptr = XScalePtrs + indxs[i] * stride_mxs + b * stride_mxb
+ scale = tl.load(
+ scale_row_ptr, mask=x_n_mask_scale & is_valid, other=0.0
+ )
+ scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
+ vals = vals.reshape([BLOCK_N // 32, 32])
+ vals = (scale[:, None] * vals).reshape([BLOCK_N])
+ curr += vals
+ # apply nonlinearity to split-k output
+ if ACTIVATION_FN is not None:
+ curr = ACTIVATION_FN(curr[None, :], *activation_fn_args)
+ curr = tl.reshape(curr, [curr.shape[-1]])
+ # update final accumulator
+ acc += curr
+ acc *= x_scale
+ # Compute per-32-col MXFP scales for this tile if requested
+ Nrem = (N - n_curr) // ACTIVATION_REDUCTION_N
+ out_n_mask = tl.arange(0, BLOCK_N_OUT) < Nrem
+ out_n_mask_scale = tl.arange(0, BLOCK_N_OUT // 32) < tl.cdiv(Nrem, 32)
+ if HAS_OUT_MX_SCALE:
+ acc, acc_scale = quantize_mxfp8_fn(acc[None, :], out_n_mask[None, :])
+ acc = tl.reshape(acc, [acc.shape[-1]])
+ acc_scale = tl.reshape(acc_scale, [acc_scale.shape[-1]])
+ # Convert to flexpoint output if configured (scalar scales)
+ acc = float_to_flex(
+ acc,
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale,
+ None,
+ Out,
+ FLEXPOINT_SATURATE_INF,
+ )
+ # write-back for this tile
+ out_ptr = OutPtrs + pid_t * stride_om
+ tl.store(out_ptr, acc, mask=out_n_mask)
+ if HAS_OUT_MX_SCALE:
+ out_scale_ptr = OutScalePtrs + pid_t * stride_omxs
+ tl.store(out_scale_ptr, acc_scale, mask=out_n_mask_scale)
+ XPtrs += BLOCK_N * stride_xn
+ OutPtrs += BLOCK_N_OUT * stride_on
+ if HAS_IN_MX_SCALE:
+ XScalePtrs += BLOCK_N // 32 * stride_xn
+ if HAS_OUT_MX_SCALE:
+ OutScalePtrs += BLOCK_N_OUT // 32 * stride_xn
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc09cfaddff2f1bf764aacfa621e87936a846857
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags.py
@@ -0,0 +1,303 @@
+# isort: off
+# fmt: off
+from dataclasses import dataclass
+import triton
+from compactor_vllm.triton_kernels.target_info import get_cdna_version
+import torch
+from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
+
+
+@dataclass
+class OptFlags:
+ block_m: int
+ block_n: int
+ block_k: int
+ num_warps: int
+ num_stages: int
+ group_m: int
+ xcd_swizzle: int
+ w_cache_modifier: str
+ split_k: int
+ is_persistent: bool
+ fused_scatter: bool
+ idle_sms: int
+ epilogue_subtile: int | None
+ arch: str
+ target_kernel_kwargs: dict
+
+ def __post_init__(self):
+ if self.fused_scatter and self.split_k != 1:
+ raise ValueError("Not supported")
+
+
+def make_default_opt_flags_amd(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ enforce_bitwise_invariance,
+ epilogue_effective_itemsize,
+ constraints,
+):
+ constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
+ # tokens per expert
+ if routing_data is None:
+ tokens_per_expt = m
+ elif routing_data.expected_tokens_per_expt is None:
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
+ else:
+ tokens_per_expt = routing_data.expected_tokens_per_expt
+
+ is_cdna4 = get_cdna_version() == 4
+ # block_m
+ if constraints.get("block_m", None):
+ block_m = constraints["block_m"]
+ elif enforce_bitwise_invariance:
+ block_m = 256 if is_cdna4 else 128
+ elif tokens_per_expt >= 512 and n >= 2048:
+ block_m = 256 if is_cdna4 else 128
+ elif is_cdna4 and m >= 512:
+ block_m = 128
+ else:
+ block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
+
+ if routing_data is not None:
+ grid_m = routing_data.n_blocks(m, block_m)
+ else:
+ grid_m = triton.cdiv(m, block_m)
+ # group_m:
+ group_m = 4
+ # number of xcds
+ num_xcds = 8
+ xcd_swizzle = num_xcds
+ # block_nk:
+ block_n, block_k = opt_flags_amd.compute_block_nk(
+ n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
+ )
+ # Replace block_k if provided in constraints.
+ # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
+ if constraints.get("block_k", None) is not None:
+ block_k = constraints["block_k"]
+ if constraints.get("block_n", None) is not None:
+ block_n = constraints["block_n"]
+ is_persistent = constraints.get("is_persistent", False)
+ # split_k:
+ if constraints.get("split_k", None) is not None:
+ split_k = constraints["split_k"]
+ elif is_persistent or enforce_bitwise_invariance:
+ split_k = 1
+ else:
+ grid_size = grid_m * ((n + block_n - 1) // block_n)
+ n_cu = torch.cuda.get_device_properties(0).multi_processor_count
+ split_k = max(1, n_cu // grid_size)
+ # w_cache_modifier:
+ w_cache_modifier = ".cg" if block_m <= 32 else None
+ # num_warps, num_stages
+ num_warps = 2 if (m is not None and m <= 16) else 8
+ num_stages = 2
+ # AMD-specific
+ target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
+ epilogue_subtile = constraints.get('epilogue_subtile', None)
+ if epilogue_subtile is None:
+ epilogue_subtile = 1
+ ret = OptFlags(
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ group_m=group_m,
+ xcd_swizzle=xcd_swizzle,
+ w_cache_modifier=w_cache_modifier,
+ split_k=split_k,
+ is_persistent=is_persistent,
+ fused_scatter=constraints.get('fused_scatter', False),
+ idle_sms=0,
+ epilogue_subtile=epilogue_subtile,
+ arch=None,
+ target_kernel_kwargs=target_kernel_kwargs,
+ )
+ # check constraints
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
+ return ret
+
+def make_default_opt_flags_nvidia(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ enforce_bitwise_invariance,
+ epilogue_effective_itemsize,
+ constraints,
+):
+ constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
+ # tokens per expert
+ if routing_data is None:
+ tokens_per_expt = m
+ elif routing_data.expected_tokens_per_expt is None:
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
+ else:
+ tokens_per_expt = routing_data.expected_tokens_per_expt
+ # pid swizzling
+ group_m = 8
+ xcd_swizzle = 1
+ # block_m
+ if constraints.get("block_m", None):
+ block_m = constraints["block_m"]
+ elif enforce_bitwise_invariance:
+ block_m = 128
+ else:
+ block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
+ # block n
+ arch = None
+ block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
+ # is_persistent
+ grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
+ n_sms = torch.cuda.get_device_properties(0).multi_processor_count
+ tiles_per_sm = grid_size / n_sms
+ supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
+ if constraints.get("is_persistent", None) is not None:
+ is_persistent = constraints["is_persistent"]
+ else:
+ has_simple_epilogue = precision_config.max_num_imprecise_acc is None
+ is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
+ # TEMP CHANGE
+ if precision_config.act_scale is not None or precision_config.out_scale is not None:
+ is_persistent = False
+ # block k
+ if constraints.get("block_k", None) is not None:
+ block_k = constraints["block_k"]
+ else:
+ block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
+ # split_k
+ if constraints.get("split_k", None) is not None:
+ split_k = constraints["split_k"]
+ elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
+ split_k = 1
+ else:
+ estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
+ split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
+ if split_k > 1:
+ # With split_k, results are written in f32. Use that for the following computations.
+ out_dtype = torch.float32
+ compute_num_stages_args = (
+ precision_config,
+ is_persistent,
+
+ block_m,
+ block_n,
+ block_k,
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ )
+
+ if constraints.get("epilogue_subtile", None) is not None:
+ subtiles_to_check = [constraints["epilogue_subtile"]]
+ else:
+ subtiles_to_check = [1, 2, 4]
+ num_stages = -1
+ for ep in subtiles_to_check:
+ ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
+ if ns > num_stages:
+ epilogue_subtile, num_stages = ep, ns
+ assert num_stages >= 1
+ if constraints.get("num_stages", None):
+ num_stages = constraints["num_stages"]
+ # fused scatter scratchpad
+ if constraints.get("fused_scatter", None) is not None:
+ fused_scatter = constraints["fused_scatter"]
+ else:
+ fused_scatter = can_use_fused_scatter and split_k == 1
+ # Handshake with the HBM swizzling
+ num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
+ ret = OptFlags(
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ fused_scatter=fused_scatter,
+ group_m=group_m,
+ xcd_swizzle=xcd_swizzle,
+ w_cache_modifier=None,
+ split_k=split_k,
+ is_persistent=is_persistent,
+ epilogue_subtile=epilogue_subtile,
+ arch=arch,
+ target_kernel_kwargs=dict(),
+ idle_sms=constraints.get("idle_sms", 0),
+ )
+ # check constraints
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
+ return ret
+
+# --------------
+# User Interface
+# --------------
+
+_opt_flags_constraints: dict = dict()
+_opt_flags: OptFlags | None = None
+
+def update_opt_flags_constraints(constraints: dict[str, int]):
+ global _opt_flags_constraints
+ _opt_flags_constraints.update(constraints)
+
+def reset_opt_flags_constraints():
+ global _opt_flags_constraints
+ _opt_flags_constraints = dict()
+
+def set_opt_flags(opt_flags: OptFlags):
+ global _opt_flags
+ assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
+ assert not _opt_flags, "opt_flags already set; please reset to None first"
+ _opt_flags = opt_flags
+
+class InapplicableConstraint(Exception):
+ pass
+
+def make_opt_flags(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ epilogue_effective_itemsize,
+):
+ if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
+ raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
+ if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter:
+ raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint")
+ enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
+ if _opt_flags is not None:
+ assert not _opt_flags_constraints
+ return _opt_flags
+ args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
+ routing_data, can_use_persistent_tma, can_use_fused_scatter,
+ enforce_bitwise_invariance, epilogue_effective_itemsize,
+ _opt_flags_constraints]
+ backend = triton.runtime.driver.active.get_current_target().backend
+ if backend == "hip":
+ return make_default_opt_flags_amd(*args)
+ if backend == "cuda":
+ return make_default_opt_flags_nvidia(*args)
+ assert False
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc396aa930731201b79404faf03ac19551ede81
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
@@ -0,0 +1,37 @@
+import torch
+import triton
+from compactor_vllm.triton_kernels.target_info import get_cdna_version
+from compactor_vllm.triton_kernels.tensor import bitwidth
+
+
+def compute_block_nk(
+ n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
+):
+ lhs_width = bitwidth(lhs_dtype) / 8
+ rhs_width = bitwidth(rhs_dtype) / 8
+
+ # block_n:
+ n_cu = torch.cuda.get_device_properties(0).multi_processor_count
+ if n is not None:
+ if n <= 128 and (n & (n - 1)) == 0:
+ block_n = n
+ else:
+ block_n = max(
+ 32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))
+ )
+ elif block_m > 64:
+ block_n = 256
+ else:
+ block_n = 128
+
+ if get_cdna_version() == 4 and block_m == 128:
+ block_n = 512
+
+ # block_k needs to match the cacheline size (128B)
+ block_k = int(128 // min(lhs_width, rhs_width))
+
+ # TODO: block_k = 128 seems to work better for now.
+ # perhaps due to increased number of k loops to pipeline
+ if precision_config.weight_scale is not None and get_cdna_version() != 4:
+ block_k = 128
+ return block_n, block_k
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
new file mode 100644
index 0000000000000000000000000000000000000000..978f8330e9c5bee5b10246591be7db8da6cd0955
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
@@ -0,0 +1,119 @@
+import torch
+import triton
+from compactor_vllm.triton_kernels import target_info
+from compactor_vllm.triton_kernels.tensor import get_layout, bitwidth, FP4
+from compactor_vllm.triton_kernels.tensor_details.layout import HopperMXScaleLayout
+from compactor_vllm.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import (
+ MXFP_BLOCK_SIZE,
+)
+
+
+def compute_grid_size(routing_data, m, n, block_m, block_n):
+ if routing_data is not None:
+ grid_m = routing_data.n_blocks(m, block_m)
+ else:
+ grid_m = triton.cdiv(m, block_m)
+ grid_n = (n + block_n - 1) // block_n
+ return grid_m * grid_n
+
+
+def compute_block_n(n: int, arch, precision_config):
+ # block_n:
+ layout = get_layout(precision_config.weight_scale)
+ if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
+ return 128
+ elif precision_config.max_num_imprecise_acc is None and n > 128:
+ return 256
+ else:
+ return max(16, min(128, triton.next_power_of_2(n)))
+
+
+def compute_block_k(
+ m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config
+):
+ lhs_width = bitwidth(lhs_dtype)
+ rhs_width = bitwidth(rhs_dtype)
+ # block_k needs to match the cacheline size (1024 bits)
+ block_k = int(1024 // min(lhs_width, rhs_width))
+ has_native_mxfp = target_info.cuda_capability_geq(10, 0)
+ if rhs_width == 4 and not has_native_mxfp:
+ block_k = 128
+ elif k is not None:
+ block_k = max(32, min(triton.next_power_of_2(k), block_k))
+ has_mx_weight_scale = (
+ precision_config is not None and precision_config.weight_scale is not None
+ )
+ if has_native_mxfp and is_persistent and has_mx_weight_scale:
+ block_k = min(block_k, 128)
+ return block_k
+
+
+def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
+ device_props = torch.cuda.get_device_properties(0)
+ n_sms = device_props.multi_processor_count
+ split_k = n_sms // grid_size
+ if k is not None:
+ # avoid split_k for small k
+ num_block_k = triton.cdiv(k, block_k)
+ split_k = min(split_k, num_block_k // 4)
+ split_k = max(split_k, 1)
+ return split_k
+
+
+def compute_num_warps(block_m, block_n, precision_config):
+ layout = get_layout(precision_config.weight_scale)
+ if isinstance(layout, HopperMXScaleLayout):
+ return layout.num_warps
+ return max(block_m * block_n // 4096, 4)
+
+
+def compute_num_stages(
+ precision_config,
+ is_persistent,
+ block_m,
+ block_n,
+ block_k,
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ epilogue_subtile,
+ epilogue_effective_itemsize,
+):
+ if precision_config.max_num_imprecise_acc is not None:
+ return 3
+ weight_size = bitwidth(rhs_dtype) / 8
+ stage_size = (
+ block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size
+ )
+ device_props = torch.cuda.get_device_properties(0)
+ smem_capacity = device_props.shared_memory_per_block_optin
+ has_native_mxfp = target_info.cuda_capability_geq(10, 0)
+ if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None:
+ if rhs_dtype == FP4:
+ # 4-bit e2m1 weights are padded 2x
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
+ stage_size += block_k * block_n * weight_size
+
+ if is_persistent:
+ # Per-stage wait barrier
+ stage_size += 8
+ if target_info.cuda_capability_geq(10, 0):
+ acc_size = epilogue_effective_itemsize or out_dtype.itemsize
+ else:
+ acc_size = out_dtype.itemsize
+ if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None:
+ acc_block_n = block_n // epilogue_subtile
+ else:
+ acc_block_n = block_n
+ # pipelined TMA store local to global, or
+ # pipelined layout conversion before store of the accumulator
+ # note: layout conversion has some padding
+ smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
+ if precision_config.weight_scale is not None:
+ # mx scales
+ stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
+ elif has_native_mxfp:
+ # mx scales
+ stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
+ num_stages = min(4, smem_capacity // int(stage_size))
+ return num_stages
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics.py
new file mode 100644
index 0000000000000000000000000000000000000000..024d3fcf0b819646a485596070b14c7a0a2e17ed
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics.py
@@ -0,0 +1,42 @@
+import torch
+from dataclasses import dataclass
+
+MAX_FINITE_FLOAT8E5 = 57344.0
+MAX_FINITE_FLOAT8E4NV = 448.0
+MAX_FINITE_FLOAT8E4B8 = 240.0
+
+
+@dataclass(frozen=True)
+class BaseFlexData:
+ dtype: torch.dtype | None = None
+
+ def view(self, x: torch.Tensor):
+ if self.dtype is None:
+ return x
+ return x.view(self.dtype)
+
+ def reinterpret(self, x):
+ if self.dtype is None or x.dtype.itemsize > 1:
+ return x
+ return x.view(self.dtype)
+
+
+@dataclass(frozen=True)
+class InFlexData(BaseFlexData):
+ scale: torch.Tensor | None = None
+
+ @property
+ def is_per_batch(self):
+ return False if self.scale is None else len(self.scale) > 1
+
+
+@dataclass(frozen=True)
+class OutFlexData(BaseFlexData):
+ expected_scale: torch.Tensor | None = None
+ actual_scale: torch.Tensor | None = None
+ checksum_scale: torch.Tensor | None = None
+
+ def __iter__(self):
+ yield self.expected_scale
+ yield self.actual_scale
+ yield self.checksum_scale
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/flexpoint.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/flexpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..92fd0ef9b075cf7d9e0db1320164ce2a54366c59
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/flexpoint.py
@@ -0,0 +1,204 @@
+from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
+import triton
+import triton.language as tl
+from compactor_vllm.triton_kernels.target_info import cuda_capability_geq
+
+# -------------------------------
+# Kernels stuff
+# -------------------------------
+
+TL_MAX_FINITE_FLOAT8E5 = tl.constexpr(MAX_FINITE_FLOAT8E5)
+TL_MAX_FINITE_FLOAT8E4NV = tl.constexpr(MAX_FINITE_FLOAT8E4NV)
+TL_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(MAX_FINITE_FLOAT8E4B8)
+TL_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(1.750)
+TL_MAX_FINITE_FLOAT16 = tl.constexpr(65472.0)
+
+TL_RCP_MAX_FINITE_FLOAT8E5 = tl.constexpr(0x37924925) # 0x1.24924Ap-16
+TL_RCP_MAX_FINITE_FLOAT8E4NV = tl.constexpr(0x3B124925) # 0x1.24924Ap-9
+TL_RCP_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(0x3B888889) # 0x1.111112p-8
+TL_RCP_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(0x3F124925) # 0x1.24924Ap-1
+TL_RCP_MAX_FINITE_FLOAT16 = tl.constexpr(0x37802008) # 0x1.004010p-16
+
+
+@triton.jit
+def max_finite(dtype):
+ if dtype == tl.constexpr(tl.float8e5):
+ return TL_MAX_FINITE_FLOAT8E5
+ elif dtype == tl.constexpr(tl.float8e4nv):
+ return TL_MAX_FINITE_FLOAT8E4NV
+ elif dtype == tl.constexpr(tl.float8e4b8):
+ return TL_MAX_FINITE_FLOAT8E4B8
+ elif dtype == tl.constexpr(tl.float8e4b15):
+ return TL_MAX_FINITE_FLOAT8E4B15
+ elif dtype == tl.constexpr(tl.float16):
+ return TL_MAX_FINITE_FLOAT16
+ else:
+ tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
+
+
+@triton.jit
+def rcp_max_finite(dtype):
+ if dtype == tl.constexpr(tl.float8e5):
+ return TL_RCP_MAX_FINITE_FLOAT8E5
+ elif dtype == tl.constexpr(tl.float8e4nv):
+ return TL_RCP_MAX_FINITE_FLOAT8E4NV
+ elif dtype == tl.constexpr(tl.float8e4b8):
+ return TL_RCP_MAX_FINITE_FLOAT8E4B8
+ elif dtype == tl.constexpr(tl.float8e4b15):
+ return TL_RCP_MAX_FINITE_FLOAT8E4B15
+ elif dtype == tl.constexpr(tl.float16):
+ return TL_RCP_MAX_FINITE_FLOAT16
+ else:
+ tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
+
+
+@triton.jit
+def sm86_min_nan_xorsign_abs_f32(a, b):
+ """Wrapper for min.NaN.xorsign.abs.f32 PTX instruction.
+
+ Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
+ NaN inputs are propagated to the output.
+
+ Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
+ """
+ tl.static_assert(
+ cuda_capability_geq(8, 6),
+ "min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+",
+ )
+ tl.static_assert(
+ a.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+ tl.static_assert(
+ b.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+
+ return tl.inline_asm_elementwise(
+ """{
+ min.NaN.xorsign.abs.f32 $0, $1, $2;
+ }""",
+ "=r,r,r",
+ [a, b],
+ dtype=tl.float32,
+ is_pure=True,
+ pack=1,
+ )
+
+
+@triton.jit
+def sm86_max_nan_xorsign_abs_f32(a, b):
+ """Wrapper for max.NaN.xorsign.abs.f32 PTX instruction.
+
+ Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
+ NaN inputs are propagated to the output.
+
+ Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
+ """
+ tl.static_assert(
+ cuda_capability_geq(8, 6),
+ "max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+",
+ )
+ tl.static_assert(
+ a.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+ tl.static_assert(
+ b.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+
+ return tl.inline_asm_elementwise(
+ """{
+ max.NaN.xorsign.abs.f32 $0, $1, $2;
+ }""",
+ "=r,r,r",
+ [a, b],
+ dtype=tl.float32,
+ is_pure=True,
+ pack=1,
+ )
+
+
+@triton.jit
+def load_scale(scale_ptr):
+ return 1.0 if scale_ptr is None else tl.load(scale_ptr)
+
+
+@triton.jit
+def flex_to_float(x, scale_ptr):
+ scale = load_scale(scale_ptr)
+ return x.to(tl.float32) * scale
+
+
+@triton.jit
+def clip(x, limit):
+ res = tl.minimum(x, limit)
+ res = tl.maximum(-limit, res)
+ return res
+
+
+@triton.jit
+def nan_propagating_absmax_reduce(x, axis=None):
+ if cuda_capability_geq(8, 6):
+ # abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported.
+ x_absmax = tl.reduce(x, axis, sm86_max_nan_xorsign_abs_f32)
+ # Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it.
+ x_absmax = x_absmax.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
+ else:
+ # Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float)
+ masked_abs_x = x.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
+ x_absmax = tl.max(masked_abs_x, axis)
+
+ return x_absmax
+
+
+@triton.jit
+def compute_scale(x, Out):
+ x_absmax = nan_propagating_absmax_reduce(tl.ravel(x, can_reorder=True))
+
+ # atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000).
+ # We use integer minimum because NaNs are above +inf in integer representation.
+ x_absmax = tl.minimum(x_absmax, 0x7F800000).to(tl.float32, bitcast=True)
+ RCP_MAX_VALUE = rcp_max_finite(Out.dtype.element_ty)
+ return tl.fma(x_absmax, RCP_MAX_VALUE.to(tl.float32, bitcast=True), 1.0e-30)
+
+
+@triton.jit
+def update_scale(x, scale_ptr, Out) -> None:
+ if scale_ptr is not None:
+ scale = compute_scale(x, Out)
+ tl.atomic_max(scale_ptr, scale, sem="relaxed")
+
+
+@triton.jit
+def float_to_flex(
+ x,
+ expected_scale_ptr_or_val,
+ actual_scale_ptr,
+ checksum_scale_ptr,
+ mask,
+ Out,
+ saturate_infs: tl.constexpr,
+):
+ if expected_scale_ptr_or_val is not None:
+ if expected_scale_ptr_or_val.dtype.is_ptr():
+ invscale = 1.0 / tl.load(expected_scale_ptr_or_val)
+ else:
+ invscale = 1.0 / expected_scale_ptr_or_val
+ else:
+ invscale = 1.0
+ if checksum_scale_ptr is not None:
+ x_int32 = x.to(tl.int32, bitcast=True)
+ zero = tl.cast(0.0, tl.int32)
+ if mask is not None:
+ x_int32 = tl.where(mask, x_int32, zero)
+ checksum_local = tl.xor_sum(tl.ravel(x_int32, can_reorder=True), 0)
+ tl.atomic_add(checksum_scale_ptr, checksum_local)
+ if mask is not None:
+ if actual_scale_ptr is not None:
+ x = tl.where(mask, x, 0.0)
+ update_scale(x, actual_scale_ptr, Out)
+ x = x * invscale
+ # if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case.
+ if expected_scale_ptr_or_val is not None:
+ if saturate_infs:
+ CLIP_VALUE = max_finite(Out.dtype.element_ty)
+ x = clip(x, CLIP_VALUE)
+ return x
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..37c69c83c1dd77668ae80cbee0f21bafc5767815
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp.py
@@ -0,0 +1,303 @@
+# isort: off
+# fmt: off
+from enum import Enum
+import triton
+import torch
+import torch.nn.functional as F
+from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp
+from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, MXFP_BLOCK_SIZE, _quantize_mxfp8_fn
+
+# -----------------------------------------------------------------------------
+# Dequantization / Quantization Utilities
+# -----------------------------------------------------------------------------
+
+
+class DequantScaleRoundingMode(Enum):
+ ROUND_UP = 0
+ ROUND_DOWN = 1
+
+
+def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
+ DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
+ """
+ Convert the src weights to mx format. The src weight is quantized along the axis dimension.
+
+ If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
+ Note that this means the k_dim of the tensor will be half of the logical k_dim.
+
+ If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
+ in their respective formats.
+ """
+ ndim = src_tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ axis = axis if axis >= 0 else axis + ndim
+ # downcast
+ src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
+ is_fp4 = out_quant_type == torch.uint8
+ is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
+ assert is_fp4 or is_fp8
+ divisor = 2 if is_fp4 else 1
+ L = src_tensor.shape[-1]
+ if is_fp4:
+ assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
+ out_shape = src_tensor.shape[:-1] + (L // divisor, )
+ out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
+
+ out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
+ out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
+
+ if src_tensor.numel() > 0:
+ kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
+ kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
+ kernel_scale = out_scale.view(-1, out_scale.shape[-1])
+
+ BLOCK_OUT_DIM = 128
+ BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
+ grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
+ grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
+
+ _downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
+ *kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
+ *kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
+ DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
+
+ out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
+ out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
+ return out_quant_tensor, out_scale
+
+
+def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
+ """
+ Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
+
+ The function assumes that the tensors were quantized along the given axis.
+ It permutes the tensor so that the quantized axis is last, reshapes to 2D,
+ launches the Triton upcast kernel, and then unpermutes back to the original order.
+ """
+ ndim = tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ axis = axis if axis >= 0 else axis + ndim
+ assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
+ f"Got {tensor.ndim=} and {scale.ndim=}")
+ # dtype checks
+ assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
+ f"Invalid tensor dtype {tensor.dtype=}"
+ assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
+ assert target_dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {target_dtype=}"
+ # upcast
+ logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
+ tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
+ scale = scale.transpose(axis, scale.ndim - 1).contiguous()
+ out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=target_dtype, device=tensor.device)
+ reshaped_out = out.view(-1, out.shape[-1])
+ reshaped_tensor = tensor.view(-1, tensor.shape[-1])
+ reshaped_scale = scale.view(-1, scale.shape[-1])
+ BLOCK_OUT_DIM = 128
+ BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
+ blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
+ blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
+ _upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
+ *reshaped_scale.stride(), reshaped_tensor,
+ *reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
+ BLOCK_QUANT_DIM, num_warps=8)
+ out = out.transpose(axis, scale.ndim - 1).contiguous()
+ return out
+
+
+# ------------
+
+
+def right_shift_unsigned(x, shift):
+ # CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
+ return (x >> shift) & ((1 << (32 - shift)) - 1)
+
+
+def get_max_quant_val(dtype: torch.dtype):
+ d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
+ assert dtype in d
+ return d[dtype]
+
+
+def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
+ DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
+ """
+ Converts the src tensor to the output format specified by out_quant_type.
+ axis: The axis along which the tensors are contiguous and quantization is applied.
+ DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
+
+ Returns:
+ out_quant_tensor: Quantized tensor in mx format.
+ • For mxfp8, the output has the same shape as src_tensor.
+ • For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
+ scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
+ Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
+ where L is the original length along that axis.
+ """
+ # This should probably be packed into its own tiny class
+ ndim = src_tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ assert src_tensor.dtype in {torch.float32, torch.bfloat16,
+ torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
+
+ axis = axis if axis >= 0 else axis + ndim
+ is_fp4 = out_quant_type == torch.uint8
+ is_fp8 = "float8" in str(out_quant_type)
+ assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
+
+ device = src_tensor.device
+
+ # For mxfp4 conversion, we assume the contiguous axis length is even.
+ if is_fp4:
+ axis_shape = src_tensor.size(axis)
+ assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
+
+ # Permute the tensor so that the contiguous axis becomes the last dimension.
+ src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
+ axis_shape = src.shape[-1]
+
+ # Pad the axis to be divisible by 32, in case it is not.
+ next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
+ pad_amount = next_multiple - axis_shape
+ padded_src = F.pad(src, (0, pad_amount))
+ valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
+ padded_axis_shape = padded_src.size(-1) # now divisible by 32
+
+ # --- Compute per-group maximums for scale ---
+ # Set padded entries to -1 so they don’t affect the max.
+ abs_f = torch.abs(padded_src)
+ abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
+ # Reshape the last dimension into groups of 32.
+ new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
+ abs_groups = abs_f.view(*new_shape)
+ # Compute maximum along the group dimension (of size 32).
+ max_val, _ = abs_groups.max(dim=-1, keepdim=True)
+
+ # Choose a max quantization value depending on type.
+ max_quant_val = get_max_quant_val(out_quant_type)
+ dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
+
+ # Convert to int to round the FP32 scale, prior to quantization!
+ ds_int = dequant_scale.view(torch.int32)
+ if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
+ ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
+ else:
+ ds_int_rounded = ds_int & 0x7F800000
+ # Reinterpret back as float32.
+ dequant_scale_rounded = ds_int_rounded.view(torch.float32)
+
+ # Compute the quantization scale.
+ quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
+
+ # Quantize the tensor
+ orig_padded_shape = padded_src.shape
+ padded_src_groups = padded_src.view(*new_shape)
+ quant_tensor = padded_src_groups * quant_scale
+ # Reshape back to the original shape and trim padding
+ quant_tensor = quant_tensor.view(orig_padded_shape)
+ quant_tensor = quant_tensor[..., :axis_shape]
+
+ # Finally, convert the quantized tensor to the target format
+ if is_fp8:
+ # Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
+ quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
+ out_weight = quant_tensor.to(out_quant_type)
+ else:
+ assert is_fp4, f"Invalid output quantization type {out_quant_type}"
+ # For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
+ # First, reinterpret the quantized tensor bits.
+ q_int = quant_tensor.contiguous().view(torch.int32)
+ # Extract sign, exponent, and mantissa.
+ signs = q_int & 0x80000000
+ exponents = right_shift_unsigned(q_int, 23) & 0xFF
+ mantissas = q_int & 0x7FFFFF
+
+ E8_BIAS = 127
+ E2_BIAS = 1
+ # Adjust mantissas for subnormals.
+ mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
+ (E8_BIAS - exponents - 1), mantissas)
+ exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
+ e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
+ e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
+ e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
+
+ # Pack pairs of 4-bit values along the last dimension.
+ e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
+ evens = e2m1_value[..., 0]
+ odds = e2m1_value[..., 1]
+ out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
+
+ # --- Process and output the scale ---
+ dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
+ dq_scale = dq_scale.squeeze(-1)
+ out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
+ dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
+ return out_weight, dq_scale
+
+
+def cvt_e2m1_to_fp32(input_tensor):
+ assert input_tensor.dtype == torch.uint8
+
+ input_tensor = input_tensor.to(torch.int32)
+ evens = input_tensor & 0xF
+ odds = (input_tensor >> 4) & 0xF
+
+ vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
+ outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
+ outputs = torch.cat([outputs, -outputs])
+
+ even_floats = outputs[evens]
+ odd_floats = outputs[odds]
+ output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
+ output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
+ return output_tensor
+
+
+def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
+ """
+ Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
+ axis: The axis along which dequantization is applied.
+
+ Returns:
+ out_weight: Tensor in the target format.
+ """
+
+ ndim = tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
+ assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
+
+ # Permute the tensor and scale so that the quantization axis becomes the last dimension
+ axis = axis if axis >= 0 else axis + ndim
+ scale = scale.transpose(axis, scale.ndim - 1)
+ tensor = tensor.transpose(axis, tensor.ndim - 1)
+
+ dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
+ if tensor.dtype == torch.uint8:
+ fp32_tensor = cvt_e2m1_to_fp32(tensor)
+ else:
+ fp32_tensor = tensor.to(torch.float32)
+
+ logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
+ axis_shape = fp32_tensor.size(-1)
+ padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
+ pad_size = padded_axis_shape - axis_shape
+ padded_tensor = F.pad(fp32_tensor, (0, pad_size))
+
+ new_axis_shape = padded_tensor.shape[-1]
+ new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
+ padded_tensor = padded_tensor.view(*new_shape)
+ dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
+ out_padded = padded_tensor * dq_scale_padded
+
+ # Flatten back and remove the padded tail
+ out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
+ out_tensor = out_padded[..., :axis_shape]
+
+ out_tensor = out_tensor.to(target_dtype).contiguous()
+ out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
+
+ return out_tensor
+
+
+quantize_mxfp8_fn = _quantize_mxfp8_fn
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eac6467e2d8d49385106574ec073cf677c622e0
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
@@ -0,0 +1,158 @@
+import triton
+import triton.language as tl
+
+# fmt: off
+
+
+MXFP_BLOCK_SIZE = tl.constexpr(32)
+
+
+@triton.jit
+def _get_max_quant_val(dtype: tl.constexpr):
+ if dtype == tl.uint8:
+ return 6.0
+ elif dtype == tl.float8e5:
+ return 57344.0
+ elif dtype == tl.float8e4nv:
+ return 448.0
+ else:
+ tl.static_assert(False, f"Invalid {dtype=}")
+
+@triton.jit
+def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
+ BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
+
+ # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
+ f32_tensor = src_tensor.to(tl.float32)
+ abs_tensor = tl.abs(f32_tensor)
+ abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
+ abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
+ dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
+ if DEQUANT_SCALE_ROUNDING_MODE == 0:
+ # DequantScaleRoundingMode.ROUND_UP
+ # compute 2 ** ceil(log2(dequant_scale))
+ # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
+ # A corner case: exponent is 0xFF that will overflow but that's already
+ # NaN so assume we don't care.
+ dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
+ else:
+ # DequantScaleRoundingMode.ROUND_DOWN
+ # compute 2 ** floor(log2(dequant_scale))
+ assert DEQUANT_SCALE_ROUNDING_MODE == 1
+ dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
+ dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
+ quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
+
+ f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ quant_tensor = f32_tensor * quant_scale
+
+ # Reshape the tensors after scaling
+ quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
+ # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
+ quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
+ dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
+
+ # First, we simply extract the exponent part of the scales and store the result
+ dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
+ # Now we must convert the tensors to the mx format.
+ if is_fp8:
+ out_tensor = quant_tensor.to(mx_tensor_dtype)
+ else:
+ quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
+ signs = quant_tensor & 0x80000000
+ exponents = (quant_tensor >> 23) & 0xFF
+ mantissas = (quant_tensor & 0x7FFFFF)
+
+ # 0.25 <= x < 0.75 maps to 0.5, a denormal number
+ E8_BIAS = 127
+ E2_BIAS = 1
+ # Move implicit bit 1 at the beginning to mantissa for denormals
+ adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
+ mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
+
+ # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
+ exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
+
+ # Combine sign, exponent, and mantissa, while saturating
+ # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
+ e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
+ e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
+
+ e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
+ evens, odds = tl.split(e2m1_value)
+ out_tensor = evens | (odds << 4)
+
+ return out_tensor, dequant_scale_exponent
+
+@triton.jit
+def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
+ mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
+ src_ptr, stride_src_outer, stride_src_quant,
+ outer_dim, quant_dim,
+ BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
+
+ tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
+
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
+ tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
+ f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
+
+ src_dtype: tl.constexpr = src_ptr.dtype.element_ty
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
+ tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
+
+ outer_block = tl.program_id(0).to(tl.int64)
+ quant_block = tl.program_id(1).to(tl.int64)
+
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
+
+ start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
+ start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
+
+ src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
+ mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
+ mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
+
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
+ offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
+ offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
+
+ mask_src_quant = start_src_quant + offs_src_quant < quant_dim
+ mask_n = start_out + offs_outer < outer_dim
+ full_mask_src = mask_src_quant & mask_n
+
+ mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
+ full_mask_mxt = mask_mxt_quant & mask_n
+
+ scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
+ full_scale_mask = scale_mask_k & mask_n
+
+ src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
+ mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
+ mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
+ src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
+
+ out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
+ DEQUANT_SCALE_ROUNDING_MODE)
+
+ tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
+ tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
+
+
+@triton.jit(repr=lambda _: "_dequantize_mxfp8")
+def _quantize_mxfp8_fn(input, mask, pid=None):
+ return _compute_quant_and_scale(input, mask, tl.float8e4nv)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e5f027fa986c06f402405a4a5047b649b3e1bfe
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
@@ -0,0 +1,125 @@
+import triton
+import triton.language as tl
+
+from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
+
+
+# fmt: off
+@triton.jit
+def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
+ stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
+ outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
+
+ tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
+ dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
+ tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 or dst_dtype == tl.float32)
+ tl.static_assert(
+ mx_tensor_dtype == tl.uint8
+ or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
+ "mx_tensor_ptr must be uint8 or float8 or dst_dtype")
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+
+ # Determine if we are dealing with fp8 types.
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
+
+ # Compute starting indices for the quantized (packed) dimension and the outer dimension.
+ outer_block = tl.program_id(0).to(tl.int64)
+ quant_block = tl.program_id(1).to(tl.int64)
+
+ start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
+ start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
+
+ mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
+ mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
+ out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
+
+ # Compute offsets and masks.
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
+ offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
+ offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
+
+ mask_outer = start_out + offs_outer < outer_dim
+ mask_out_quant = start_out_quant + offs_out_quant < quant_dim
+ full_mask_out = mask_out_quant & mask_outer
+
+ mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
+ full_mask_src = mask_src_quant & mask_outer
+
+ mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
+ full_scale_mask = mask_scale & mask_outer
+
+ tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
+ scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
+ out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
+
+ # Load the packed tensor and scale.
+ tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
+ scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
+
+ # Upcast the scale to the destination type.
+ if dst_dtype == tl.bfloat16:
+ dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
+ else:
+ dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
+ if dst_dtype == tl.float16:
+ dst_scale = dst_scale.to(tl.float16)
+
+ # Now upcast the tensor.
+ intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
+ if is_fp8:
+ dst_tensor = tensor.to(intermediate_dtype)
+ if tensor.dtype == tl.float8e5:
+ from_e_bits: tl.constexpr = 5
+ from_m_bits: tl.constexpr = 2
+ to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
+ to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
+
+ # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
+ non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
+ non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
+ dst_tensor = tl.where(
+ (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
+ (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
+ dst_tensor,
+ )
+ else:
+ assert is_fp4
+ dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
+ dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
+ dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
+ # e2m1
+ em0 = tensor & 0x07
+ em1 = tensor & 0x70
+ x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
+ x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
+ # Three cases:
+ # 1) x is normal and non-zero: Correct bias
+ x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
+ x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
+ # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
+ x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
+ x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
+ # 3) x is zero, do nothing
+ dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
+ dst_tensor = dst_tensor.to(dst_dtype)
+
+ # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
+ dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
+ scale = scale.reshape(dst_scale.shape)
+
+ out_tensor = dst_tensor * dst_scale
+ # Correct any NaNs encoded via the scale.
+ out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
+ out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
+ tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py
new file mode 100644
index 0000000000000000000000000000000000000000..a187eecc2d66659c278be3668e7865ee8a785694
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py
@@ -0,0 +1,19 @@
+# proton options
+
+import os
+
+_launch_metadata_allow_sync = None
+
+
+def launch_metadata_allow_sync():
+ global _launch_metadata_allow_sync
+ if _launch_metadata_allow_sync is None:
+ _launch_metadata_allow_sync = not (
+ os.getenv("PROTON_LAUNCH_METADATA_NOSYNC") == "1"
+ )
+ return _launch_metadata_allow_sync
+
+
+def set_launch_metadata_allow_sync(allow_sync: bool):
+ global _launch_metadata_allow_sync
+ _launch_metadata_allow_sync = allow_sync
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/reduce_bitmatrix.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/reduce_bitmatrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..398482c321e119dfeb059fc420341ca58d1cceb1
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/reduce_bitmatrix.py
@@ -0,0 +1,133 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def vpopc(x):
+ """
+ Vertical popcount
+ Input x : uint32[..., N]
+ Output y : uint32[..., 32]
+ semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
+ credits: @apgoucher
+ """
+
+ tl.static_assert(
+ x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers"
+ )
+
+ BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
+ BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
+ if BLOCK_N >= 8:
+ sa1: tl.constexpr = 8
+ else:
+ sa1: tl.constexpr = BLOCK_N
+ # create 8-way sums in 4-bit fields:
+ y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1])
+ y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111
+ y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4]
+ if BLOCK_N >= 128:
+ sa2: tl.constexpr = 16
+ else:
+ sa2: tl.constexpr = BLOCK_N // sa1
+ # create 128-way sums in 8-bit fields:
+ y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4])
+ y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0F0F0F0F
+ y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
+ sa3: tl.constexpr = BLOCK_N // (sa1 * sa2)
+ # create N-way sums in 32-bit fields:
+ y = tl.reshape(y, [BATCHES, 1, sa3, 8])
+ y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000FF
+ y = tl.sum(y, 2) # [BATCHES, 4, 8]
+ y = tl.reshape(y, x.shape[:-1] + [32])
+ return y
+
+
+@triton.jit
+def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr):
+ pid = tl.program_id(0)
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+ tl.store(Ret + offs, 0)
+
+
+@triton.jit
+def _sum_bitmatrix_rows(
+ B,
+ shape_bm,
+ stride_bm: tl.constexpr,
+ stride_bn: tl.constexpr, # input bitmatrix
+ Ret,
+ Partials,
+ stride_pm: tl.constexpr,
+ stride_pn,
+ shape_pn, # outputs
+ BLOCK_MM: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ tl.static_assert(BLOCK_MM % BLOCK_M == 0)
+ TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M
+ if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr():
+ shape_bm = tl.load(shape_bm)
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM)
+ offs_n = pid_n * 32 + tl.arange(0, 32)
+ n_rows = shape_bm
+ bits = tl.load(
+ B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0
+ )
+ bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M])
+ ret = vpopc(bits) # [TILE_SIZE, 32]
+
+ offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
+
+ tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed")
+ tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)
+
+
+def clear_sums(n_cols, device, MEMSET_BLOCK=512):
+ cdiv = triton.cdiv
+ blocks = cdiv(n_cols, MEMSET_BLOCK)
+ out_ret = torch.empty((blocks * MEMSET_BLOCK,), device=device, dtype=torch.int32)
+ _sum_bitmatrix_memset[(blocks,)](out_ret, MEMSET_BLOCK)
+ return out_ret
+
+
+def sum_bitmatrix_rows(x, out_ret, partials_block_size=None):
+ assert partials_block_size is not None
+ cdiv = triton.cdiv
+ PARTIALS_BLOCK_M = partials_block_size
+ n_rows, n_cols = x.shape
+ n_rows_max = x.shape_max[0]
+ assert out_ret.shape == (n_cols,)
+
+ TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
+ BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
+
+ pids_x = cdiv(n_rows_max, BLOCK_MM)
+ pids_y = cdiv(n_cols, 32)
+ out_partials = torch.empty(
+ (pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32
+ )
+ out_partials = torch.transpose(out_partials, 0, 1)
+
+ # output tensors
+ _sum_bitmatrix_rows[(pids_x, pids_y)](
+ x.storage.data,
+ n_rows,
+ x.stride(0),
+ x.stride(1), # input
+ out_ret, # output [final reduction]
+ out_partials,
+ out_partials.stride(0),
+ out_partials.stride(1),
+ out_partials.shape[1], # output [partial reductions]
+ BLOCK_M=PARTIALS_BLOCK_M,
+ BLOCK_MM=BLOCK_MM, # constants
+ num_warps=8,
+ )
+
+ out_partials = out_partials[: cdiv(n_rows_max, PARTIALS_BLOCK_M), :]
+
+ return out_ret, out_partials
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd736f6f0867b95c67a3c857b4f0bcc80c79fc0
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing.py
@@ -0,0 +1,521 @@
+import torch
+import triton
+from dataclasses import dataclass, field
+from .routing_details._routing_compute import _combined_routing_compute
+from .routing_details._routing_compute import _combined_routing_memset
+from .routing_details._routing_compute import _routing_clear_bitmatrix
+from .routing_details._expt_data import _expt_data_memset
+from .routing_details._expt_data import _expt_data_compute
+from .target_info import is_hip
+
+
+@dataclass
+class GatherIndx:
+ """
+ Indices for an operation that performs:
+ Y = X[src_idx, :]
+ """
+
+ # array such that `dst_idx[src_idx] = arange(0, N)`
+ src_indx: torch.Tensor
+ dst_indx: torch.Tensor
+
+
+@dataclass
+class ScatterIndx:
+ """
+ Indices for an operation that performs:
+ Y[dst_idx, :] = X
+ """
+
+ # array such that `dst_idx[src_idx] = arange(0, N)`
+ src_indx: torch.Tensor
+ dst_indx: torch.Tensor
+
+
+@dataclass
+class ExptData:
+ # hist[i] is the number of tokens routed to expert i
+ hist: torch.Tensor
+ # token_offs_raw[i] is the offset of the first token routed
+ # to expert i in an expert-sorted array
+ token_offs_raw: torch.Tensor
+ # token_offs_pad[block][i] is the offset of the first token routed
+ # to expert i in an expert-sorted array, assuming histogram
+ # rounded to the next multiple of `block`
+ token_offs_pad: dict[int, torch.Tensor]
+ # block_id_map[block] contain one value for each `pid`` launched by
+ # the matrix multiplication kernel launched with BLOCK_M=block:
+ # - the value is -1 if the `pid` has no work to do
+ # - otherwise, the value is two int16 (packed as an int32) that
+ # correspond respectively to (1) the expert assigned to
+ # the tokens processed by this pid; (2) the block assigned to the
+ # tokens processed by this pid (think `pid_m` in a regular matmul)
+ # see `test_routing.py` for a reference implementation and more details
+ block_pid_map: dict[int, torch.Tensor]
+
+ def __post_init__(self):
+ if self.hist is not None:
+ assert self.hist.dtype == torch.int32
+ if self.token_offs_raw is not None:
+ assert self.token_offs_raw.dtype == torch.int32
+ if self.token_offs_pad is not None:
+ for v in self.token_offs_pad.values():
+ assert v.dtype == torch.int32
+ if self.block_pid_map is not None:
+ for v in self.block_pid_map.values():
+ assert v.dtype == torch.int32
+
+
+@dataclass
+class RoutingData:
+ gate_scal: torch.Tensor = field()
+ expt_hist: torch.Tensor = field()
+ n_expts_tot: int = field()
+ n_expts_act: int = field()
+ expt_data: ExptData = None
+
+ # Used to make perf annotation cleaner: when we use expert sharding, we can
+ # use this to tell the "expected" number of local tokens per expert, because
+ # the actual number can vary per each input.
+ expected_tokens_per_expt: int = field(default=None)
+
+ def n_blocks(self, n_rows, block_m):
+ if n_rows <= self.n_expts_tot:
+ return n_rows
+ else:
+ return (
+ triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m)
+ + self.n_expts_tot
+ - 1
+ )
+
+
+# --------------------------
+# sort tokens by expert
+# --------------------------
+
+
+class SortTokens(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
+ HIST_BLOCK_M = 32
+ INDX_OFFS_BLOCK_M = 512
+ MEMSET_BLOCK = 1024
+ cdiv = triton.cdiv
+
+ device = expt_scal.device
+ dtype = expt_scal.dtype
+ n_tokens_raw, _ = bitmatrix.shape
+ n_tokens_pad, n_expts_act = expt_scal.shape
+ n_gates_pad = n_tokens_pad * n_expts_act
+
+ hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
+ hist = hist[:n_expts_tot]
+ assert hist.dtype == torch.int32
+ # scratchpad
+ expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
+ combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
+ # output
+ topk_indx = combined_indx[:n_gates_pad]
+ gate_indx = combined_indx[n_gates_pad:]
+ gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
+
+ (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1a,
+ blocks2a,
+ MEMSET_BLOCK_A,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ ) = _compute_expt_data_internal(hist, n_expts_tot, n_gates_pad)
+
+ blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
+ blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
+
+ _combined_routing_memset[(blocks1a + blocks1b,)](
+ combined_indx,
+ n_gates_pad * 2,
+ -1,
+ MEMSET_BLOCK,
+ hist, #
+ expt_offs,
+ hist.shape[0],
+ n_expts_tot,
+ partial_hist, # inputs
+ partial_hist.shape[0],
+ partial_hist.stride(0),
+ partial_hist.stride(1), # outputs
+ token_offs_combined,
+ token_offs_combined.stride(0), #
+ blocks1a,
+ block_pid_map, #
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
+ BLOCK_N=512,
+ BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
+ )
+
+ indx_offs = partial_hist
+
+ _combined_routing_compute[(blocks2a + blocks2b,)](
+ topk_indx,
+ gate_indx,
+ gate_scal, # outputs
+ expt_scal,
+ expt_indx,
+ indx_offs,
+ indx_offs.stride(0),
+ indx_offs.stride(1), # inputs
+ expt_offs,
+ n_tokens_raw, # input shape
+ HIST_BLOCK_M,
+ n_expts_act, # constants
+ hist,
+ token_offs_pad,
+ token_offs_pad.stride(0),
+ block_pid_map,
+ block_pid_map.stride(0), # outputs
+ block_m_log2_start,
+ block_m_num,
+ HIST2_BLOCK_M,
+ blocks2a, # etc.
+ )
+
+ ctx.n_tokens_raw = n_tokens_raw
+ ctx.n_tokens_pad = n_tokens_pad
+ ctx.n_expts_act = n_expts_act
+ ctx.save_for_backward(gate_indx)
+ return (
+ hist,
+ topk_indx,
+ gate_indx,
+ gate_scal,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ )
+
+ @staticmethod
+ def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
+ (gate_indx,) = ctx.saved_tensors
+ dgate_scal = dgate_scal[gate_indx]
+ dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
+ return dgate_scal, None, None, None
+
+
+def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix):
+ return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix)
+
+
+# --------------------------
+# prune routing
+# --------------------------
+
+
+class PruneRouting(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
+ from .compaction import compaction
+
+ n_tokens_pad = expt_scal.shape[0]
+ assert n_expts_tot % simulated_ep == 0
+ _routing_clear_bitmatrix[(n_tokens_pad,)](
+ bitmatrix.storage.data,
+ bitmatrix.storage.data.stride(0),
+ bitmatrix.storage.data.stride(1),
+ bitmatrix.storage.data.shape[1],
+ n_expts_tot // simulated_ep,
+ BLOCK_N=512,
+ )
+ # perform compaction to update expt_scal / expt_indx
+ expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
+ n_expts_tot = n_expts_tot // simulated_ep
+ bitmatrix.shape[-1] = n_expts_tot
+ return expt_scal, expt_indx, bitmatrix
+
+
+def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
+ return PruneRouting.apply(
+ expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep
+ )
+
+
+# --------------------------
+# expt_data
+# --------------------------
+
+
+def log2_power_of_two(x):
+ assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two"
+ return x.bit_length() - 1
+
+
+block_m_log2_start = 4
+
+
+def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
+ MEMSET_BLOCK = 512
+ HIST2_BLOCK_M = 512
+ device = expt_hist.device
+ n_expts_tot = n_expts_tot
+ cdiv = triton.cdiv
+ # block_ms are all powers-of-two between 16 and 128 (inclusive)
+ block_m_log2_end = 9 if is_hip() else 8
+ block_m_num = block_m_log2_end - block_m_log2_start
+ if n_gates <= n_expts_tot:
+ max_n_tiles = n_gates
+ else:
+ max_n_tiles = (
+ n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start)
+ )
+ # allocate memory
+ pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
+ dtype = torch.int32
+
+ token_offs_combined = torch.empty(
+ (block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device
+ )
+
+ token_offs_raw = token_offs_combined[0][: n_expts_tot + 1]
+ token_offs_pad = token_offs_combined[1:]
+
+ block_pid_map = torch.empty(
+ (block_m_num, pad(max_n_tiles)), dtype=dtype, device=device
+ )
+ memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
+ # compute outputs
+ token_offs_pad = token_offs_pad[:, : n_expts_tot + 1]
+ block_pid_map = block_pid_map[:, :max_n_tiles]
+
+ blocks1 = memset_grid + block_m_num + 1
+ blocks2 = n_expts_tot * block_m_num
+ return (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1,
+ blocks2,
+ MEMSET_BLOCK,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ )
+
+
+def _unpack_into_dict(x):
+ block_m_log2_end = block_m_log2_start + x.shape[0]
+ x = {
+ 2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))
+ }
+ return x
+
+
+def compute_expt_data(expt_hist, n_expts_tot, n_gates):
+ if expt_hist is None:
+ return ExptData(None, None, None, None)
+
+ # this just computes the kernel arguments:
+ (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1,
+ blocks2,
+ MEMSET_BLOCK,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ ) = _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates)
+
+ _expt_data_memset[(blocks1,)](
+ expt_hist,
+ n_expts_tot, #
+ token_offs_combined,
+ token_offs_combined.stride(0), #
+ block_pid_map, #
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK=MEMSET_BLOCK, # optimization parameters
+ num_warps=4,
+ )
+ _expt_data_compute[(blocks2,)](
+ expt_hist,
+ token_offs_pad,
+ token_offs_pad.stride(0),
+ block_pid_map,
+ block_pid_map.stride(0), # outputs
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK=HIST2_BLOCK_M, # optimization parameters
+ num_warps=4,
+ )
+
+ token_offs_pad = _unpack_into_dict(token_offs_pad)
+ block_pid_map = _unpack_into_dict(block_pid_map)
+ return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+
+# --------------------------
+# routing
+# --------------------------
+
+
+def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
+ (
+ hist,
+ topk_indx,
+ gate_indx,
+ gate_scal,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ ) = sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix)
+ token_offs_pad = _unpack_into_dict(token_offs_pad)
+ block_pid_map = _unpack_into_dict(block_pid_map)
+ expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+ # pack the matmul data structure
+ gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
+ scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
+ return (
+ RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data),
+ gather_indx,
+ scatter_indx,
+ )
+
+
+def routing(
+ logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None
+):
+ from .topk import topk
+
+ if sm_first:
+ logits = torch.softmax(logits, dim=-1)
+ expt_scal, expt_indx, bitmatrix = topk(
+ logits,
+ n_expts_act, #
+ apply_softmax=not sm_first,
+ y_indx=expt_indx,
+ n_rows=n_rows,
+ )
+ n_expts_tot = logits.shape[-1] // simulated_ep
+ # mutate bitmatrix
+ if simulated_ep > 1:
+ expt_scal, expt_indx, bitmatrix = prune_routing(
+ expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep
+ )
+
+ return routing_from_bitmatrix(
+ bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
+ )
+
+
+# --------------------------
+# torch reference
+# --------------------------
+
+
+def compute_expt_data_torch(hist, n_expts_tot, n_gates):
+ # offset for each experts
+ device = hist.device
+ token_offs_raw = torch.cumsum(hist, dim=0)
+ token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw))
+ token_offs_raw = token_offs_raw.int()
+ # maximum number of tiles for all values of `block_m` considered
+ block_ms = [16, 32, 64, 128]
+ if is_hip():
+ block_ms.append(256)
+ if n_gates <= n_expts_tot:
+ max_n_tiles = n_gates
+ else:
+ # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
+ # ceil_div(x, y): -(-x // y)
+ max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms))
+ # fill up tile offset/infos for each block
+ token_offs_pad = dict()
+ block_pid_map = dict()
+ for block_m in block_ms:
+ n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
+ token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
+ token_offs_pad[block_m] = torch.cat(
+ (torch.zeros(1, device=device), token_offs_pad[block_m])
+ )
+ token_offs_pad[block_m] = token_offs_pad[block_m].int()
+ # compute data required to drive ragged batch matmul
+ block_pid_map[block_m] = -torch.ones(
+ max_n_tiles, dtype=torch.int32, device=device
+ )
+
+ # for e in range(n_expts_tot):
+ # offset = token_offs_pad[block_m][e]
+ # for b in range(n_tiles[e]):
+ # block_pid_map[block_m][offset + b] = (b << 16) + e
+
+ col = torch.arange(max_n_tiles, device=device)
+ map_vals = (
+ torch.arange(n_expts_tot, device=device)[:, None] + (col << 16)[None, :]
+ )
+ map_idxs = token_offs_pad[block_m][:-1, None] + col[None, :]
+ mask = col[None, :] < n_tiles[:, None]
+ block_pid_map[block_m].index_put_((map_idxs[mask],), map_vals.int()[mask])
+ return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+
+def topk_torch(vals, k, expt_indx, has_user_provided_indx=False):
+ # topk of experts
+ if has_user_provided_indx:
+ tk_indx = expt_indx
+ else:
+ tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
+ tk_indx = tk_indx.long()
+ tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
+ tk_indx = tk_indx.int()
+ return tk_val, tk_indx
+
+
+def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
+ has_user_provided_indx = expt_indx is not None
+ n_gates_pad = logits.shape[0] * n_expts_act
+
+ if n_rows is not None:
+ logits = logits[:n_rows, :]
+ _, n_expts_tot = logits.shape
+ if sm_first:
+ logits = torch.softmax(logits, dim=-1)
+ expt_scal, expt_indx = topk_torch(
+ logits, n_expts_act, expt_indx, has_user_provided_indx=has_user_provided_indx
+ )
+ if not sm_first:
+ expt_scal = torch.softmax(expt_scal, dim=-1)
+ # sort each token's selections by expert
+ if not has_user_provided_indx:
+ expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
+ expt_scal = torch.gather(expt_scal, 1, sort_indices)
+ # flatten topk data
+ expt_scal = expt_scal.reshape(-1)
+ expt_indx = expt_indx.reshape(-1).to(torch.int32)
+ # sort by expert_id so experts are contiguous for the matmul
+ topk_indx = torch.argsort(expt_indx, stable=True)
+ gate_indx = torch.argsort(topk_indx, stable=True)
+ gate_scal = expt_scal[topk_indx]
+ hist = torch.histc(
+ expt_indx, bins=n_expts_tot, max=n_expts_tot - 1
+ ).int() # histogram of tokens over experts
+ # pack the matmul data structure
+ gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
+ scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
+ # compute expt_data
+ expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad)
+ return (
+ RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data),
+ gather_indx,
+ scatter_indx,
+ )
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_expt_data.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_expt_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd625868fb668d1a317e193ec4d5ec24a4da6206
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_expt_data.py
@@ -0,0 +1,75 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _cdiv_pow2(n, log2_k):
+ return (n + ((1 << log2_k) - 1)) >> log2_k
+
+
+@triton.jit
+def _expt_data_memset(
+ Hist,
+ n_expts_tot,
+ MDStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ if pid <= SIZES:
+ MDStarts += pid * tile_starts_stridem
+ x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty)
+ Tile_ptrs = MDStarts + tl.arange(0, BLOCK)
+ tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1)
+
+ for i in range(0, n_expts_tot + 1, BLOCK):
+ offs_n = tl.arange(0, BLOCK) + i
+ mask_n0 = offs_n < n_expts_tot
+ hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0)
+ hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2)
+
+ tile_starts = tl.cumsum(hist_tile, 0) + x_tile
+ x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty)
+ tl.store(Tile_ptrs, tile_starts - hist_tile)
+ Tile_ptrs += BLOCK
+
+ else:
+ pid -= SIZES + 1
+ TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK)
+ tl.store(TileInfoOut, 0xFFFFFFFF)
+
+
+@triton.jit
+def _expt_data_compute(
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ expt_id = pid // SIZES
+ buff_id = pid % SIZES
+
+ MDTileStarts += buff_id * tile_starts_stridem
+ MDTileInfo += buff_id * tile_info_stridem
+
+ n_tokens = tl.load(Hist + expt_id)
+ tile_dim_log2 = first_tile_dim_log2 + buff_id
+ n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2)
+
+ tile_off = tl.load(MDTileStarts + expt_id)
+ MDTileInfo += tile_off
+
+ for block_off in range(0, n_blocks, BLOCK):
+ block_offs = block_off + tl.arange(0, BLOCK)
+ data = (block_offs << 16) + expt_id
+ tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_routing_compute.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_routing_compute.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b097cc1cc8c1117363f031cfc9a785b94a7d5ed
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_routing_compute.py
@@ -0,0 +1,241 @@
+import triton
+import triton.language as tl
+
+from ._expt_data import _expt_data_compute, _expt_data_memset
+
+
+@triton.jit
+def _routing_compute_expt_offs(
+ ExpertHist,
+ FinalExpertOffs,
+ hist_size, # histogram
+ BLOCK_N: tl.constexpr,
+):
+ loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N
+ x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty)
+ for i in range(loop_iterations):
+ offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < hist_size
+ hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
+ tok_starts = tl.cumsum(hist2, 0) - hist2 + x
+ x += tl.sum(hist2, 0)
+ tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
+ offs_n += BLOCK_N
+
+
+@triton.jit
+def _routing_compute_indx_offs(
+ PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id
+):
+ offs_m = tl.arange(0, BLOCK_M)
+ # iterate over input data
+ curr_sum = 0
+ for _ in range(0, shape_pm, BLOCK_M):
+ offs = offs_m * stride_pm + expt_id * stride_pn
+ curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
+ out = tl.cumsum(curr, 0) + curr_sum
+ curr_sum += tl.sum(curr, 0)
+ tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm)
+ offs_m += BLOCK_M
+
+
+@triton.jit
+def _keyed_add(x, y):
+ # we keep the key in the upper 16 bits of a uint32:
+ key_mask: tl.constexpr = 0xFFFF0000
+
+ kx = x & key_mask
+ ky = y & key_mask
+ z = tl.where(kx == ky, x + y - kx, y)
+ return z
+
+
+@triton.jit
+def _routing_compute_indx(
+ pid_m,
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+):
+ if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
+ n_tokens = tl.load(n_tokens)
+ n_gates = n_tokens * N_EXPTS_ACT
+
+ tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
+
+ local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
+ expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
+
+ # stable-sort by expert ID:
+ kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
+ kv_pairs = tl.sort(kv_pairs, 0)
+ expert = kv_pairs >> 16
+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xFFFF)
+ mask = expert != 0xFFFF
+ gate_scal = tl.load(ExptScal + offs, mask=mask)
+
+ # compute run lengths in expert-sorted order:
+ x = kv_pairs & 0xFFFF0000 | 0x00000001
+ expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
+ exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF
+
+ gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask)
+ gates += tl.load(TokensStart + expert, mask=mask)
+ gates += exclusive_run_lengths
+
+ tl.store(ScatterIndx + offs, gates, mask=mask)
+ tl.store(GatherIndx + gates, offs, mask=mask)
+ tl.store(GateScal + gates, gate_scal, mask=mask)
+
+
+@triton.jit
+def _combined_routing_compute(
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+ blocks2a,
+):
+ pid = tl.program_id(0)
+ if pid < blocks2a:
+ _expt_data_compute(
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES,
+ BLOCK,
+ )
+ else:
+ pid -= blocks2a
+ _routing_compute_indx(
+ pid,
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M,
+ N_EXPTS_ACT,
+ )
+
+
+@triton.jit
+def _routing_clear_bitmatrix(
+ Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr
+):
+ pid_m = tl.program_id(0)
+ cutoff_word = cutoff // 32
+ cutoff_bit = cutoff % 32
+ cutoff_mask = (1 << (cutoff_bit)) - 1
+ for start_n in range(0, shape_bn, BLOCK_N):
+ offs_n = start_n + tl.arange(0, BLOCK_N)
+ values = tl.load(
+ Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn
+ )
+ values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
+ values = tl.where(offs_n > cutoff_word, 0, values)
+ tl.store(
+ Bitmatrix + pid_m * stride_bm + offs_n * stride_bn,
+ values,
+ mask=offs_n < shape_bn,
+ )
+
+
+@triton.jit
+def _combined_routing_memset(
+ Indx,
+ size,
+ sentinel,
+ BLOCK: tl.constexpr,
+ ExpertHist,
+ FinalExpertOffs,
+ hist_size,
+ n_expts_tot,
+ PartialHist,
+ shape_pm,
+ stride_pm,
+ stride_pn,
+ MDStarts,
+ tile_starts_stridem,
+ blocks1a,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK_A: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ """
+ This kernel essentially combines 6 different pieces of functionality,
+ statically branching on the value of tl.program_id(0) to decide which
+ codepath to take.
+
+ pid == 0: create the token cumsum
+ 1 <= pid <= SIZES: create a tile cumsum
+ SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
+ blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
+ pid == blocks1a + n_expts_tot: compute_expt_offs
+ pid > blocks1a + n_expts_tot: initialise Indx to sentinel
+
+ As each of these is a relatively trivial workload, launching them from
+ this single trampoline is beneficial as they can execute on different
+ streaming multiprocesses in parallel.
+ """
+
+ pid = tl.program_id(0)
+
+ if pid < blocks1a:
+ _expt_data_memset(
+ ExpertHist,
+ n_expts_tot,
+ MDStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES,
+ BLOCK_A,
+ )
+ elif pid == n_expts_tot + blocks1a:
+ _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
+ elif pid < n_expts_tot + blocks1a:
+ _routing_compute_indx_offs(
+ PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a
+ )
+ else:
+ offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK)
+ mask = offs < size
+ tl.store(Indx + offs, sentinel, mask=mask)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/specialize.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/specialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcf44d70cb47664e6a837ec4cf0d28f04fbb1c16
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/specialize.py
@@ -0,0 +1,143 @@
+import inspect
+import re
+import textwrap
+import types
+import triton
+
+
+def cacheable(f):
+ """
+ A decorator that allow you to write something of the form:
+
+ @cacheable
+ def my_kernel(): return (expression dynamically defining a kernel)
+
+ such that it interacts gracefully with triton cache and preload.
+ """
+
+ g = f()
+ g.fn.__name__ = f.__name__
+ g.fn.__module__ = f.__module__
+ g.fn.__qualname__ = f.__qualname__
+ g.__name__ = f.__name__
+ g.__module__ = f.__module__
+ g.__qualname__ = f.__qualname__
+ g._fn_name = f"{f.__module__}.{f.__qualname__}"
+ return g
+
+
+def define_kernel(src, module, attrs=None, **extra_globals):
+ """
+ Dynamically create a Triton function or kernel from a src string,
+ linking any symbols in the kernel to objects specified by extra_globals.
+ """
+
+ # create templace function
+ def _empty_fn():
+ pass
+
+ gdict = dict(**(_empty_fn.__globals__))
+ gdict.update(extra_globals)
+ f = types.FunctionType(_empty_fn.__code__, gdict)
+ f.__module__ = module.__name__
+
+ src = textwrap.dedent(src)
+ src = src[src.find("def ") :]
+
+ stored_functions = []
+ function_name = src[4:].split("(")[0].strip()
+
+ exec_globals = gdict
+ exec_globals.update({"stored_functions": stored_functions})
+ exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
+
+ f.__signature__ = inspect.signature(stored_functions[0])
+ f.__name__ = function_name
+ f.__doc__ = stored_functions[0].__doc__
+
+ if attrs is None:
+ attrs = dict()
+ f = triton.JITFunction(f, **attrs)
+ f._unsafe_update_src(src)
+ return f
+
+
+def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
+ if name is None:
+ name = f"{fn.__name__}"
+ # Get original source code
+ src = inspect.getsource(fn.fn)
+ src = textwrap.dedent(src)
+ lines = src.split("\n")
+ # Skip decorator and def line
+ def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def"))
+ # separate header vs body LOC
+ header_end = def_idx
+ while not lines[header_end].rstrip().endswith(":"):
+ header_end += 1
+ body_lines = lines[header_end + 1 :]
+ header_lines = lines[def_idx : header_end + 1]
+ # clean-up header
+ header_clean = [
+ l.split("#", 1)[0].strip() # keep code, discard comment
+ for l in header_lines
+ if l.split("#", 1)[0].strip() # skip blank‑after‑comment lines
+ ]
+ # decompose arguments
+ header_src = " ".join(header_clean) # turn it into a single line
+ m = re.search(r"\((.*)\)\s*:", header_src)
+ if not m:
+ raise ValueError("Could not parse function header")
+ args_str = m.group(1)
+ args = [arg.strip() for arg in args_str.split(",") if arg.strip()]
+ non_specialized_args = []
+ for arg in args:
+ arg_key = arg.split(":")[0].split("=")[0].strip()
+ new_args = tuples.get(arg_key, [arg])
+ if arg_key not in constants:
+ non_specialized_args += new_args
+ # add global symbols
+ spec_fns = {
+ v.__name__: v
+ for k, v in constants.items()
+ if isinstance(v, triton.runtime.jit.JITFunction)
+ }
+ globals = spec_fns | fn.get_capture_scope()
+ # build new source code and define kernel dynamically
+ new_signature = f"def {name}({', '.join(non_specialized_args)}):"
+ constexpr_lines = [
+ f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}"
+ for key, value in constants.items()
+ ]
+ tuple_lines = [
+ f" {key} = {'(' + ','.join(value) + (',' if len(value) >= 1 else '') + ')'}"
+ for key, value in tuples.items()
+ ]
+ new_src = "\n".join(
+ ["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines
+ )
+ # find function parameters
+ sig = inspect.signature(triton.runtime.jit.JITFunction.__init__)
+ params = list(sig.parameters.values())[2:]
+ attrs = {param.name: getattr(fn, param.name, param.default) for param in params}
+
+ # make a new repr which appends the repr of the specialized functions.
+ base_repr = attrs["repr"]
+
+ def new_repr(specialization):
+ ret = base_repr(specialization)
+ for spec_fn in spec_fns.values():
+ spec_repr = spec_fn.repr(None)
+ if spec_repr:
+ spec_repr = spec_repr.strip("_")
+ if spec_repr:
+ ret += f"_{spec_repr}"
+ return ret
+
+ attrs["repr"] = new_repr
+
+ if do_not_specialize:
+ attrs["do_not_specialize"] = do_not_specialize
+ ret = define_kernel(new_src, module, attrs, **globals)
+ return ret
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3619a4be74113d34c2bf0138a0bea5eb2d29788
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py
@@ -0,0 +1,99 @@
+from dataclasses import dataclass
+from compactor_vllm.triton_kernels.numerics import InFlexData, OutFlexData
+import torch
+import triton
+from .swiglu_details._swiglu import _swiglu, _swiglu_fn
+from compactor_vllm.triton_kernels import target_info
+
+
+@dataclass(frozen=True)
+class FlexCtx:
+ out_data: OutFlexData = OutFlexData()
+ inp_data: InFlexData = InFlexData()
+ saturate_inf: bool = False
+
+
+@dataclass(frozen=True)
+class PrecisionConfig:
+ limit: float
+ flex_ctx: FlexCtx = FlexCtx()
+
+
+swiglu_fn = _swiglu_fn
+
+
+class SwiGLU(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, a, alpha, precision_config, routing_data):
+ N = a.shape[-1]
+ M = a.numel() // N
+ assert a.stride()[-1] == 1
+ assert a.shape[-1] % 2 == 0
+ out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
+ flex_ctx = precision_config.flex_ctx
+ # optimization hyperparameters
+ BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
+ num_warps = 4
+ kwargs = {"maxnreg": 64} if not target_info.is_hip() else {}
+ # launch semi-persistent kernel
+ N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
+ num_sms = target_info.num_sms()
+ if routing_data is not None:
+ waves_per_sm = 32 if target_info.is_hip() else 128
+ num_pid = num_sms * (waves_per_sm // num_warps)
+ M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
+ grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms),)
+ else:
+ M_BLOCKS = triton.cdiv(M, BLOCK_M)
+ if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
+ grid = (8 * num_sms,)
+ else:
+ grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms),)
+ n_tokens = None
+ if routing_data is not None:
+ n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
+ _swiglu[grid](
+ flex_ctx.out_data.reinterpret(out),
+ flex_ctx.out_data.expected_scale,
+ flex_ctx.out_data.actual_scale,
+ flex_ctx.out_data.checksum_scale,
+ flex_ctx.inp_data.reinterpret(a),
+ flex_ctx.inp_data.scale,
+ alpha,
+ M,
+ N // 2,
+ a.shape[-1],
+ 1,
+ out.shape[-1],
+ 1,
+ precision_config.limit,
+ n_tokens,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ EVEN_N=(N // 2) % BLOCK_N == 0,
+ M_BLOCKS=M_BLOCKS,
+ N_BLOCKS=N_BLOCKS,
+ flexpoint_saturate_inf=flex_ctx.saturate_inf,
+ num_warps=num_warps,
+ **kwargs,
+ )
+ out = out.view(a.shape[:-1] + out.shape[-1:])
+ return out
+
+
+def swiglu(a, alpha, precision_config, routing_data=None):
+ return SwiGLU.apply(a, alpha, precision_config, routing_data)
+
+
+def swiglu_torch(a, alpha, precision_config):
+ limit = precision_config.limit
+ a_gelu = a[..., ::2]
+ if limit is not None:
+ a_gelu = a_gelu.clamp(max=limit)
+ a_linear = a[..., 1::2]
+ if limit is not None:
+ a_linear = a_linear.clamp(min=-limit, max=limit)
+
+ out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
+ out = out_gelu * (a_linear + 1)
+ return out
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb7644271a5360cbd89b4041a179361fd4197b80
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py
@@ -0,0 +1,141 @@
+from compactor_vllm.triton_kernels.numerics_details.flexpoint import (
+ load_scale,
+ float_to_flex,
+ update_scale,
+)
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def clip(x, limit, clip_lower: tl.constexpr):
+ res = tl.minimum(x, limit)
+ if clip_lower:
+ res = tl.maximum(-limit, res)
+ return res
+
+
+@triton.jit
+def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
+ return tl.max(
+ tl.reshape(
+ tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True
+ ),
+ axis=1,
+ )
+
+
+def swiglu_repr(specialization):
+ signature = specialization.signature
+ constants = specialization.constants
+ convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype
+ dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]])
+ blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]])
+ return f"_swiglu_{dtypes}_{blocks}"
+
+
+def swiglu_launch_metadata(grid, kernel, args):
+ M, N = args["M"], args["N"]
+ ret = dict()
+ ret["name"] = f"{kernel.name} [M = {M}, N = {N}]"
+ A, Out = args["A"], args["Out"]
+ ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size()
+ return ret
+
+
+@triton.jit
+def compute_swiglu(gelu, linear, scale, alpha, limit):
+ gelu = gelu.to(tl.float32) * scale
+ if limit is not None:
+ gelu = clip(gelu, limit, clip_lower=False)
+ linear = linear.to(tl.float32) * scale
+ if limit is not None:
+ linear = clip(linear, limit, clip_lower=True)
+ s = gelu / (1 + tl.exp(-alpha * gelu))
+ return tl.fma(s, linear, s) # (s * (linear + 1))
+
+
+@triton.jit(repr=lambda _: "_swiglu")
+def _swiglu_fn(input, alpha, limit):
+ gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
+ return compute_swiglu(gelu, linear, 1.0, alpha, limit)
+
+
+@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
+def _swiglu(
+ Out,
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale,
+ A,
+ AScale,
+ alpha,
+ M,
+ N,
+ stride_am,
+ stride_an,
+ stride_outm,
+ stride_outn,
+ limit: tl.constexpr,
+ NTokens,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ EVEN_N: tl.constexpr,
+ M_BLOCKS,
+ N_BLOCKS,
+ flexpoint_saturate_inf: tl.constexpr,
+):
+ if NTokens is not None:
+ M = tl.load(NTokens)
+ M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
+
+ local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32)
+
+ a_scale = load_scale(AScale)
+ out_expected_scale = load_scale(OutExpectedScale)
+
+ for pid in tl.range(
+ tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2
+ ):
+ pid_m = pid // N_BLOCKS
+ pid_n = pid % N_BLOCKS
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_m = off_m < M
+ mask_n = off_n < N
+ packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2
+ packed_mask_n = packed_off_n < N
+ packed_mask_n = tl.max_constancy(packed_mask_n, [16])
+ # load a
+ packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N)
+ packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an
+ if EVEN_N:
+ a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.0)
+ else:
+ if pid_n * BLOCK_N + BLOCK_N <= N:
+ a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.0)
+ else:
+ packed_mask = mask_m[:, None] & packed_mask_n[None, :]
+ a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.0)
+ a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
+ out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit)
+ # update flexpoint stats and divide by scale
+ # we don't need masking because of the `other` when loading `A`
+ if OutActualScale is not None:
+ absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads())
+ local_max = tl.maximum(local_max, absmax)
+ out = float_to_flex(
+ out,
+ out_expected_scale,
+ None, # ActualScale: local absmax is tracked and updated after the loop
+ OutChecksumScale,
+ None,
+ Out,
+ flexpoint_saturate_inf,
+ )
+ mask = mask_m[:, None] if EVEN_N else mask_m[:, None] & mask_n[None, :]
+ tl.store(
+ Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask
+ )
+
+ update_scale(local_max, OutActualScale, Out)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/target_info.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/target_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ae4303c512241455cc8aed5a85a2edb1c1c8eb
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/target_info.py
@@ -0,0 +1,54 @@
+import torch
+import triton
+import triton.language as tl
+
+from triton.language.target_info import (
+ cuda_capability_geq,
+ is_cuda,
+ is_hip,
+ is_hip_cdna3,
+ is_hip_cdna4,
+)
+
+__all__ = [
+ "cuda_capability_geq",
+ "get_cdna_version",
+ "has_tma_gather",
+ "has_native_mxfp",
+ "is_cuda",
+ "is_hip",
+ "is_hip_cdna3",
+ "is_hip_cdna4",
+ "num_sms",
+]
+
+
+@triton.constexpr_function
+def get_cdna_version():
+ """
+ Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
+ only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD
+ hardware or unsupported architecture
+ """
+ target = tl.target_info.current_target()
+ if target.backend != "hip":
+ return -1
+ if target.arch == "gfx942":
+ return 3
+ if target.arch == "gfx950":
+ return 4
+ return -1
+
+
+@triton.constexpr_function
+def has_tma_gather():
+ return cuda_capability_geq(10, 0)
+
+
+@triton.constexpr_function
+def has_native_mxfp():
+ return cuda_capability_geq(10, 0)
+
+
+def num_sms():
+ return torch.cuda.get_device_properties(0).multi_processor_count
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6992e942365b2cf52701be8d013f174dd4458784
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor.py
@@ -0,0 +1,227 @@
+from dataclasses import dataclass, fields
+from typing import Type
+
+import torch
+from triton.tools.tensor_descriptor import TensorDescriptor
+from triton.tools.ragged_tma import create_ragged_descriptor
+
+from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
+from .target_info import cuda_capability_geq
+from .tensor_details.layout import Layout, StridedLayout
+
+
+@dataclass
+class Storage:
+ data: torch.Tensor
+ layout: Layout = None
+
+ def __post_init__(self):
+ assert isinstance(self.data, torch.Tensor)
+ if self.layout is None:
+ self.layout = StridedLayout(self.data.shape)
+
+ @property
+ def device(self):
+ return self.data.device
+
+ def is_tma_compliant(self):
+ # TMAs didn't exist until Hopper
+ if not cuda_capability_geq(9, 0):
+ return False
+ # TMAs only exist for 2D, 3D, 5D inputs
+ if len(self.data.shape) not in [2, 3, 5]:
+ return False
+ # TMAs need at most one stride equal to 1
+ # and all other strides divisble by 16
+ strides = list(self.data.stride())
+ try:
+ major_dim = strides.index(1)
+ except ValueError:
+ major_dim = -1
+ ndim = self.data.ndim
+ bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
+ compliant = [
+ strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim
+ ]
+ return all(compliant)
+
+ def make_dense_tma(self, block_shape, transpose=False):
+ strides = list(self.data.stride())
+ shape = list(self.data.shape)
+ transpose = self.data.stride()[-1] != 1
+ if transpose:
+ block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
+ shape = shape[:-2] + [shape[-1], shape[-2]]
+ strides = strides[:-2] + [strides[-1], strides[-2]]
+ if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
+ indx = strides.index(1)
+ block_shape[indx] = block_shape[indx] // 2
+ if shape[-1] % 128 != 0:
+ raise ValueError(
+ "inner shape need to be multiple of 128 for "
+ "mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
+ )
+ block_shape = self.layout.swizzle_block_shape(block_shape)
+ return TensorDescriptor(self.data, shape, strides, block_shape)
+
+ def make_tma(self, block_shape, mode, transpose=False):
+ if mode in ["dense", "gather", "scatter"]:
+ return self.make_dense_tma(block_shape, transpose)
+ assert mode == "ragged"
+ ragged_dim = len(self.data.shape) - 2
+ return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)
+
+
+@dataclass
+class IntegerType:
+ bitwidth: int
+
+
+@dataclass
+class FloatType:
+ bitwidth_exponent: int
+ bitwidth_mantissa: int
+ is_signed: bool
+
+ def __post_init__(self):
+ self.bitwidth = (
+ int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
+ )
+
+
+BIT = IntegerType(1)
+FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
+
+
+def bitwidth(type: IntegerType | FloatType | torch.dtype):
+ if isinstance(type, torch.dtype):
+ return type.itemsize * 8
+ return type.bitwidth
+
+
+@dataclass
+class Tensor:
+ storage: Storage | torch.Tensor
+ dtype: IntegerType | FloatType | torch.dtype = None
+ shape: list[int] | None = None
+ shape_max: list[int] | None = None
+
+ def __post_init__(self):
+ # set storage
+ if isinstance(self.storage, torch.Tensor):
+ self.storage = Storage(self.storage)
+ # initialize dtype
+ if self.dtype is None:
+ self.dtype = self.storage.data.dtype
+ if bitwidth(self.dtype) < 8 and self.shape is None:
+ raise ValueError("shape must be provided for sub-byte types")
+ # initialize shape
+ if self.shape is None:
+ self.shape = list(self.storage.data.shape)
+ # validate shape: all elements must be `int` or numel-1 `torch.Tensor`
+ is_int = lambda s: isinstance(s, int)
+ is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
+ assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
+ # initialize shape_max
+ if self.shape_max is None:
+ self.shape_max = [None] * len(self.shape)
+ for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
+ if smax is not None and not is_int(smax):
+ raise ValueError(
+ f"shape_max[{i}] must be `int` or `None`; got {type(smax)}"
+ )
+ if smax is None:
+ self.shape_max[i] = s
+ # validate shape_max: all elements must be `int`
+ assert all(map(is_int, self.shape_max))
+
+ # torch compatibility layer
+ @property
+ def ndim(self):
+ return len(self.shape)
+
+ @property
+ def device(self):
+ return self.storage.device
+
+ def stride(self, i=None):
+ return self.storage.data.stride() if i is None else self.storage.data.stride(i)
+
+ def data_ptr(self):
+ return self.storage.data.data_ptr()
+
+ def numel(self):
+ return self.storage.data.numel()
+
+ def element_size(self):
+ return bitwidth(self.dtype) // 8
+
+ @property
+ def data(self):
+ t = self.storage
+ return t.data if isinstance(t, Storage) else t
+
+ def dim(self):
+ return self.ndim
+
+ def size(self, i=None):
+ if i is None:
+ return self.shape
+ return self.shape[i]
+
+
+@dataclass
+class Bitmatrix(Tensor):
+ """
+ Represents a boolean matrix in a packed format where each element occupies
+ a single bit of memory.
+
+ _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
+ with the actual bitmatrix to avoid having to launch a separate memset
+ kernel when we call Bitmatrix::sum().
+ """
+
+ scratchpad: torch.Tensor = None
+
+ def __init__(self, storage, shape, shape_max=None, scratchpad=None):
+ super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
+ self.scratchpad = scratchpad
+
+ def sum(self, partials_block_size):
+ _, n_cols = self.shape
+ dev = self.device
+ if self.scratchpad is None:
+ self.scratchpad = clear_sums(n_cols, dev)
+ out_ret = self.scratchpad[:n_cols]
+ self.scratchpad = None # throw error if we try to sum again
+ return sum_bitmatrix_rows(self, out_ret, partials_block_size)
+
+
+def get_layout(tensor: torch.Tensor | Tensor | None):
+ if tensor is None:
+ return None
+ if isinstance(tensor, Tensor):
+ return tensor.storage.layout
+ return StridedLayout
+
+
+def wrap_torch_tensor(torch_tensor, dtype=None):
+ if dtype is None:
+ dtype = torch_tensor.dtype
+ shape = list(torch_tensor.shape)
+ shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(
+ dtype
+ )
+ return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
+
+
+def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
+ assert isinstance(tensor, Tensor)
+ old_storage = tensor.storage
+ old_data = old_storage.layout.unswizzle_data(old_storage.data)
+ new_layout = layout_cls(old_data.shape, **layout_kwargs)
+ new_data = new_layout.swizzle_data(old_data)
+ attrs = {
+ k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"
+ }
+ return Tensor(Storage(new_data, new_layout), **attrs)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout.py
new file mode 100644
index 0000000000000000000000000000000000000000..98122f3517a593b1bc479c43d8d64fb64191a7af
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout.py
@@ -0,0 +1,40 @@
+from .layout_details.base import Layout
+from .layout_details.blackwell_scale import BlackwellMXScaleLayout
+from .layout_details.blackwell_value import BlackwellMXValueLayout
+from .layout_details.hopper_scale import HopperMXScaleLayout
+from .layout_details.hopper_value import HopperMXValueLayout
+from .layout_details.cdna4_scale import CDNA4MXScaleLayout
+from .layout_details.strided import StridedLayout
+from ..target_info import cuda_capability_geq, is_hip_cdna4
+
+__all__ = [
+ "Layout",
+ "BlackwellMXValueLayout",
+ "BlackwellMXScaleLayout",
+ "HopperMXScaleLayout",
+ "HopperMXValueLayout",
+ "CDNA4MXScaleLayout",
+ "StridedLayout",
+]
+
+
+def make_default_matmul_mxfp4_w_layout(mx_axis: int):
+ if cuda_capability_geq(10):
+ # return StridedLayout, dict()
+ return BlackwellMXValueLayout, dict()
+ elif cuda_capability_geq(9):
+ return HopperMXValueLayout, {"mx_axis": mx_axis}
+ else:
+ return StridedLayout, dict()
+
+
+def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
+ if is_hip_cdna4():
+ return CDNA4MXScaleLayout, dict()
+ else:
+ if cuda_capability_geq(10):
+ return BlackwellMXScaleLayout, dict()
+ elif cuda_capability_geq(9):
+ return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
+
+ return StridedLayout, dict()
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/base.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d23dab8f42abd1d87bf77c08c3b64c1efe4d3e3
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/base.py
@@ -0,0 +1,18 @@
+from abc import ABC, abstractmethod
+
+
+class Layout(ABC):
+ def __init__(self, shape) -> None:
+ self.initial_shape = shape
+
+ @abstractmethod
+ def swizzle_data(self, data):
+ pass
+
+ @abstractmethod
+ def unswizzle_data(self, data):
+ pass
+
+ @abstractmethod
+ def swizzle_block_shape(self, block_shape):
+ pass
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..a54a300cfdd906dec1a78aaf4f48259529659cdf
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_scale.py
@@ -0,0 +1,81 @@
+import math
+import triton
+import triton.language as tl
+import torch
+from .base import Layout
+
+SWIZZLE_ALIGN_INNER = 8
+SWIZZLE_SIZE_INNER = 4
+SWIZZLE_SIZE_OUTER = 128
+
+
+class BlackwellMXScaleLayout(Layout):
+ name: str = "BLACKWELL_SCALE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+ (
+ *self.leading_shape,
+ self.K,
+ self.N,
+ ) = shape
+ self.B = math.prod(self.leading_shape)
+ self.ALIGN_K = 8
+ self.ALIGN_N = 128
+ self.SWIZZLE_K = 4
+ self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K
+ self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N
+
+ def swizzle_data(self, data):
+ data = torch.nn.functional.pad(
+ data, (0, self.N_pad - self.N, 0, self.K_pad - self.K)
+ )
+ data = data.transpose(-1, -2).contiguous()
+ data = data.reshape(
+ self.B,
+ self.N_pad // self.ALIGN_N,
+ self.ALIGN_N // 32,
+ 32,
+ self.K_pad // self.SWIZZLE_K,
+ self.SWIZZLE_K,
+ )
+ data = data.transpose(2, 4).contiguous()
+ data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
+ return data
+
+ def unswizzle_data(self, data):
+ data = data.reshape(
+ self.B,
+ self.N_pad // self.ALIGN_N,
+ self.K_pad // self.SWIZZLE_K,
+ 32,
+ self.ALIGN_N // 32,
+ self.SWIZZLE_K,
+ )
+ data = data.transpose(2, 4)
+ data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad)
+ data = data.transpose(-1, -2)
+ return data[..., : self.K, : self.N]
+
+ def swizzle_block_shape(self, block_shape):
+ MX_PACK_DIVISOR = 32
+ MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
+ return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
+
+
+@triton.jit
+def unswizzle_mx_scale_bw(
+ x,
+ SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER,
+ SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER,
+ ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER,
+):
+ shape_0: tl.constexpr = x.shape[0]
+ shape_1: tl.constexpr = x.shape[1]
+ tl.static_assert(shape_1 % SIZE_OUTER == 0)
+ tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER)
+ x = x.reshape(
+ shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER
+ )
+ x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER)
+ return x
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_value.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_value.py
new file mode 100644
index 0000000000000000000000000000000000000000..622744888b91eb0c99ba6d9c7fb150acb2d89702
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_value.py
@@ -0,0 +1,37 @@
+import torch
+from .base import Layout
+
+
+class BlackwellMXValueLayout(Layout):
+ name: str = "BLACKWELL_VALUE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+ self.shape = shape
+
+ def swizzle_data(self, data):
+ # permutation needed to make `data` row major
+ to_row_major = sorted(range(data.ndim), key=lambda d: (data.stride(d), d))[::-1]
+ # permutation needed to retrieve original order
+ inv = [0] * data.ndim
+ for i, d in enumerate(to_row_major):
+ inv[d] = i
+ # leading dimension must be padded to be aligned to 128
+ align_dim = lambda x: (x + 128 - 1) // 128 * 128
+ major_dim = data.stride().index(1)
+ pad = align_dim(data.shape[major_dim]) - data.shape[major_dim]
+ data = torch.nn.functional.pad(data.permute(to_row_major), (0, pad)).permute(
+ inv
+ )
+ return data
+
+ def unswizzle_data(self, data: torch.Tensor):
+ # Trim padding along all dims back to the original shape recorded at init.
+ assert data.ndim == len(self.shape), (
+ "Rank mismatch between data and recorded shape"
+ )
+ sizes = [min(data.size(i), self.shape[i]) for i in range(data.ndim)]
+ return data[tuple(slice(0, s) for s in sizes)]
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/cdna4_scale.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/cdna4_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..beecaee3e12d93294df0365010966e15d625635e
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/cdna4_scale.py
@@ -0,0 +1,50 @@
+import triton
+import triton.language as tl
+from .base import Layout
+
+NON_K_PRESHUFFLE_BLOCK_SIZE = 32
+
+
+class CDNA4MXScaleLayout(Layout):
+ name: str = "CDNA4_SCALE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+
+ def swizzle_data(self, data):
+ block_shape = data.shape
+ SCALE_K = block_shape[-2]
+ N = block_shape[-1]
+ data = data.transpose(-1, -2)
+ data = data.view(
+ -1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1
+ )
+ data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
+ if len(block_shape) == 3:
+ E = block_shape[0]
+ data = data.reshape(E, N // 32, SCALE_K * 32)
+ else:
+ assert len(block_shape) == 2
+ data = data.reshape(N // 32, SCALE_K * 32)
+ return data.transpose(-1, -2)
+
+ def unswizzle_data(self, data):
+ raise NotImplementedError()
+
+ def swizzle_block_shape(self, block_shape):
+ SCALE_K = block_shape[-2]
+ N = block_shape[-1]
+ return block_shape[:-2] + [N // 32, SCALE_K * 32]
+
+
+@triton.jit
+def unswizzle_mx_scale_cdna4(
+ x,
+ BLOCK_N: tl.constexpr,
+ MX_SCALE_BLOCK_K: tl.constexpr,
+ N_PRESHUFFLE_FACTOR: tl.constexpr = NON_K_PRESHUFFLE_BLOCK_SIZE,
+):
+ x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1)
+ x = x.permute(0, 5, 3, 1, 4, 2, 6)
+ x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K)
+ return x
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_scale.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ef61e889b2c4c38bad4832bd160734a4b492b26
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_scale.py
@@ -0,0 +1,91 @@
+import torch
+import triton
+import triton.language as tl
+from .base import Layout
+
+
+class HopperMXScaleLayout(Layout):
+ name: str = "HOPPER_SCALE"
+
+ def __init__(self, shape, mx_axis, num_warps=8) -> None:
+ assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
+ super().__init__(shape)
+ self.mx_axis = mx_axis
+ self.num_warps = num_warps
+ *self.leading_shape, _, _ = shape
+
+ def _maybe_mT(self, data):
+ if self.mx_axis == len(self.leading_shape):
+ return data.contiguous().mT
+ return data
+
+ def swizzle_data(self, data):
+ data = self._maybe_mT(data).contiguous()
+ *batch, M, K = data.shape
+ SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
+ SWIZZLE_ALIGN_K = 2
+ pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
+ pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
+ data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
+ *batch, M, K = data.shape
+ assert data.is_contiguous()
+ assert M % (2 * self.num_warps * 2 * 8) == 0 and K % 2 == 0, (
+ f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
+ )
+ b = len(batch)
+ data = data.reshape(
+ *batch,
+ M // (2 * self.num_warps * 2 * 8),
+ 2,
+ self.num_warps,
+ 2,
+ 8,
+ K // 2,
+ 2,
+ )
+ perm = [0, 2, 5, 1, 4, 6, 3]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.flatten(-5, -1)
+ data = data.flatten(-3, -2)
+ assert data.shape[-2] == M // 32
+ assert data.shape[-1] == K * 32
+ data = self._maybe_mT(data)
+ return data
+
+ def unswizzle_data(self, data):
+ data = self._maybe_mT(data)
+ *batch, M, K = data.shape
+ b = len(batch)
+ data = data.reshape(
+ *batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2
+ )
+ perm = [0, 3, 1, 6, 4, 2, 5]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.reshape(*batch, M * 32, K // 32)
+ data = self._maybe_mT(data)
+ return data
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
+
+
+@triton.jit
+def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
+ """
+ Triton inverse of swizzle_mxfp4_scale_hopper
+ """
+ tl.static_assert(len(x.shape) == 2, "NYI")
+ # implementation assumes mxfp data is packed along the last dimension
+ x = x.trans() if mx_axis == 0 else x
+ M: tl.constexpr = x.shape[0]
+ K: tl.constexpr = x.shape[1]
+ tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
+ tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
+ x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
+ x = x.trans(0, 3, 1, 6, 4, 2, 5)
+ x = x.reshape(M * 32, K // 32)
+ # implementation assumed mxfp data is packed along the last dimension
+ x = x.trans() if mx_axis == 0 else x
+ return x
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_value.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_value.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ddfadf09427f519bc9867094c7855d9d12eac7
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/hopper_value.py
@@ -0,0 +1,362 @@
+import torch
+import triton
+import triton.language as tl
+from .base import Layout
+
+
+def right_shift_unsigned(x, shift):
+ return (x >> shift) & ((1 << (32 - shift)) - 1)
+
+
+# -----------------------------------------------------------------------
+# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
+# 1000000111000000 (first fp4)
+# 1000000111000000 (second fp4)
+# 1000000111000000 (third fp4)
+# 0110110000000000 (fourth fp4)
+# This is done so that dequantization can be done in 14 SASS instructions
+# -----------------------------------------------------------------------
+
+
+def _compress_fp4(x):
+ x = x.to(torch.int32)
+ return ((x & 0x8) << 12) | ((x & 0x7) << 6)
+
+
+def _compress_fourth(x):
+ x = x.to(torch.int32)
+ return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13)
+
+
+def _pack_bits(x: torch.Tensor, mx_axis: int):
+ x = x.contiguous()
+ assert x.shape[-1] % 4 == 0, (
+ "Input tensor must have a last dimension divisible by 4"
+ )
+ x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4))
+ first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16)
+ second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16)
+ third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16)
+ fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16)
+ x = (
+ first
+ | right_shift_unsigned(second, 3)
+ | right_shift_unsigned(third, 6)
+ | fourth
+ )
+ assert x.is_contiguous()
+ x = x.view(torch.uint8)
+ return x
+
+
+# -----------------------------------------------------------------------
+# inverse operation of _pack_bits
+# -----------------------------------------------------------------------
+
+
+def _bf16_to_fp4e2m1(x):
+ # 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
+ assert x.dtype == torch.int16
+ s = (right_shift_unsigned(x, 15) & 0x1) << 3
+ em = right_shift_unsigned(x, 6) & 0x7
+ return (s | em).to(torch.uint8)
+
+
+def _bf16x2_to_fp4e2m1x2(x):
+ # 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
+ assert x.dtype == torch.int32
+ lo = (x & 0xFFFF).to(torch.int16)
+ hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16)
+ ret_lo = _bf16_to_fp4e2m1(lo)
+ ret_hi = _bf16_to_fp4e2m1(hi)
+ return ret_lo | (ret_hi << 4)
+
+
+def _unpack_bits(x, mx_axis: int):
+ x = x.view(torch.int32)
+ m = 0b10000001110000001000000111000000
+ a = (x << 1) & 0b10000000000000001000000000000000
+ b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000
+ c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000
+ unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c]
+ x = torch.stack(unpacked, dim=-1)
+ x = x.flatten(-2, -1)
+ x = _bf16x2_to_fp4e2m1x2(x)
+ return x
+
+
+# -----------------------------------------------------------------------
+
+
+class HopperMXValueLayout(Layout):
+ name: str = "HOPPER_VALUE"
+
+ def __init__(self, shape, mx_axis, mma_version=3):
+ super().__init__(shape)
+ assert mx_axis in range(len(shape))
+ self.mx_axis = mx_axis
+ self.mma_version = mma_version
+ (
+ *self.leading_shape,
+ self.K,
+ self.N,
+ ) = shape
+
+ def _maybe_mT(self, data):
+ if self.mx_axis == len(self.leading_shape):
+ return data.mT
+ return data
+
+ def swizzle_data(self, data):
+ """
+ Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
+ (*, M // 4, K * 4) such that:
+
+ 1) Groups contiguously all the elements owned by the same thread of 4
+ mma tiles along the K axis. The following animation shows a similar
+ grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
+ as done here:
+ https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
+
+ 2) Moves the elements belonging to thread 4-7 to be contiguous with those
+ from thread 0-3. This is done to get a full cache line when loading them
+ from HBM.
+
+ mx_axis selects the lhs or rhs of the matmul.
+
+ WARNING: Assumes that the matmul will be done in bf16 or fp16!
+ Implementing it for fp8 is as easy as making the tile size (8, 8)
+ """
+ batch = data.ndim - 2
+ assert batch >= 0
+ assert self.mma_version in (2, 3)
+ data = self._maybe_mT(data)
+ init_shape = data.shape
+
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
+
+ # Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
+ contig = (1, u8_kwidth)
+ scott_trick = (2, 1)
+ threads = (4, 4)
+ warp_tile = (2, 2)
+ k_tile = (1, 4 // u8_kwidth)
+
+ sizes = list(data.shape[:-2])
+ pads = []
+ # [rest, K, tile, threads] per dimension
+ for i, (a, b, c, s, d) in enumerate(
+ zip(k_tile, warp_tile, threads, scott_trick, contig)
+ ):
+ pack = a * b * c * s * d
+ size = data.shape[batch + i]
+ pad = (pack - size % pack) % pack
+ pads += [(0, pad)]
+ sizes.append((size + pad) // pack)
+ sizes += [a, b, c, s, d]
+
+ pads = tuple(x for t in pads[::-1] for x in t)
+ data = torch.nn.functional.pad(data, pads)
+ init_shape = data.shape
+ # 0: rest[0]
+ # 1: k_tile[0]
+ # 2: warp_tile[0]
+ # 3: threads[0]
+ # 4: scott_trick[0]
+ # 5: contig[0]
+ # 6: rest[1]
+ # 7: k_tile[1]
+ # 8: warp_tile[1]
+ # 9: threads[1]
+ # 10: scott_trick[1]
+ # 11: contig[1]
+ data = data.view(*sizes)
+ # Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
+ perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11]
+ perm = list(range(batch)) + [batch + p for p in perm]
+ data = data.permute(*perm).contiguous()
+ # These are views
+ data = data.flatten(-10, -1)
+ data = data.flatten(-3, -2)
+ assert data.is_contiguous()
+ assert data.shape[-2] == init_shape[-2] // 4
+ assert data.shape[-1] == init_shape[-1] * 4
+ # twiddle the bits
+ data = _pack_bits(data, self.mx_axis)
+ data = self._maybe_mT(data)
+ return data
+
+ def unswizzle_data(self, data):
+ data = self._maybe_mT(data)
+ data = _unpack_bits(data, self.mx_axis)
+ *batch, M, K = data.shape
+ # We have two times the elements if we already upcasted to bfloat16
+ mult = 2 if data.dtype == torch.bfloat16 else 1
+ assert M % 4 == 0, "M must be divisible by 4"
+ assert K % (4 * 8 * 2 * 2 * mult) == 0, (
+ f"K must be divisible by {4 * 8 * 2 * 2 * mult}"
+ )
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
+ data = data.reshape(
+ *batch,
+ M // 4,
+ 4,
+ K // (4 * 8 * 2 * 2 * mult),
+ 2,
+ 4,
+ 8 // u8_kwidth,
+ 2,
+ u8_kwidth * mult,
+ )
+ b = len(batch)
+ perm = [0, 6, 1, 3, 2, 5, 4, 7]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.reshape(*batch, M * 4, K // 4)
+ data = self._maybe_mT(data)
+ return data[..., : self.K, : self.N]
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
+
+
+@triton.jit
+def _unshuffle_triton(x, mma_version: tl.constexpr):
+ """
+ Triton inverse of swizzle_mxfp4_value_hopper
+ """
+ tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3")
+ # if mx_axis == 0:
+ # x = x.trans()
+
+ # We have two times the elements if we already upcasted to bfloat16
+ mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1
+ M: tl.constexpr = x.shape[0]
+ K: tl.constexpr = x.shape[1]
+ tl.static_assert(M % 4 == 0, "M must be divisible by 4")
+ tl.static_assert(
+ K % (4 * 8 * 2 * 2 * mult) == 0,
+ f"K must be divisible by {4 * 8 * 2 * 2 * mult}",
+ )
+
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1
+ x = x.reshape(
+ M // 4,
+ 4,
+ K // (4 * 8 * 2 * 2 * mult),
+ 2,
+ 4,
+ 8 // u8_kwidth,
+ 2,
+ u8_kwidth * mult,
+ )
+ x = x.trans(0, 6, 1, 3, 2, 5, 4, 7)
+ x = x.reshape(M * 4, K // 4)
+ # if mx_axis == 0:
+ # x = x.trans()
+ return x
+
+
+@triton.jit
+def _unpack_fp4_to_bf16_triton(x):
+ # For now we implement just H100 support (mul.bf16x2)
+ # A100 support is possible via fma
+ r0, r1 = tl.inline_asm_elementwise(
+ r"""
+ {
+ .reg .b32 b, c, d<7>, scale;
+ .reg .b32 bias;
+ mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
+ // We add the missing bias to the scale directly
+ and.b32 $0, $4, 0b10000001110000001000000111000000;
+ mul.bf16x2 $0, $0, bias;
+ shl.b32 b, $4, 3;
+ and.b32 $1, b, 0b10000001110000001000000111000000;
+ mul.bf16x2 $1, $1, bias;
+ shl.b32 c, $4, 6;
+ and.b32 $2, c, 0b10000001110000001000000111000000;
+ mul.bf16x2 $2, $2, bias;
+ // Unpack last two elements
+ shl.b32 d0, $4, 1;
+ and.b32 d1, d0, 0b10000000000000001000000000000000;
+ shr.b32 d2, $4, 3;
+ and.b32 d3, d2, 0b00000001100000000000000110000000;
+ or.b32 d4, d1, d3;
+ shr.b32 d5, $4, 7;
+ and.b32 d6, d5, 0b00000000010000000000000001000000;
+ or.b32 $3, d4, d6;
+ mul.bf16x2 $3, $3, bias;
+ }
+ """,
+ constraints="=r,=r,=r,=r,r",
+ args=[x],
+ dtype=(tl.bfloat16, tl.bfloat16),
+ is_pure=True,
+ pack=4,
+ )
+ # Concat each pack of 4
+ x = tl.join(r0, r1)
+ x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2])
+ x = x.trans(0, 1, 3, 2)
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
+ return x
+
+
+@triton.jit
+def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
+ """
+ Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
+ (x << 0) & 0b1000000111000000
+ (x << 3) & 0b1000000111000000
+ (x << 6) & 0b1000000111000000
+ ((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
+ """
+ # upcast values to bfloat16
+ tl.static_assert(len(x.shape) == 2)
+ tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1")
+ tl.static_assert(x.shape[1] % 4 == 0)
+ tl.static_assert(x.dtype == tl.uint8)
+ if mx_axis == 0:
+ x = x.trans()
+ x = _unpack_fp4_to_bf16_triton(x)
+ x = _unshuffle_triton(x, mma_version=3)
+ if mx_axis == 0:
+ x = x.trans()
+
+ # upcast scale to bfloat16
+ # Add bias missing from the bf16 upcasting sequence
+ # triton / LLVM generates terrible code for this sequence
+ # scale = scale.to(tl.uint16)
+ # scale = scale << 7
+ # scale = scale.to(tl.bfloat16, bitcast=True)
+ scale = tl.inline_asm_elementwise(
+ r"""
+ {
+ prmt.b32 $0, $2, 0, 0x5140;
+ shl.b32 $0, $0, 7;
+ prmt.b32 $1, $2, 0, 0x7362;
+ shl.b32 $1, $1, 7;
+ }
+ """,
+ constraints="=r,=r,r",
+ args=[scale],
+ dtype=tl.bfloat16,
+ is_pure=True,
+ pack=4,
+ )
+ # Broadcast scale
+ scale = scale.expand_dims(mx_axis + 1)
+ scale = scale.broadcast_to(
+ scale.shape[: mx_axis + 1] + [32] + scale.shape[mx_axis + 2 :]
+ )
+ scale = scale.reshape(x.shape)
+
+ # Combine scale and x
+ x = x * scale
+ return x
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/strided.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/strided.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbfd9248fca219eb94dae358cafd7fac6e082cd1
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/strided.py
@@ -0,0 +1,17 @@
+from .base import Layout
+
+
+class StridedLayout(Layout):
+ name: str = None
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+
+ def swizzle_data(self, data):
+ return data
+
+ def unswizzle_data(self, data):
+ return data
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/testing.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ea4534b2d44e787e91b638458948233d4be092
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/testing.py
@@ -0,0 +1,215 @@
+import enum
+import functools
+import os
+import subprocess
+import sys
+import torch
+from compactor_vllm.triton_kernels.numerics import (
+ MAX_FINITE_FLOAT8E4B8,
+ MAX_FINITE_FLOAT8E4NV,
+ MAX_FINITE_FLOAT8E5,
+)
+
+
+def assert_equal(ref, tri):
+ if isinstance(ref, torch.Tensor):
+ assert torch.all(ref == tri)
+ else:
+ assert ref == tri
+
+
+def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
+ if tri.dtype.itemsize == 1:
+ ref_as_type = ref.to(tri.dtype)
+ if ref.dtype == tri.dtype:
+ assert torch.all(ref_as_type == tri)
+ return
+ ref = ref_as_type
+
+ if ref.numel() == 0:
+ return
+
+ if maxtol is None:
+ maxtol = 2e-2
+ if rmstol is None:
+ rmstol = 4e-3
+ """
+ Compare reference values against obtained values.
+ """
+
+ # cast to float32:
+ ref = ref.to(torch.float32).detach()
+ tri = tri.to(torch.float32).detach()
+ assert ref.shape == tri.shape, (
+ f"Tensors must have same size {ref.shape=} {tri.shape=}"
+ )
+
+ # deal with infinite elements:
+ inf_mask_ref = torch.isinf(ref)
+ inf_mask_tri = torch.isinf(tri)
+ assert torch.equal(inf_mask_ref, inf_mask_tri), (
+ "Tensor must have same infinite elements"
+ )
+ refn = torch.where(inf_mask_ref, 0, ref)
+ trin = torch.where(inf_mask_tri, 0, tri)
+
+ # normalise so that RMS calculation doesn't overflow:
+ eps = 1.0e-30
+ multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
+ refn *= multiplier
+ trin *= multiplier
+
+ ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
+
+ rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
+ max_err = torch.max(rel_err).item()
+ rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
+
+ if verbose:
+ print(
+ "%s maximum relative error = %s (threshold = %s)"
+ % (description, max_err, maxtol)
+ )
+ print(
+ "%s RMS relative error = %s (threshold = %s)"
+ % (description, rms_err, rmstol)
+ )
+
+ if max_err > maxtol:
+ bad_idxs = torch.nonzero(rel_err > maxtol)
+ num_nonzero = bad_idxs.size(0)
+ bad_idxs = bad_idxs[:1000]
+ print(
+ "%d / %d mismatched elements (shape = %s) at coords %s"
+ % (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())
+ )
+
+ bad_idxs = bad_idxs.unbind(-1)
+ print("ref values: ", ref[tuple(bad_idxs)].cpu())
+ print("tri values: ", tri[tuple(bad_idxs)].cpu())
+
+ assert max_err <= maxtol
+ assert rms_err <= rmstol
+
+
+class ComputeSanitizerTool(enum.Enum):
+ MEMCHECK = "memcheck"
+ RACECHECK = "racecheck"
+ SYNCCHECK = "synccheck"
+ INITCHECK = "initcheck"
+
+
+def compute_sanitizer(**target_kwargs):
+ """
+ Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
+ to expose potential memory access errors.
+ This decorator requires the `request` fixture to be present.
+ If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
+ Running tests under compute sanitizer requires launching subprocess and is slow,
+ so use sparingly
+ """
+
+ def decorator(test_fn):
+ @functools.wraps(test_fn)
+ def wrapper(*args, **kwargs):
+ if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
+ test_fn(*args, **kwargs)
+ return
+
+ import psutil
+
+ if target_kwargs.pop("clear_torch_cache", False):
+ # If we don't pop clear_torch_cache, it won't pass
+ # target_kwargs.items() <= kwargs.items() condition below.
+ torch.cuda.empty_cache()
+ tools_to_check = target_kwargs.pop(
+ "tools_to_check", [ComputeSanitizerTool.MEMCHECK]
+ )
+ assert isinstance(tools_to_check, list), f"{tools_to_check=}"
+ assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
+ f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}"
+ )
+
+ ppid_name = psutil.Process(os.getppid()).exe()
+ run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
+ if "run_sanitizer" in kwargs:
+ run_compute_sanitizer &= kwargs["run_sanitizer"]
+ if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
+ for tool in tools_to_check:
+ path = os.path.realpath(test_fn.__globals__["__file__"])
+ # get path of current file
+ env = {
+ "PATH": os.environ["PATH"],
+ "PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
+ "TORCH_SHOW_CPP_STACKTRACES": "1",
+ "CUDA_LAUNCH_BLOCKING": "1",
+ }
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
+ assert "request_fixture" in kwargs, (
+ "memcheck'ed test must have a (possibly unused) `request` fixture"
+ )
+ test_id = kwargs["request_fixture"].node.callspec.id
+ cmd = f"{path}::{test_fn.__name__}[{test_id}]"
+ cmd = [
+ "compute-sanitizer",
+ "--target-processes=application-only",
+ "--destroy-on-device-error=context",
+ f"--tool={tool.value}",
+ sys.executable,
+ "-m",
+ "pytest",
+ "-vsx",
+ cmd,
+ ]
+ for opt in ["--update_checksum", "--ignore_checksum_error"]:
+ if opt in sys.argv:
+ cmd.append(opt)
+ out = subprocess.run(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ env=env,
+ )
+ sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
+ out.stdout
+ ) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
+ test_output = out.stdout
+ if type(test_output) is bytes:
+ test_output = test_output.decode()
+
+ fail = False
+ if not sanitizer_ok:
+ print("compute-sanitizer returned an error")
+ fail = True
+ elif out.returncode != 0:
+ print(
+ "The test failed due to some other reason: consider running without compute-sanitizer to verify."
+ )
+ print(f"{out.returncode=}")
+ fail = True
+
+ if fail:
+ print("*****************************************************")
+ print("******************** TEST OUTPUT ********************")
+ print("*****************************************************")
+ print(test_output)
+ print("*****************************************************")
+ print("****************** TEST OUTPUT END ******************")
+ print("*****************************************************")
+ assert None
+ else:
+ test_fn(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def compute_actual_scale(x, dtype):
+ max_finite = {
+ torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
+ torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
+ torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
+ }[dtype]
+ return x.abs().max() / max_finite
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk.py
new file mode 100644
index 0000000000000000000000000000000000000000..95ff4481daba996dd0e0a2c4de051183a506fe40
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk.py
@@ -0,0 +1,157 @@
+import torch
+import triton
+from compactor_vllm.triton_kernels.topk_details._topk_forward import _topk_forward
+from compactor_vllm.triton_kernels.topk_details import _topk_backward
+from compactor_vllm.triton_kernels.tensor import Tensor, Bitmatrix
+from typing import Optional, Union
+
+
+def topk_forward(
+ x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None
+):
+ if not isinstance(x, Tensor):
+ x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
+ x_shape_max = [x.shape[0], x.shape[1]]
+ x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
+ cdiv = lambda a, b: (a + b - 1) // b
+ BLOCK_M = 32
+ BLOCK_N = 32
+ BLOCK_S = 128
+ assert len(x.shape) == 2
+ assert x.shape_max[-1] < 32768
+ assert dim == 1
+ assert return_bitmatrix
+ n_rows, n_cols = x.shape
+ n_rows_max, _ = x.shape_max
+ dev = x.device
+ # scratchpad tensors
+ # NOTE: these are not returned
+ y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
+ if y_indx is not None:
+ use_provided_indx = True
+ else:
+ y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
+ use_provided_indx = False
+ # create bitmatrix in transposed memory layout:
+ n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
+ n_cols_words = n_cols_pad // 32
+ bitmatrix = torch.empty(
+ (n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev
+ )
+ bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
+ s_blocks = cdiv(n_cols, BLOCK_S)
+ s_cols = s_blocks * BLOCK_S
+ scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev)
+ pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
+ _topk_forward[(pids,)](
+ x,
+ x.stride(0), # inputs
+ y_vals,
+ y_indx,
+ y_vals.stride(0),
+ use_provided_indx, # output [topk]
+ bitmatrix,
+ bitmatrix.stride(0),
+ bitmatrix.stride(1), # output [bitmatrix]
+ n_rows,
+ n_cols, # shapes
+ scratchpad,
+ BLOCK_S,
+ s_blocks, # thing to memset to zero
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N, # tunable parameter
+ APPLY_SOFTMAX=apply_softmax,
+ N_EXPTS_PAD=n_cols_pad,
+ N_EXPTS_ACT=k, # constants
+ )
+ bitmatrix_shape = [n_rows, n_cols_words * 32]
+ bitmatrix_shape_max = [n_rows_max, None]
+ bitmatrix = Bitmatrix(
+ bitmatrix,
+ shape=bitmatrix_shape,
+ shape_max=bitmatrix_shape_max,
+ scratchpad=scratchpad,
+ )
+ return y_vals, y_indx, bitmatrix
+
+
+def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):
+ assert dy_vals.shape[-1] == k
+ n_expts_pad = triton.next_power_of_2(x.shape[-1])
+ dx = torch.empty_like(x)
+ _topk_backward[(dy_vals.shape[0],)](
+ y_indx,
+ y_indx.stride(0),
+ dy_vals,
+ dy_vals.stride(0),
+ x,
+ x.stride(0), # inputs
+ dx, # outputs
+ dx.stride(0),
+ x.shape[0],
+ n_rows,
+ x.shape[-1],
+ APPLY_SOFTMAX=apply_softmax,
+ N_EXPTS_ACT=k,
+ N_EXPTS_PAD=n_expts_pad,
+ )
+ return dx
+
+
+class TopK(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows):
+ y_vals, y_indx, bitmatrix = topk_forward(
+ x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows
+ )
+ ctx.save_for_backward(x, y_indx)
+ ctx.apply_softmax = apply_softmax
+ ctx.k = k
+ ctx.n_rows = n_rows
+ return y_vals, y_indx, bitmatrix
+
+ @staticmethod
+ def backward(ctx, dy_vals, _0, _1):
+ x, y_indx = ctx.saved_tensors
+ dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax)
+ return dx, None, None, None, None, None, None
+
+
+def topk(
+ x: Union[Tensor, torch.Tensor],
+ k: int,
+ apply_softmax: bool = True,
+ dim: int = 1,
+ return_bitmatrix: bool = True,
+ y_indx: Optional[torch.Tensor] = None,
+ n_rows: Optional[int] = None,
+):
+ """
+ Computes the top-k values and indices along a specified dimension of a tensor.
+ Note that the input can be either a `Tensor` or a `torch.Tensor`, but the output will always be a `torch.Tensor`.
+
+ Parameters
+ ----------
+ x : Union[triton_kernels.Tensor, torch.Tensor]
+ Input tensor of shape (n_tokens, n_expts).
+ k : int
+ Number of top elements to retrieve.
+ apply_softmax : bool, default True
+ Whether to apply softmax to the input tensor before computing top-k.
+ dim : int, default 1
+ Dimension along which to compute top-k.
+ return_bitmatrix : bool, default True
+ A bitmatrix of shape (n_tokens, cdiv(n_expts, 32)).
+ Each bit on [t, b] indicates whether the b-th expert was selected for the t-th token.
+ y_indx : torch.Tensor, optional
+ Pre-allocated tensor for storing indices of top-k elements with shape (n_tokens, k).
+ If provided, we skip the computation of top-k indices and use this tensor instead.
+ n_rows : int, optional
+ Number of rows to apply top-k on. If None, we consider all rows in `x`.
+
+ Returns
+ -------
+ (expt_scal, expt_indx, bitmatrix) : Tuple[torch.Tensor, torch.Tensor, Bitmatrix]
+ """
+ ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
+ return ret
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_backward.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_backward.py
new file mode 100644
index 0000000000000000000000000000000000000000..eebe481771543a05cfab5741bf1a0c875248f70d
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_backward.py
@@ -0,0 +1,51 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _topk_backward(
+ Yi,
+ stride_ym, # topk indices
+ DY,
+ stride_dym, # output gradient values
+ X,
+ stride_xm, # input values
+ DX,
+ stride_dxm, # input gradient values
+ n_rows,
+ NRows,
+ n_expts_tot,
+ APPLY_SOFTMAX: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ N_EXPTS_PAD: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ if NRows is not None:
+ n_rows = tl.load(NRows)
+ if pid_m >= n_rows:
+ return
+ Yi += pid_m * stride_ym
+ DY += pid_m * stride_dym
+ X += pid_m * stride_xm
+ DX += pid_m * stride_dxm
+ # --
+ offs_xn = tl.arange(0, N_EXPTS_PAD)
+ offs_yn = tl.arange(0, N_EXPTS_ACT)
+ mask_xn = offs_xn < n_expts_tot
+ # recompute softmax
+ y_indx = tl.load(Yi + offs_yn)
+ x = tl.load(X + y_indx)
+ x = x.to(tl.float32)
+ y = tl.softmax(x)
+ # compute input-gradient
+ dy = tl.load(DY + offs_yn)
+ dy = dy.to(tl.float32)
+ s = tl.sum(y * dy, 0)
+ # write-back input gradient
+ tl.store(DX + offs_xn, 0, mask=mask_xn)
+ tl.debug_barrier()
+ if APPLY_SOFTMAX:
+ dx = y * (dy - s)
+ else:
+ dx = dy
+ tl.store(DX + y_indx, dx)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_forward.py b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf27ba999cca1a2b8fe63f1c386680c77ea4cec9
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/triton_kernels/topk_details/_topk_forward.py
@@ -0,0 +1,183 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def get_topmask_and_fullmask(x):
+ tl.static_assert(
+ x.dtype.is_int_unsigned(), "floating-point value must be passed as bits"
+ )
+ tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
+ fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
+ tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
+ fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
+ return tm_arr, fm_arr
+
+
+@triton.jit
+def fpval_to_key(x):
+ tm, fm = get_topmask_and_fullmask(x)
+ return x ^ tl.where((x & tm) != 0, fm, tm)
+
+
+@triton.jit
+def key_to_fpval(x):
+ tm, fm = get_topmask_and_fullmask(x)
+ return x ^ tl.where((x & tm) == 0, fm, tm)
+
+
+# stable top-k tie-breaks to value with smaller index
+@triton.jit
+def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr):
+ return N_EXPTS_PAD - indx
+
+
+@triton.jit
+def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr):
+ return N_EXPTS_PAD - indx
+
+
+@triton.jit
+def streaming_topk(
+ X,
+ stride_xm,
+ n_expts_tot,
+ offs_m,
+ mask_m,
+ N_EXPTS_PAD: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
+ x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
+ if x_nbits < 16:
+ # this ensures that we leave at least 16 bits for expert index
+ # even if the input dtype is smaller than 16 bits:
+ y_nbits: tl.constexpr = 32
+ else:
+ y_nbits: tl.constexpr = x_nbits * 2
+ x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}")
+ x_dtype: tl.constexpr = X.dtype.element_ty
+
+ # subtract 1 from loop iterations because we peel the first (masked) iteration:
+ loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
+ offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_x_n[None, :] < n_expts_tot
+
+ # first iteration:
+ X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
+ x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
+ x = fpval_to_key(x.to(x_utype, bitcast=True))
+ x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
+ acc = tl.topk(x, N_EXPTS_ACT, dim=1)
+
+ # subsequent iterations:
+ for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
+ acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
+ X_ptrs -= BLOCK_N
+ offs_x_n -= BLOCK_N
+ x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
+ x = fpval_to_key(x.to(x_utype, bitcast=True))
+ x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
+ acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
+
+ # rotate expert index into upper 16 bits:
+ # 0000vvvvvvvviiii --> iiii0000vvvvvvvv
+ acc = (acc << (y_nbits - 16)) | (acc >> 16)
+ # sort in ascending order of expert (descending order of key)
+ acc = tl.sort(acc, dim=1, descending=True)
+ # iiii0000vvvvvvvv --> 0000iiii:
+ y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
+ y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
+ # iiii0000vvvvvvvv --> vvvvvvvv:
+ y_values_raw = acc.to(x_utype)
+ y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
+
+ return y_values, y_indices
+
+
+@triton.jit
+def _topk_forward(
+ X,
+ stride_xm, # inputs
+ Yv,
+ Yi,
+ stride_ym, # topk values/indices
+ USE_PROVIDED_INDX: tl.constexpr,
+ Bits,
+ stride_rm: tl.constexpr,
+ stride_rn: tl.constexpr, # bitmatrix
+ n_rows,
+ n_expts_tot, # shape
+ S,
+ BLOCK_S: tl.constexpr,
+ s_blocks, # thing to memset
+ APPLY_SOFTMAX: tl.constexpr, # constant
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_PAD: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr():
+ n_rows = tl.load(n_rows)
+
+ if pid < s_blocks:
+ tl.store(
+ S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32)
+ )
+
+ if pid * BLOCK_M >= n_rows:
+ # early exit:
+ return
+
+ tl.static_assert(BLOCK_N % 32 == 0)
+ tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
+ x_dtype: tl.constexpr = X.dtype.element_ty
+
+ # load logits
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_y_n = tl.arange(0, N_EXPTS_ACT)
+ mask_m = offs_m[:, None] < n_rows
+ if USE_PROVIDED_INDX:
+ Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ y_indices = tl.load(Yi_ptrs, mask=mask_m)
+ Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
+ y_values = tl.load(Xv_ptrs, mask=mask_m)
+ else:
+ y_values, y_indices = streaming_topk(
+ X,
+ stride_xm,
+ n_expts_tot,
+ offs_m,
+ mask_m, #
+ N_EXPTS_PAD,
+ N_EXPTS_ACT,
+ BLOCK_N,
+ )
+
+ # normalize selected values
+ if APPLY_SOFTMAX:
+ y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(
+ x_dtype
+ )
+
+ # write back
+ Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ tl.store(Yv_ptrs, y_values, mask=mask_m)
+ if not USE_PROVIDED_INDX:
+ Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ tl.store(Yi_ptrs, y_indices, mask=mask_m)
+
+ # pack into bitmatrix
+ y_div = y_indices // 32
+ y_rem = y_indices % 32
+ loop_iterations = N_EXPTS_PAD // BLOCK_N
+ for i in range(loop_iterations):
+ offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
+ y2 = tl.where(
+ y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0
+ )
+ r = tl.reduce_or(y2, axis=1)
+ BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
+ tl.store(BitsPtrs, r, mask=mask_m)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/utils/__init__.py b/vllm/compactor-vllm/src/compactor_vllm/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/compactor-vllm/src/compactor_vllm/utils/arguments.py b/vllm/compactor-vllm/src/compactor_vllm/utils/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f5f73b8bd324bdfadb85e778276bfa6da459c87
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/utils/arguments.py
@@ -0,0 +1,408 @@
+import itertools
+import math
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+from compactor_vllm.compression import CompressionMethod
+from compactor_vllm.compression.compression_config import BatchCompressionParams
+from compactor_vllm.config.engine_config import LLMConfig
+from compactor_vllm.utils.sequence import Sequence
+
+
+@dataclass
+class PrefillBatchArguments:
+ B: int
+ N: int
+ do_compression: bool
+ compression_method: CompressionMethod
+ compression_chunk_size: int
+
+ seq_ids: torch.Tensor
+
+ input_ids: torch.Tensor
+ positions: torch.Tensor
+ cu_seqlens_q: torch.Tensor
+ cu_seqlens_k: torch.Tensor
+ max_seqlen_q: int
+ max_seqlen_k: int
+
+ batch_tokens_to_retain: Optional[torch.Tensor]
+ max_tokens_to_retain: Optional[int]
+ protected_first: Optional[List[int]]
+ protected_last: Optional[List[int]]
+
+ PHI: Optional[torch.Tensor]
+
+ # args needed for memory reservation
+ context_lens: torch.Tensor
+ max_new_tokens: torch.Tensor
+
+ # 与 kvpress ``CompactorPress`` blending 默认(未显式指定时用 compression_ratio)对齐
+ compression_ratio: float = 1.0
+
+
+class PackedTensorArguments:
+ def __init__(
+ self, rank: int, max_batched_tokens: int, config: LLMConfig, seed: int = 42
+ ) -> None:
+ hf_config = config.hf_config
+ self.rank = rank
+ self.device = torch.device(f"cuda:{rank}")
+ self.max_num_batches = config.max_num_seqs
+ self.max_batched_tokens = max_batched_tokens
+ self.num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
+ self.world_size = config.tensor_parallel_size
+ self.page_size = int(config.kvcache_page_size)
+ self.head_dim = getattr(hf_config, "head_dim", None)
+ self.sketch_dim = config.leverage_sketch_size
+ self.model_dtype = hf_config.torch_dtype
+
+ # i64 pack = [seq_ids (BMAX)] || [input_ids (NMAX)] || [positions (NMAX)] || max_new_tok (BMAX)
+ self.i64_len_max = (
+ self.max_num_batches + 2 * self.max_batched_tokens + self.max_num_batches
+ )
+ self.packed_context_i64 = torch.empty(
+ self.i64_len_max, dtype=torch.int64, device=self.device
+ )
+
+ # i32 pack = [header (6): ... + compression_ratio*1e6] || [cu_q (BMAX+1)] || ...
+ # || [protected_first_tokens (BMAX)] || [protected_last_tokens (BMAX)]
+ self.i32_len_max = (
+ 6
+ + (self.max_num_batches + 1)
+ + (self.max_num_batches + 1)
+ + self.max_num_batches
+ + self.max_num_batches
+ + self.max_num_batches
+ + self.max_num_batches
+ )
+ self.packed_context_i32 = torch.empty(
+ self.i32_len_max, dtype=torch.int32, device=self.device
+ )
+
+ self.generator = torch.Generator(device=self.device).manual_seed(seed)
+ self.PHI = torch.randn(
+ (self.head_dim, self.sketch_dim),
+ device=self.packed_context_i32.device,
+ generator=self.generator,
+ ).to(self.model_dtype) * (1 / math.sqrt(self.sketch_dim))
+
+ def _master_build_prefill(
+ self, seqs: List[Sequence], batch_compression_params: BatchCompressionParams
+ ) -> PrefillBatchArguments:
+ B = len(seqs)
+ Ls = [x.prompt_len for x in seqs]
+
+ N = sum(Ls)
+ assert N <= self.max_batched_tokens
+ do_compression = any(x.compression_params.compression_ratio < 1.0 for x in seqs)
+ do_compression = (
+ do_compression
+ and batch_compression_params.compression_method != CompressionMethod.NONE
+ )
+ pack_slices_64 = self.packed_i64_slices(B, N)
+ pack_slices_32 = self.packed_i32_slices(B)
+
+ # max_retain = max(retain)
+ protected_first_list = [
+ x.compression_params.protected_first_tokens for x in seqs
+ ]
+ protected_last_list = [x.compression_params.protected_last_tokens for x in seqs]
+ retain = [
+ max(
+ int(
+ round(
+ x.compression_params.compression_ratio
+ * (L - s - e)
+ * self.num_kv_heads
+ )
+ ),
+ 1,
+ )
+ for s, e, L, x in zip(protected_first_list, protected_last_list, Ls, seqs)
+ ]
+ retain = torch.tensor(retain, dtype=torch.int32, device="cpu", pin_memory=True)
+ protected_first = torch.tensor(
+ protected_first_list, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ protected_last = torch.tensor(
+ protected_last_list, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ self.packed_context_i32[pack_slices_32["protected_first"]].copy_(
+ protected_first, non_blocking=True
+ )
+ self.packed_context_i32[pack_slices_32["protected_last"]].copy_(
+ protected_last, non_blocking=True
+ )
+ compression_chunk_size = (
+ batch_compression_params.chunk_size
+ if batch_compression_params.do_chunked_compression
+ else -1
+ )
+ min_compression_ratio = min(x.compression_params.compression_ratio for x in seqs)
+ cr_scaled = int(round(float(min_compression_ratio) * 1_000_000.0))
+ cr_scaled = max(min(cr_scaled, 2_000_000_000), -2_000_000_000)
+ header_host = torch.tensor(
+ [
+ B,
+ N,
+ 1 if do_compression else 0,
+ batch_compression_params.compression_method.value,
+ compression_chunk_size,
+ cr_scaled,
+ ],
+ dtype=torch.int32,
+ device="cpu",
+ pin_memory=True,
+ )
+
+ self.packed_context_i32[pack_slices_32["retain"]].copy_(
+ retain, non_blocking=True
+ )
+ self.packed_context_i32[pack_slices_32["header"]].copy_(
+ header_host, non_blocking=True
+ )
+ max_seq_qk = max(Ls)
+
+ cu = torch.tensor(
+ list(itertools.accumulate(Ls, initial=0)),
+ dtype=torch.int32,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i32[pack_slices_32["cu_q"]].copy_(cu, non_blocking=True)
+ self.packed_context_i32[pack_slices_32["cu_k"]].copy_(cu, non_blocking=True)
+ self.packed_context_i32[pack_slices_32["context_lens"]].copy_(
+ cu.diff(), non_blocking=True
+ )
+
+ seq_ids = torch.tensor(
+ [x.seq_id for x in seqs], dtype=torch.int64, device="cpu", pin_memory=True
+ )
+ input_ids = torch.tensor(
+ [tid for x in seqs for tid in x.prompt_token_ids],
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i64[pack_slices_64["seq_ids"]].copy_(
+ seq_ids, non_blocking=True
+ )
+ self.packed_context_i64[pack_slices_64["input_ids"]].copy_(
+ input_ids, non_blocking=True
+ )
+
+ positions = torch.cat(
+ [
+ torch.arange(L, dtype=torch.int64, device="cpu", pin_memory=True)
+ for L in Ls
+ ]
+ )
+ self.packed_context_i64[pack_slices_64["positions"]].copy_(
+ positions, non_blocking=True
+ )
+
+ max_new_tokens = torch.tensor(
+ [seq.sampling_params.max_new_tokens for seq in seqs],
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i64[pack_slices_64["max_new_tokens"]].copy_(
+ max_new_tokens, non_blocking=True
+ )
+ # `prefill_store_topk_kv(..., PAD_TO_PAGE_SIZE=True)` may scan beyond the
+ # top-k prefix to fill per-head lengths up to a page boundary. Using a
+ # full ranking (top_k = max_seq_len * HKV) makes `torch.topk` degenerate
+ # into a full sort, which is very expensive for long contexts.
+ #
+ # Instead, request only a prefix that is large enough for:
+ # 1) the maximum "keep" budget in the batch, plus
+ # 2) a conservative extra window for page-padding candidates.
+ max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
+ full_budget = max_seq_len * self.num_kv_heads
+ keep_budget = int(retain.max().item())
+ pad_search_budget = (self.page_size - 1) * (self.num_kv_heads**2)
+ max_retain = min(full_budget, keep_budget + pad_search_budget)
+ dist.broadcast(self.packed_context_i64, src=0)
+ dist.broadcast(self.packed_context_i32, src=0)
+ prefill_args = PrefillBatchArguments(
+ B=B,
+ N=N,
+ do_compression=do_compression,
+ compression_method=batch_compression_params.compression_method,
+ compression_chunk_size=compression_chunk_size,
+ seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
+ input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
+ positions=self.packed_context_i64[pack_slices_64["positions"]],
+ cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
+ cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
+ max_seqlen_q=max_seq_qk,
+ max_seqlen_k=max_seq_qk,
+ batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
+ max_tokens_to_retain=max_retain,
+ PHI=self.PHI,
+ context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
+ max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
+ protected_first=protected_first_list,
+ protected_last=protected_last_list,
+ compression_ratio=min_compression_ratio,
+ )
+ return prefill_args
+
+ def _peer_receive_prefill(self) -> PrefillBatchArguments:
+ dist.broadcast(self.packed_context_i64, src=0)
+ dist.broadcast(self.packed_context_i32, src=0)
+ header = self.packed_context_i32[:6].tolist()
+ B, N = int(header[0]), int(header[1])
+ do_compression = bool(int(header[2]))
+ compression_method = CompressionMethod(int(header[3]))
+ compression_chunk_size = int(header[4])
+ compression_ratio = int(header[5]) / 1_000_000.0
+
+ pack_slices_64 = self.packed_i64_slices(B, N)
+ pack_slices_32 = self.packed_i32_slices(B)
+ max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
+ full_budget = max_seq_len * self.num_kv_heads
+ keep_budget = int(self.packed_context_i32[pack_slices_32["retain"]].max().item())
+ pad_search_budget = (self.page_size - 1) * (self.num_kv_heads**2)
+ max_retain = min(full_budget, keep_budget + pad_search_budget)
+ prefill_args = PrefillBatchArguments(
+ B=B,
+ N=N,
+ do_compression=do_compression,
+ compression_method=compression_method,
+ compression_chunk_size=compression_chunk_size,
+ seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
+ input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
+ positions=self.packed_context_i64[pack_slices_64["positions"]],
+ cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
+ cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
+ max_seqlen_q=int(self.packed_context_i32[pack_slices_32["cu_q"]].max()),
+ max_seqlen_k=int(self.packed_context_i32[pack_slices_32["cu_k"]].max()),
+ batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
+ max_tokens_to_retain=max_retain,
+ PHI=self.PHI,
+ context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
+ max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
+ protected_first=self.packed_context_i32[
+ pack_slices_32["protected_first"]
+ ].tolist(),
+ protected_last=self.packed_context_i32[
+ pack_slices_32["protected_last"]
+ ].tolist(),
+ compression_ratio=compression_ratio,
+ )
+ return prefill_args
+
+ @torch.inference_mode()
+ def build_prefill_args(
+ self,
+ seqs: Optional[List[Sequence]] = None,
+ batch_compression_params: Optional[BatchCompressionParams] = None,
+ ) -> PrefillBatchArguments:
+ if self.rank == 0:
+ return self._master_build_prefill(seqs, batch_compression_params)
+ return self._peer_receive_prefill()
+
+ def broadcast(self):
+ if self.world_size > 1:
+ return dist.broadcast(self.packed_context_i64, src=0)
+ return None
+
+ @staticmethod
+ def packed_i64_slices(B: int, N: int):
+ return {
+ "seq_ids": slice(0, B),
+ "input_ids": slice(B, B + N),
+ "positions": slice(B + N, B + 2 * N),
+ "max_new_tokens": slice(B + 2 * N, 2 * B + 2 * N),
+ }
+
+ @staticmethod
+ def packed_i32_slices(B: int):
+ h0, h1 = 0, 6
+ q0 = h1
+ q1 = q0 + (B + 1)
+ k0 = q1
+ k1 = k0 + (B + 1)
+ r0 = k1
+ r1 = r0 + B
+ c0 = r1
+ c1 = r1 + B
+
+ pf0 = c1
+ pf1 = c1 + B
+ pl0 = pf1
+ pl1 = pf1 + B
+ return {
+ "header": slice(h0, h1),
+ "cu_q": slice(q0, q1),
+ "cu_k": slice(k0, k1),
+ "retain": slice(r0, r1),
+ "context_lens": slice(c0, c1),
+ "protected_first": slice(pf0, pf1),
+ "protected_last": slice(pl0, pl1),
+ }
+
+
+@dataclass
+class DecodeBatchOutput:
+ output_tokens: Optional[torch.Tensor]
+ output_seq_ids: Optional[torch.Tensor]
+
+
+@dataclass
+class DecodeBatchArguments:
+ batch_mapping: Optional[torch.Tensor] = None
+ token_ids: Optional[torch.Tensor] = None
+ positions: Optional[torch.Tensor] = None
+ max_ctx_lens: Optional[torch.Tensor] = None
+ seq_ids: Optional[torch.Tensor] = None
+ temps: Optional[torch.Tensor] = None
+ desired_batch_occupancy: int = -1
+ num_stashed_batches: int = 0
+
+ def update(
+ self,
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ seq_ids,
+ temps=None,
+ desired_batch_occupancy: int = None,
+ ):
+ if self.batch_mapping is not None:
+ self.batch_mapping = torch.cat([self.batch_mapping, batch_mapping], dim=0)
+ else:
+ self.batch_mapping = batch_mapping.clone()
+ if self.token_ids is not None:
+ self.token_ids = torch.cat([self.token_ids, token_ids], dim=0)
+ else:
+ self.token_ids = token_ids.clone()
+ if self.positions is not None:
+ self.positions = torch.cat([self.positions, positions], dim=0)
+ else:
+ self.positions = positions.clone()
+ if self.max_ctx_lens is not None:
+ self.max_ctx_lens = torch.cat([self.max_ctx_lens, max_ctx_lens], dim=0)
+ else:
+ self.max_ctx_lens = max_ctx_lens.clone()
+ if self.seq_ids is not None:
+ self.seq_ids = torch.cat([self.seq_ids, seq_ids], dim=0)
+ else:
+ self.seq_ids = seq_ids.clone()
+
+ if self.temps is not None and temps is not None:
+ self.temps = torch.cat([self.temps, temps], dim=0)
+ elif temps is not None:
+ self.temps = temps.clone()
+
+ if desired_batch_occupancy is not None:
+ self.desired_batch_occupancy = desired_batch_occupancy
+
+ return self
+
diff --git a/vllm/compactor-vllm/src/compactor_vllm/utils/context.py b/vllm/compactor-vllm/src/compactor_vllm/utils/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..f05a8144fb7f370ddc961cafb1531ea62c2c4508
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/utils/context.py
@@ -0,0 +1,97 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+from compactor_vllm.compression import CompressionMethod
+from compactor_vllm.config.engine_config import AttentionBackend
+
+
+@dataclass
+class CompressionContext:
+ compression_method: CompressionMethod = CompressionMethod.COMPACTOR
+
+ compression_chunk_size: int = -1
+ batch_tokens_to_retain: torch.Tensor | None = None
+ max_tokens_to_retain: int = 0
+ context_lens: List[int] | None = None
+ PHI: torch.Tensor | None = None
+
+ # Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参)
+ sketch_dimension: int = 48
+ sink_size_start: int = 8
+ sink_size_end: int = 4
+ compactor_blending: Optional[float] = None
+ # 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio)
+ compression_ratio: Optional[float] = None
+
+ protected_first_tokens: List[int] | None = None
+ protected_last_tokens: List[int] | None = None
+
+ # CriticalAdaKV
+ wo_weight: Optional[torch.Tensor] = None
+ critical_ada_epsilon: float = 1e-4
+ critical_ada_first_stage_ratio: float = 0.5
+ critical_ada_alpha_safeguard: float = 0.2
+
+
+@dataclass
+class Context:
+ is_prefill: bool = False
+ do_compression: bool = False
+
+ cu_seqlens_q: torch.Tensor | None = None
+ cu_seqlens_k: torch.Tensor | None = None
+ max_seqlen_q: int = 0
+ max_seqlen_k: int = 0
+ batch_mapping: torch.Tensor | None = None
+ max_bh_len: int = 0
+
+ compression_context: CompressionContext | None = None
+ STORE_STREAM: torch.cuda.Stream | None = None
+
+ key_split: int | None = None
+ attention_backend: AttentionBackend = AttentionBackend.COMPACTOR_TRITON
+
+
+_CONTEXT = Context()
+
+
+def get_context():
+ return _CONTEXT
+
+
+def set_context(
+ *,
+ is_prefill,
+ do_compression=False,
+ cu_seqlens_q=None,
+ cu_seqlens_k=None,
+ max_seqlen_q=0,
+ max_seqlen_k=0,
+ batch_mapping=None,
+ max_bh_len=0,
+ compression_context: CompressionContext = None,
+ STORE_STREAM=None,
+ key_split=None,
+ attention_backend=AttentionBackend.COMPACTOR_TRITON,
+):
+ global _CONTEXT
+ _CONTEXT = Context(
+ is_prefill,
+ do_compression,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ batch_mapping,
+ max_bh_len,
+ compression_context,
+ STORE_STREAM,
+ key_split,
+ attention_backend,
+ )
+
+
+def reset_context():
+ global _CONTEXT
+ _CONTEXT = Context()
diff --git a/vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py b/vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e833b885ec2cc2372b1a267a7b361b535fd9d938
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py
@@ -0,0 +1,35 @@
+from collections.abc import Callable
+
+import torch
+
+
+def maybe_execute_in_stream(
+ fn: Callable, *args, STORE_STREAM: torch.cuda.Stream = None, **kwargs
+):
+ if STORE_STREAM is not None:
+ tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
+ tensors += [val for val in kwargs.values() if isinstance(val, torch.Tensor)]
+ obj = getattr(fn, "__self__", None)
+ if isinstance(obj, torch.Tensor):
+ tensors.append(obj)
+ STORE_STREAM.wait_stream(torch.cuda.default_stream())
+ # Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
+ # The portable API is `torch.cuda.stream(stream)`.
+ stream_ctx = (
+ STORE_STREAM
+ if hasattr(STORE_STREAM, "__enter__")
+ else torch.cuda.stream(STORE_STREAM)
+ )
+ with stream_ctx:
+ output = fn(*args, **kwargs)
+ for t in tensors:
+ t.record_stream(STORE_STREAM)
+ if isinstance(output, tuple):
+ for o in output:
+ if isinstance(o, torch.Tensor):
+ o.record_stream(torch.cuda.default_stream())
+ elif isinstance(output, torch.Tensor):
+ output.record_stream(torch.cuda.default_stream())
+ return output
+ else:
+ return fn(*args, **kwargs)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py b/vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py
new file mode 100644
index 0000000000000000000000000000000000000000..19e5aa76bb5f2ca04a7dc3f5cba111448c854d10
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py
@@ -0,0 +1,83 @@
+from dataclasses import dataclass, field
+from enum import Enum, auto
+from itertools import count
+from typing import List
+
+from compactor_vllm.compression.compression_config import SequenceCompressionParams
+from compactor_vllm.config.sampling_params import SamplingParams
+
+
+class SequenceStatus(Enum):
+ WAITING = auto()
+ RUNNING = auto()
+ FINISHED = auto()
+
+
+@dataclass
+class Sequence:
+ """
+ Represents a single user request / sequence being generated.
+ """
+
+ _counter = count()
+
+ prompt_token_ids: List[int]
+ completion_token_ids: List[int] = field(default_factory=list)
+ sampling_params: SamplingParams = field(default_factory=SamplingParams)
+ compression_params: SequenceCompressionParams = field(
+ default_factory=SequenceCompressionParams
+ )
+ status: SequenceStatus = SequenceStatus.WAITING
+
+ seq_id: int = field(default_factory=lambda: next(Sequence._counter), init=False)
+ num_tokens_processed: int = 0
+
+ @property
+ def num_prompt_tokens(self) -> int:
+ return len(self.prompt_token_ids)
+
+ @property
+ def num_generated_tokens(self) -> int:
+ return len(self.completion_token_ids)
+
+ def add_new_token(self, token_id: int) -> None:
+ if len(self.completion_token_ids) == 0:
+ self.num_tokens_processed += self.num_prompt_tokens
+ self.completion_token_ids.append(token_id)
+ self.num_tokens_processed += 1
+
+ def tokens_to_retain_per_layer(self, num_kv_heads: int) -> int:
+ n = int(
+ self.compression_params.compression_ratio
+ * self.num_prompt_tokens
+ * num_kv_heads
+ )
+ return max(1, n)
+
+ def __getstate__(self):
+ return dict(
+ prompt_token_ids=list(self.prompt_token_ids),
+ completion_token_ids=list(self.completion_token_ids),
+ sampling_params=self.sampling_params,
+ compression_params=self.compression_params,
+ status=self.status,
+ seq_id=self.seq_id,
+ num_tokens_processed=self.num_tokens_processed,
+ )
+
+ def __setstate__(self, state):
+ self.prompt_token_ids = list(state["prompt_token_ids"])
+ self.completion_token_ids = list(state["completion_token_ids"])
+ self.sampling_params = state["sampling_params"]
+ self.compression_params = state["compression_params"]
+ self.status = state["status"]
+ self.seq_id = state["seq_id"]
+ self.num_tokens_processed = state["num_tokens_processed"]
+
+ @property
+ def prompt_len(self) -> int:
+ return len(self.prompt_token_ids)
+
+ @property
+ def completion_len(self) -> int:
+ return len(self.completion_token_ids)
diff --git a/vllm/compactor-vllm/src/compactor_vllm/utils/triton_compat.py b/vllm/compactor-vllm/src/compactor_vllm/utils/triton_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..65a459c0bddeaf38d594177abc2e0bfb07533b8e
--- /dev/null
+++ b/vllm/compactor-vllm/src/compactor_vllm/utils/triton_compat.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+import inspect
+from typing import Any, Callable, Mapping
+
+import torch
+
+
+def _filter_kwargs_for_callable(
+ fn: Callable[..., Any], kwargs: Mapping[str, Any]
+) -> dict[str, Any]:
+ try:
+ params = inspect.signature(fn).parameters
+ except (TypeError, ValueError):
+ return dict(kwargs)
+ return {k: v for k, v in kwargs.items() if k in params}
+
+
+def autotune(*, configs, key, **kwargs):
+ """
+ Compatibility wrapper around `triton.autotune`.
+
+ Some Triton builds (e.g., custom vendor builds) may not support newer
+ keyword arguments like `cache_results`. This wrapper filters unsupported
+ kwargs based on the runtime `triton.autotune` signature.
+ """
+ import triton
+
+ filtered = _filter_kwargs_for_callable(triton.autotune, kwargs)
+ return triton.autotune(configs=configs, key=key, **filtered)
+
+
+def maybe_set_allocator(alloc_fn: Callable[[int, int, int | None], Any]) -> bool:
+ """
+ Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
+
+ Returns True if the allocator was set.
+ """
+ import triton
+
+ setter = getattr(triton, "set_allocator", None)
+ if setter is None:
+ return False
+ setter(alloc_fn)
+ return True
+
+
+def cuda_capability_geq(major: int, minor: int = 0, device: int | None = None) -> bool:
+ """
+ Host-side CUDA capability check that works even when `tl.target_info` is absent.
+ """
+ if not torch.cuda.is_available():
+ return False
+ if device is None:
+ try:
+ device = torch.cuda.current_device()
+ except Exception:
+ device = 0
+ cap = torch.cuda.get_device_capability(device)
+ return cap >= (major, minor)
+
diff --git a/vllm/compactor-vllm/tests/test_store_kv.py b/vllm/compactor-vllm/tests/test_store_kv.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bbcbb37c4337990eea6b17097bff68218725a88
--- /dev/null
+++ b/vllm/compactor-vllm/tests/test_store_kv.py
@@ -0,0 +1,239 @@
+import collections
+import logging
+from dataclasses import dataclass
+from typing import List
+
+import pytest
+import torch
+import triton
+
+from compactor_vllm.compression.common import scores_to_retain_indices
+from src.compactor_vllm.kv_cache.store_kv_cache import prefill_store_topk_kv
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Workload:
+ name: str
+ batch_size: int
+ nk_heads: int
+ head_dim: int
+ frac: float # per-sequence cached context length fractionf
+ page_size: int
+ cache_lens: List[int] # per-sequence cached context length
+
+
+WORKLOADS: List[Workload] = [
+ Workload(
+ name=f"batch_size={BATCH} kv_cache_len={cache_lens} "
+ f"FRAC={frac} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
+ batch_size=BATCH,
+ nk_heads=NK_HEADS,
+ head_dim=HEAD_DIM,
+ cache_lens=[cache_lens] * BATCH,
+ frac=frac,
+ page_size=ps,
+ )
+ for BATCH in [1, 2, 3, 8]
+ for frac in [0.10, 0.20, 0.30, 0.40]
+ for NK_HEADS in [2, 4, 8]
+ for HEAD_DIM in [32, 64, 128]
+ for cache_lens in [10, 20, 30, 70, 1000]
+ for ps in [128, 256]
+]
+
+
+@pytest.mark.parametrize("workload", WORKLOADS, ids=lambda wl: wl.name)
+def test_prefill_store_topk_kv(workload: Workload):
+ B = workload.batch_size
+ H = workload.nk_heads
+ D = workload.head_dim
+ TOP_K = int(workload.cache_lens[0] * workload.nk_heads * workload.frac)
+ PAGE_SIZE = workload.page_size
+
+ dtype = torch.float16
+ device = triton.runtime.driver.active.get_active_torch_device()
+
+ lens = torch.tensor(workload.cache_lens, dtype=torch.int32, device=device)
+ cu = torch.zeros(B + 1, dtype=torch.int32, device=device)
+ cu[1:] = torch.cumsum(lens, dim=0)
+ N_total = int(cu[-1].item())
+
+ keys = torch.randn((N_total, H, D), dtype=dtype, device=device)
+ vals = torch.randn_like(keys)
+ scores_flat = torch.randn((N_total, H), dtype=torch.float32, device=device)
+
+ top_k_eff = max(0, min(TOP_K, int(lens.max().item()) * H))
+ max_k_len = cu.diff().max().item()
+ indices = scores_to_retain_indices(
+ scores_flat, cu, max_k_len, top_k_eff, H
+ ) # [B, TOP_K]
+
+ LP = max(1, (top_k_eff + PAGE_SIZE - 1) // PAGE_SIZE)
+ N_LOGICAL_PAGES_MAX = LP
+ N_PAGES = B * H * LP + 32
+ S_LARGE = N_PAGES * PAGE_SIZE
+ k_cache = torch.empty((S_LARGE, D), dtype=dtype, device=device)
+ v_cache = torch.empty_like(k_cache)
+
+ page_table = torch.empty(
+ (B, H, N_LOGICAL_PAGES_MAX), dtype=torch.int32, device=device
+ )
+ phys = 0
+ for b in range(B):
+ for h in range(H):
+ for lp in range(LP):
+ page_table[b, h, lp] = phys
+ phys += 1
+ assert phys <= N_PAGES, "Not enough physical pages"
+
+ local_lens = torch.zeros((B, H), dtype=torch.int32, device=device)
+ batch_mapping = torch.arange(B, dtype=torch.int32, device=device)
+ num_to_retain = torch.full((B,), top_k_eff, dtype=torch.int32, device=device)
+
+ prefill_store_topk_kv(
+ new_keys=keys,
+ new_vals=vals,
+ indices_topk=indices,
+ num_tokens_to_retain=num_to_retain,
+ page_table=page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=local_lens,
+ PAGE_SIZE=PAGE_SIZE,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ PAD_TO_PAGE_SIZE=False,
+ TRITON_RESERVED_BATCH=-1,
+ )
+ torch.cuda.synchronize()
+
+ local_lens_cpu = local_lens.cpu()
+ page_table_cpu = page_table.cpu()
+ k_cache_cpu = k_cache.cpu()
+ v_cache_cpu = v_cache.cpu()
+ keys_cpu = keys.cpu()
+ vals_cpu = vals.cpu()
+ indices_cpu = indices.cpu()
+
+ for b in range(B):
+ hed = (indices_cpu[b] % H).numpy()
+ counts = collections.Counter(hed.tolist())
+ for h in range(H):
+ expected = counts.get(h, 0) # type: ignore
+ got = int(local_lens_cpu[b, h].item())
+ assert got == expected, (
+ f"Length mismatch at (b={b}, h={h}): got {got}, expected {expected}"
+ )
+
+ def rows_for_head(b, h, L):
+ """Return the list of cache row indices storing the first L logical positions for (b,h)."""
+ rows = []
+ for pos in range(L):
+ lp = pos // PAGE_SIZE
+ off = pos % PAGE_SIZE
+ phys = int(page_table_cpu[b, h, lp].item())
+ rows.append(phys * PAGE_SIZE + off)
+ return rows
+
+ for b in range(B):
+ # which tokens per head were selected for this batch?
+ tok = (indices_cpu[b] // H).numpy()
+ hed = (indices_cpu[b] % H).numpy()
+ per_head = collections.defaultdict(list)
+ for t, h in zip(tok, hed):
+ per_head[int(h)].append(int(t))
+
+ for h in range(H):
+ L = int(local_lens_cpu[b, h].item())
+ if L == 0:
+ continue
+
+ # expected vectors (unordered) from source
+ toks_h = per_head.get(h, [])
+ assert len(toks_h) == L
+ expK = keys_cpu[toks_h, h, :].contiguous().view(L, -1)
+ expV = vals_cpu[toks_h, h, :].contiguous().view(L, -1)
+
+ # actual vectors read back from cache rows
+ rows = rows_for_head(b, h, L)
+ actK = k_cache_cpu[rows, :].contiguous().view(L, -1)
+ actV = v_cache_cpu[rows, :].contiguous().view(L, -1)
+
+ expK_tuples = [tuple(row) for row in expK.numpy().tolist()]
+ actK_tuples = [tuple(row) for row in actK.numpy().tolist()]
+ expV_tuples = [tuple(row) for row in expV.numpy().tolist()]
+ actV_tuples = [tuple(row) for row in actV.numpy().tolist()]
+
+ assert collections.Counter(expK_tuples) == collections.Counter(
+ actK_tuples
+ ), f"K content mismatch at (b={b}, h={h})"
+ assert collections.Counter(expV_tuples) == collections.Counter(
+ actV_tuples
+ ), f"V content mismatch at (b={b}, h={h})"
+
+
+def test_prefill_store_topk_kv_pad_to_page_size():
+ torch.manual_seed(0)
+ B, H, D = 2, 2, 64
+ PAGE_SIZE = 128
+ RETAIN = 64
+
+ dtype = torch.float16
+ device = triton.runtime.driver.active.get_active_torch_device()
+
+ lens = torch.full((B,), 256, dtype=torch.int32, device=device)
+ cu = torch.zeros(B + 1, dtype=torch.int32, device=device)
+ cu[1:] = torch.cumsum(lens, dim=0)
+ N_total = int(cu[-1].item())
+
+ keys = torch.randn((N_total, H, D), dtype=dtype, device=device)
+ vals = torch.randn_like(keys)
+ scores_flat = torch.randn((N_total, H), dtype=torch.float32, device=device)
+
+ max_k_len = int(lens.max().item())
+ max_sel = max_k_len * H
+ indices = scores_to_retain_indices(scores_flat, cu, max_k_len, max_sel, H)
+
+ N_LOGICAL_PAGES_MAX = 2
+ N_PAGES = B * H * N_LOGICAL_PAGES_MAX + 32
+ S_LARGE = N_PAGES * PAGE_SIZE
+ k_cache = torch.empty((S_LARGE, D), dtype=dtype, device=device)
+ v_cache = torch.empty_like(k_cache)
+
+ page_table = torch.empty(
+ (B, H, N_LOGICAL_PAGES_MAX), dtype=torch.int32, device=device
+ )
+ phys = 0
+ for b in range(B):
+ for h in range(H):
+ for lp in range(N_LOGICAL_PAGES_MAX):
+ page_table[b, h, lp] = phys
+ phys += 1
+ assert phys <= N_PAGES, "Not enough physical pages"
+
+ local_lens = torch.zeros((B, H), dtype=torch.int32, device=device)
+ batch_mapping = torch.arange(B, dtype=torch.int32, device=device)
+ num_to_retain = torch.full((B,), RETAIN, dtype=torch.int32, device=device)
+
+ prefill_store_topk_kv(
+ new_keys=keys,
+ new_vals=vals,
+ indices_topk=indices,
+ num_tokens_to_retain=num_to_retain,
+ page_table=page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=local_lens,
+ PAGE_SIZE=PAGE_SIZE,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ PAD_TO_PAGE_SIZE=True,
+ cu_seqlens_k=cu,
+ TRITON_RESERVED_BATCH=-1,
+ )
+ torch.cuda.synchronize()
+
+ local_lens_cpu = local_lens.cpu()
+ lens_cpu = lens.cpu()
+ assert (local_lens_cpu % PAGE_SIZE == 0).all()
+ assert (local_lens_cpu <= lens_cpu[:, None]).all()
diff --git a/vllm/compactor-vllm/tests/test_triton_attention.py b/vllm/compactor-vllm/tests/test_triton_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eeeb444216df5178687a9cafa0f2b1218b12b38
--- /dev/null
+++ b/vllm/compactor-vllm/tests/test_triton_attention.py
@@ -0,0 +1,407 @@
+import logging
+import math
+from dataclasses import dataclass
+from typing import List
+
+import pytest
+import torch
+import triton
+from flash_attn.flash_attn_interface import (
+ flash_attn_varlen_func,
+ flash_attn_with_kvcache,
+)
+
+from compactor_vllm.attention.sparse_decode_kernel import head_sparse_decode_attention
+from compactor_vllm.attention.sparse_varlen_kernel import (
+ causal_sparse_varlen_with_cache,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Workload:
+ name: str
+ batch_size: int
+ nq_heads: int
+ nk_heads: int
+ head_dim: int
+ cache_lens: List[int] # per-sequence cached context length
+ append_lens: List[int] # per-sequence new tokens this step (Q_app, K_app, V_app)
+
+
+WORKLOADS: List[Workload] = [
+ Workload(
+ name=f"batch_size={BATCH} kv_cache_len={cache_lens} append_len={append_lens} "
+ f"HQ={NQ_HEADS} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
+ batch_size=BATCH,
+ nq_heads=NQ_HEADS,
+ nk_heads=NK_HEADS,
+ head_dim=HEAD_DIM,
+ cache_lens=[cache_lens] * BATCH,
+ append_lens=[append_lens] * BATCH,
+ )
+ for BATCH in [1, 2, 3, 8]
+ for NQ_HEADS in [32]
+ for NK_HEADS in [8]
+ for HEAD_DIM in [128]
+ for cache_lens in [0, 1, 70, 128, 8193]
+ for append_lens in [1, 2, 13, 8000]
+]
+
+WORKLOADS_DECODE: List[Workload] = [
+ Workload(
+ name=f"batch_size={BATCH} kv_cache_len={cache_lens}"
+ f"HQ={NQ_HEADS} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
+ batch_size=BATCH,
+ nq_heads=NQ_HEADS,
+ nk_heads=NK_HEADS,
+ head_dim=HEAD_DIM,
+ cache_lens=[cache_lens] * BATCH,
+ append_lens=[1] * BATCH,
+ )
+ for BATCH in [1, 2, 3, 8]
+ for NQ_HEADS in [32]
+ for NK_HEADS in [8]
+ for HEAD_DIM in [128]
+ for cache_lens in [1, 2, 70, 128, 8000]
+]
+
+
+def build_paged_cache_from_lengths(
+ B,
+ H_kv,
+ D,
+ PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX,
+ L_cache_per_b, # int32 [B], per-batch cache length
+ device,
+ dtype,
+):
+ """
+ Construct:
+ - seq_lens_bh[b, h] = L_cache_per_b[b]
+ - page_table[b, h, lp] giving physical page ids
+ - K_cache, V_cache filled for valid cached tokens
+
+ Physical layout:
+ physical_page_id = (b * H_kv + h) * N_LOGICAL_PAGES_MAX + lp
+ CACHE_SIZE = num_phys_pages * PAGE_SIZE
+ """
+ assert L_cache_per_b.shape[0] == B
+ max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
+ assert (L_cache_per_b <= max_len).all()
+
+ seq_lens_bh = torch.empty((B, H_kv), dtype=torch.int32, device=device)
+ for b in range(B):
+ seq_lens_bh[b, :].fill_(L_cache_per_b[b])
+
+ num_phys_pages = B * H_kv * N_LOGICAL_PAGES_MAX
+ CACHE_SIZE = num_phys_pages * PAGE_SIZE
+
+ K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
+ V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
+ page_table = torch.empty(
+ (B, H_kv, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
+ )
+
+ # assign unique physical pages per (b, h, lp)
+ phys_page = 0
+ for b in range(B):
+ for h in range(H_kv):
+ for lp in range(N_LOGICAL_PAGES_MAX):
+ page_table[b, h, lp] = phys_page
+ phys_page += 1
+
+ # fill cached tokens
+ g = torch.Generator(device=device).manual_seed(1234)
+ for b in range(B):
+ Lc = int(L_cache_per_b[b].item())
+ for h in range(H_kv):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b, h, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_cache[idx] = torch.randn(D, device=device, dtype=dtype, generator=g)
+ V_cache[idx] = torch.randn(D, device=device, dtype=dtype, generator=g)
+
+ return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
+
+
+def materialize_kv_for_flash_mixed(
+ K_cache,
+ V_cache,
+ page_table,
+ L_cache_per_b, # [B]
+ k_append_raw, # [N, H_kv, D]
+ v_append_raw, # [N, H_kv, D]
+ cu_seqlens_qk, # [B+1]
+ H_kv,
+ PAGE_SIZE,
+):
+ """
+ Build (K_total, V_total, cu_seqlens_k) for flash_attn_varlen_func such that:
+
+ For each batch b:
+ seqlen_q[b] = L_app[b] = cu[b+1] - cu[b]
+ seqlen_k[b] = L_cache_per_b[b] + L_app[b]
+ Keys:
+ - first L_cache_per_b[b] positions from paged cache
+ - next L_app[b] positions from k_append_raw for that batch
+ """
+ device = K_cache.device
+ dtype = K_cache.dtype
+ B = cu_seqlens_qk.numel() - 1
+ N, H_kv_raw, D = k_append_raw.shape
+ assert H_kv_raw == H_kv
+
+ # appended lengths
+ L_app = (cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).to(torch.int32) # [B]
+ seqlen_k = L_cache_per_b + L_app # [B]
+
+ cu_seqlens_k = torch.empty(B + 1, device=device, dtype=torch.int32)
+ cu_seqlens_k[0] = 0
+
+ total_k = int(seqlen_k.sum().item())
+ K_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
+ V_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
+
+ for b in range(B):
+ offset_k = int(cu_seqlens_k[b].item())
+ Lc = int(L_cache_per_b[b].item())
+ La = int(L_app[b].item())
+ q_start = int(cu_seqlens_qk[b].item())
+
+ # cache segment
+ for g in range(H_kv):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b, g, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_total[offset_k + i, g] = K_cache[idx]
+ V_total[offset_k + i, g] = V_cache[idx]
+
+ # appended segment
+ if k_append_raw.numel() > 0:
+ for g in range(H_kv):
+ for j in range(La):
+ src = q_start + j
+ dst = offset_k + Lc + j
+ K_total[dst, g] = k_append_raw[src, g]
+ V_total[dst, g] = v_append_raw[src, g]
+
+ cu_seqlens_k[b + 1] = cu_seqlens_k[b] + (Lc + La)
+
+ return K_total, V_total, cu_seqlens_k
+
+
+@pytest.mark.parametrize("workload", WORKLOADS, ids=lambda wl: wl.name)
+def test_causal_sparse_varlen_with_cache(workload: Workload):
+ dtype = torch.float16
+ device = triton.runtime.driver.active.get_active_torch_device()
+ DEFAULT_PAGE_SIZE = 256
+ N_LOGICAL_PAGES_MAX = 256
+ L_cache_per_b = torch.as_tensor(
+ workload.cache_lens, device=device, dtype=torch.int32
+ )
+ K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
+ build_paged_cache_from_lengths(
+ B=workload.batch_size,
+ H_kv=workload.nk_heads,
+ D=workload.head_dim,
+ PAGE_SIZE=DEFAULT_PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ L_cache_per_b=L_cache_per_b,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ assert len(workload.append_lens) == workload.batch_size
+ cu = [0]
+ for L in workload.append_lens:
+ cu.append(cu[-1] + L)
+ cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
+ N = int(cu_seqlens_qk[-1].item())
+
+ q_raw = torch.randn(
+ N, workload.nq_heads, workload.head_dim, device=device, dtype=dtype
+ )
+ k_append_raw = torch.randn(
+ N, workload.nk_heads, workload.head_dim, device=device, dtype=dtype
+ )
+ v_append_raw = torch.randn_like(k_append_raw)
+
+ batch_mapping = torch.arange(workload.batch_size, device=device, dtype=torch.int32)
+
+ sm_scale = 1.0 / math.sqrt(workload.head_dim)
+ K_total, V_total, cu_seqlens_k = materialize_kv_for_flash_mixed(
+ K_cache=K_cache,
+ V_cache=V_cache,
+ page_table=page_table,
+ L_cache_per_b=L_cache_per_b,
+ k_append_raw=k_append_raw,
+ v_append_raw=v_append_raw,
+ cu_seqlens_qk=cu_seqlens_qk,
+ H_kv=workload.nk_heads,
+ PAGE_SIZE=DEFAULT_PAGE_SIZE,
+ )
+
+ max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
+ max_seqlen_k = int((cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item())
+ max_seqlen_k_triton = seq_lens_bh.max().item()
+ out_triton = causal_sparse_varlen_with_cache(
+ q=q_raw,
+ k_cache=K_cache,
+ v_cache=V_cache,
+ k=k_append_raw,
+ v=v_append_raw,
+ seq_lens_bh=seq_lens_bh,
+ global_page_table=page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=cu_seqlens_qk,
+ HKV=workload.nk_heads,
+ PAGE_SIZE=DEFAULT_PAGE_SIZE,
+ sm_scale=sm_scale,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k_cache=max_seqlen_k_triton,
+ )
+ out_flash = flash_attn_varlen_func(
+ q=q_raw,
+ k=K_total,
+ v=V_total,
+ cu_seqlens_q=cu_seqlens_qk,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ dropout_p=0.0,
+ softmax_scale=sm_scale,
+ causal=True,
+ )
+ assert torch.allclose(out_triton, out_flash, rtol=1e-6, atol=3e-3)
+ max_diff = (out_triton - out_flash).abs().max().item()
+ logger.info(
+ f"[causal_sparse_varlen_with_cache: {workload.name}]: max abs diff={max_diff: .5f}"
+ )
+
+
+def materialize_kv_cache_for_flash_decode(
+ K_cache,
+ V_cache,
+ page_table,
+ L_cache_per_b, # [B] int32
+ H_kv: int,
+ PAGE_SIZE: int,
+):
+ """
+ Build (K_flash, V_flash) suitable for flash_attn_with_kvcache, with shape:
+ (B, seqlen_cache_max, H_kv, D)
+
+ For each batch b:
+ - cache_seqlen[b] = L_cache_per_b[b]
+ - K_flash[b, :cache_seqlen[b], g] and V_flash[...] are filled from the paged KV cache.
+ - Tokens beyond cache_seqlen[b] (if any) are left as zeros and will be masked out
+ by flash_attn_with_kvcache via cache_seqlens.
+ """
+ device = K_cache.device
+ dtype = K_cache.dtype
+ B = L_cache_per_b.shape[0]
+ D = K_cache.shape[1]
+
+ seqlen_cache_max = int(L_cache_per_b.max().item())
+ K_flash = torch.zeros((B, seqlen_cache_max, H_kv, D), device=device, dtype=dtype)
+ V_flash = torch.zeros_like(K_flash)
+
+ for b in range(B):
+ Lc = int(L_cache_per_b[b].item())
+ if Lc == 0:
+ continue
+ for g in range(H_kv):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b, g, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_flash[b, i, g] = K_cache[idx]
+ V_flash[b, i, g] = V_cache[idx]
+
+ return K_flash, V_flash
+
+
+@pytest.mark.parametrize("workload", WORKLOADS_DECODE, ids=lambda wl: wl.name)
+def test_sparse_decode_attention(workload: Workload):
+ dtype = torch.float16
+ device = triton.runtime.driver.active.get_active_torch_device()
+ DEFAULT_PAGE_SIZE = 256
+ N_LOGICAL_PAGES_MAX = 256
+
+ # per-sequence cache lengths (all equal for WORKLOADS_DECODE)
+ L_cache_per_b = torch.as_tensor(
+ workload.cache_lens, device=device, dtype=torch.int32
+ )
+
+ # build paged KV cache used by the Triton kernel
+ K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
+ build_paged_cache_from_lengths(
+ B=workload.batch_size,
+ H_kv=workload.nk_heads,
+ D=workload.head_dim,
+ PAGE_SIZE=DEFAULT_PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ L_cache_per_b=L_cache_per_b,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ B = workload.batch_size
+ HQ = workload.nq_heads
+ HKV = workload.nk_heads
+ D = workload.head_dim
+
+ # Triton kernel expects q: [B, HQ, D]
+ q_triton = torch.randn(B, HQ, D, device=device, dtype=dtype)
+ batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
+ sm_scale = 1.0 / math.sqrt(D)
+
+ out_triton = head_sparse_decode_attention(
+ q=q_triton,
+ k=K_cache,
+ v=V_cache,
+ seq_lens_bh=seq_lens_bh,
+ global_page_table=page_table,
+ batch_mapping=batch_mapping,
+ HKV=HKV,
+ PAGE_SIZE=DEFAULT_PAGE_SIZE,
+ sm_scale=sm_scale,
+ ) # [B, HQ, D]
+
+ # materialize contiguous KV cache with shape [B, seqlen_cache_max, HKV, D]
+ K_flash, V_flash = materialize_kv_cache_for_flash_decode(
+ K_cache=K_cache,
+ V_cache=V_cache,
+ page_table=page_table,
+ L_cache_per_b=L_cache_per_b,
+ H_kv=HKV,
+ PAGE_SIZE=DEFAULT_PAGE_SIZE,
+ )
+
+ # flash_attn_with_kvcache expects q: [B, seqlen_q, HQ, D]
+ q_flash = q_triton.unsqueeze(1) # seqlen_q = 1
+
+ out_flash = flash_attn_with_kvcache(
+ q=q_flash,
+ k_cache=K_flash,
+ v_cache=V_flash,
+ cache_seqlens=L_cache_per_b,
+ softmax_scale=sm_scale,
+ causal=True,
+ ).squeeze(1) # [B, 1, HQ, D]
+
+ assert torch.allclose(out_triton, out_flash, rtol=1e-6, atol=3e-3)
+ max_diff = (out_triton - out_flash).abs().max().item()
+ logger.info(
+ f"[head_sparse_decode_attention: {workload.name}]: max abs diff={max_diff: .5f}"
+ )
diff --git a/vllm/compactor-vllm/vllm_memory_comparison.png b/vllm/compactor-vllm/vllm_memory_comparison.png
new file mode 100644
index 0000000000000000000000000000000000000000..a7c63a92777f9d12e6c4ebe7099dde2ec044eebf
Binary files /dev/null and b/vllm/compactor-vllm/vllm_memory_comparison.png differ
diff --git a/vllm/compactor-vllm/vllm_throughput_comparison.png b/vllm/compactor-vllm/vllm_throughput_comparison.png
new file mode 100644
index 0000000000000000000000000000000000000000..389a8812bd50e9a78dd1993e11c47655992bd99d
Binary files /dev/null and b/vllm/compactor-vllm/vllm_throughput_comparison.png differ
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 5909b304300751fdf2dc1200fbafcf910ce7e725..79c2a09f529818aa8b60fdd234b8ccf99f237c58 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
+import os
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -95,6 +96,7 @@ from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
+ from vllm.kvprune.integration.compression_params import CompressionParams
from vllm.v1.metrics.reader import Metric
logger = init_logger(__name__)
@@ -184,6 +186,15 @@ class LLM:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
+ kvprune_compression: If True, sets ``enforce_eager=True`` for the **v1**
+ engine only (no v1 CUDA graph capture). If ``None`` (default), read
+ ``VLLM_KVPRUNE_COMPRESSION_DEFAULT`` (``"0"`` = allow v1 graphs;
+ ``"1"`` = skip v1 graphs). This is independent of the compactor's
+ ``LLMConfig.enforce_eager`` (see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` /
+ ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``; default tries compactor graphs).
+ When True, v1's GPU KV pool defaults to **one** block (minimum allowed by
+ the scheduler) unless ``num_gpu_blocks_override`` is passed in ``**kwargs``
+ or ``VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS`` is set (``auto`` = profiled allocation).
enable_return_routed_experts: Whether to return routed experts.
disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
@@ -240,6 +251,7 @@ class LLM:
offload_prefetch_step: int = 1,
offload_params: set[str] | None = None,
enforce_eager: bool = False,
+ kvprune_compression: bool | None = None,
enable_return_routed_experts: bool = False,
disable_custom_all_reduce: bool = False,
hf_token: bool | str | None = None,
@@ -339,6 +351,26 @@ class LLM:
"'examples/offline_inference/data_parallel.py'."
)
+ # v1 ``enforce_eager`` is independent of kvprune compactor ``LLMConfig.enforce_eager``.
+ if kvprune_compression is None:
+ _kvd = os.environ.get("VLLM_KVPRUNE_COMPRESSION_DEFAULT", "0").strip().lower()
+ kvprune_compression = _kvd in ("1", "true", "yes")
+ if kvprune_compression:
+ enforce_eager = True
+ # Reserve minimal v1 GPU KV so compactor can use the rest of VRAM. v1
+ # scheduler requires num_gpu_blocks >= 1; profiling would allocate a
+ # large pool from gpu_memory_utilization. Override:
+ # VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS unset -> 1 block (default)
+ # VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=auto -> profiled (no override)
+ # VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS= -> max(1, int)
+ if "num_gpu_blocks_override" not in kwargs:
+ _v1_kv = os.environ.get("VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS", "").strip()
+ if _v1_kv.lower() in ("auto", "profile"):
+ pass
+ elif not _v1_kv:
+ kwargs["num_gpu_blocks_override"] = 1
+ else:
+ kwargs["num_gpu_blocks_override"] = max(1, int(_v1_kv))
engine_args = EngineArgs(
model=model,
runner=runner,
@@ -405,6 +437,9 @@ class LLM:
)
# Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None
+ # Lazy compactor engine (``vllm.kvprune``) when :meth:`generate` uses compression.
+ self._kvprune_compactor_engine: Any = None
+ self._kvprune_compression_enabled = bool(kvprune_compression)
def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer()
@@ -446,6 +481,7 @@ class LLM:
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
+ compression: "CompressionParams | Sequence[CompressionParams] | None" = None,
) -> list[RequestOutput]:
"""Generates the completions for the input prompts.
@@ -473,6 +509,15 @@ class LLM:
of `prompts`, where each priority value corresponds to the prompt
at the same index.
tokenization_kwargs: Overrides for `tokenizer.encode`.
+ compression: Optional per-prompt KV compression (``vllm.kvprune``). If any
+ prompt has ``compression_ratio < 1.0``, the batch is run on the integrated
+ compactor engine with weights shared from this ``LLM``. Omit or use all
+ ``compression_ratio >= 1`` to use the standard v1 engine only.
+ Use ``kvprune_compression=True`` or ``VLLM_KVPRUNE_COMPRESSION_DEFAULT=1``
+ so the v1 engine skips CUDA graph capture. Compactor decode graphs
+ default on (``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` default ``1``) with
+ eager fallback if capture fails; set ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``
+ to skip compactor graph capture entirely.
Returns:
A list of `RequestOutput` objects containing the
@@ -485,6 +530,41 @@ class LLM:
"Try passing `--runner generate` to use the model as a "
"generative model."
)
+ compression_eff = compression
+ if compression is None and getattr(self, "_kvprune_compression_enabled", False):
+ pc = self.llm_engine.vllm_config.parallel_config
+ if (
+ pc.tensor_parallel_size > 1
+ and pc.pipeline_parallel_size == 1
+ and pc.data_parallel_size == 1
+ ):
+ from vllm.kvprune.integration.compression_params import CompressionParams
+ from vllm.kvprune.integration.compressed_generate import (
+ _normalize_prompt_list,
+ )
+
+ _plist = _normalize_prompt_list(prompts)
+ compression_eff = [
+ CompressionParams(compression_ratio=1.0) for _ in _plist
+ ]
+
+ if compression_eff is not None:
+ from vllm.kvprune.integration.compressed_generate import (
+ try_compressed_generate,
+ )
+
+ compressed_out = try_compressed_generate(
+ self,
+ prompts,
+ sampling_params,
+ compression=compression_eff,
+ use_tqdm=use_tqdm,
+ lora_request=lora_request,
+ priority=priority,
+ tokenization_kwargs=tokenization_kwargs,
+ )
+ if compressed_out is not None:
+ return compressed_out
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
diff --git a/vllm/env_override.py b/vllm/env_override.py
index 181d000a68a75155942195bb97022df76cb96ab2..d704069540a2c35f51a9680cb0e7168d5ba911ca 100644
--- a/vllm/env_override.py
+++ b/vllm/env_override.py
@@ -4,6 +4,60 @@
import importlib.util
import os
+# KV-prune (compactor) shared-weight integration needs the v1 engine in-process
+# (`worker.get_model()` in the parent). Upstream defaults to multiprocess workers
+# (`VLLM_ENABLE_V1_MULTIPROCESSING=1`). If unset, default to in-process so
+# `LLM.generate(..., compression=...)` works without requiring env to be set
+# before `import vllm`. Set `VLLM_ENABLE_V1_MULTIPROCESSING=1` to restore
+# multiprocess workers.
+if "VLLM_ENABLE_V1_MULTIPROCESSING" not in os.environ:
+ os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
+
+# In-process EngineCore (``VLLM_ENABLE_V1_MULTIPROCESSING=0``) shares the process with
+# user code; ``import vllm`` already runs ``import torch`` below. TP workers are then
+# created via multiprocessing. If we use ``fork`` after CUDA has been initialized in
+# the parent, PyTorch raises ``Cannot re-initialize CUDA in forked subprocess``.
+# ``_maybe_force_spawn()`` can miss this when CUDA is still uninitialized at the
+# moment ``get_mp_context()`` runs, so default to ``spawn`` for worker processes unless
+# the user set ``VLLM_WORKER_MULTIPROC_METHOD`` explicitly.
+os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
+
+# Tensor-parallel workers use NCCL, which queries **NVML for topology** (independent of
+# PyTorch device counting). A faulty GPU on the host (e.g. ``nvidia-smi -L`` shows
+# ``Unable to determine the device handle`` for one PCI address) often causes
+# ``nvmlDeviceGetHandleByIndex(k) failed`` and then ``ncclCommInitRank`` errors.
+# Mitigations: fix or isolate the bad GPU; or **before** ``import vllm`` restrict the
+# container to healthy GPUs via UUID, e.g.
+# export NVIDIA_VISIBLE_DEVICES=GPU-xxxx,GPU-yyyy,...
+# (not only ``CUDA_VISIBLE_DEVICES=0,1,2,3``, which can still leave a dead GPU in
+# NVML's enumeration). ``VLLM_KVPRUNE_NCCL_SAFE=1`` only tweaks P2P/IB, not NVML.
+# For Docker, also consider ``--shm-size=10g`` or ``--ipc=host``.
+if os.environ.get("VLLM_KVPRUNE_NCCL_SAFE", "").strip().lower() in (
+ "1",
+ "true",
+ "yes",
+):
+ os.environ.setdefault("NCCL_P2P_DISABLE", "1")
+ os.environ.setdefault("NCCL_IB_DISABLE", "1")
+
+# KV-prune: default ``LLM(kvprune_compression=None)`` to skip v1 CUDA graph capture
+# (``enforce_eager=True`` on v1 only). Tests set ``VLLM_KVPRUNE_COMPRESSION_DEFAULT=0``
+# in ``tests/conftest.py`` before importing vLLM.
+os.environ.setdefault("VLLM_KVPRUNE_COMPRESSION_DEFAULT", "1")
+
+# Before first compactor init: opt-in sleep(level=1)+wake_up to discard v1 KV (tests/conftest
+# also set 0). Default off now that kvprune path can use num_gpu_blocks_override=1 for v1.
+os.environ.setdefault("VLLM_KVPRUNE_RELEASE_V1_KV", "0")
+
+# Optional: ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` (fa_triton / pdtriton / pdfa) or legacy
+# ``VLLM_KVPRUNE_ATTENTION_BACKEND`` see ``vllm/kvprune/integration/config_adapter.py``.
+# Optional: ``VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1`` experimental compactor decode CUDA graphs.
+#
+# When ``LLM(..., kvprune_compression=True)`` (or default-on via
+# ``VLLM_KVPRUNE_COMPRESSION_DEFAULT``), v1's ``num_gpu_blocks_override`` defaults
+# to 1 in ``entrypoints/llm.py`` so the primary engine does not reserve a full
+# profiled KV pool on the same GPU as the compactor. Use
+# ``VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=auto`` for profiled blocks, or a positive int.
def _get_torch_cuda_version():
"""Peripheral function to _maybe_set_cuda_compatibility_path().
diff --git a/vllm/envs.py b/vllm/envs.py
index caa2fb38afb6a8e25c460e8aa7ba2fa6d469d4d2..8d0d7ea76b3eca08d24fac2c6c5694ea878c8248 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -1030,6 +1030,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(
int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))
),
+ # KV-prune / compactor integration (see ``vllm/env_override.py``, ``vllm/kvprune/``).
+ "VLLM_KVPRUNE_ATTENTION_SCHEDULE": lambda: os.getenv(
+ "VLLM_KVPRUNE_ATTENTION_SCHEDULE", ""
+ ),
+ "VLLM_KVPRUNE_ATTENTION_BACKEND": lambda: os.getenv(
+ "VLLM_KVPRUNE_ATTENTION_BACKEND", ""
+ ),
+ "VLLM_KVPRUNE_COMPRESSION_DEFAULT": lambda: os.getenv(
+ "VLLM_KVPRUNE_COMPRESSION_DEFAULT", ""
+ ),
+ "VLLM_KVPRUNE_RELEASE_V1_KV": lambda: os.getenv("VLLM_KVPRUNE_RELEASE_V1_KV", ""),
+ "VLLM_KVPRUNE_NCCL_SAFE": lambda: os.getenv("VLLM_KVPRUNE_NCCL_SAFE", ""),
+ "VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS": lambda: os.getenv(
+ "VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS", ""
+ ),
"VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float(
os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")
),
@@ -1771,6 +1786,12 @@ def compile_factors() -> dict[str, object]:
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
"VLLM_WORKER_MULTIPROC_METHOD",
"VLLM_ENABLE_V1_MULTIPROCESSING",
+ "VLLM_KVPRUNE_ATTENTION_SCHEDULE",
+ "VLLM_KVPRUNE_ATTENTION_BACKEND",
+ "VLLM_KVPRUNE_COMPRESSION_DEFAULT",
+ "VLLM_KVPRUNE_RELEASE_V1_KV",
+ "VLLM_KVPRUNE_NCCL_SAFE",
+ "VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS",
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
"VLLM_CPU_KVCACHE_SPACE",
"VLLM_CPU_MOE_PREPACK",
diff --git a/vllm/kvprune/__init__.py b/vllm/kvprune/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c351684196503dd1c0777c160a700f201f47f288
--- /dev/null
+++ b/vllm/kvprune/__init__.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+KV-cache pruning (compactor-style) under ``vllm.kvprune``.
+
+Use the standard :class:`~vllm.LLM` and pass ``compression=`` to :meth:`~vllm.LLM.generate`
+with :class:`CompressionParams` when any prompt needs ``compression_ratio < 1``. The compactor
+``LLMEngine`` + ``PagedKVCache`` shares weights with vLLM (no second checkpoint).
+
+Subpackages (``attention``, ``kv_cache``, ``compression``, …) implement the compactor
+engine.
+"""
+
+from vllm.kvprune.compression.compression_config import CompressionMethod
+from vllm.kvprune.integration import CompressionParams
+
+__all__ = [
+ "CompressionMethod",
+ "CompressionParams",
+]
diff --git a/vllm/kvprune/attention/__init__.py b/vllm/kvprune/attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0c5bb5b76552eb0d0f03cdfa04f36218699ba69
--- /dev/null
+++ b/vllm/kvprune/attention/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Sparse attention Triton kernels (varlen prefill, decode, compile helpers)."""
+
+from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
+
+__all__ = ["causal_sparse_varlen_with_cache"]
diff --git a/vllm/kvprune/attention/compile_kernels.py b/vllm/kvprune/attention/compile_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5b04d0a7992c4b1253ed986cb30427e8c511887
--- /dev/null
+++ b/vllm/kvprune/attention/compile_kernels.py
@@ -0,0 +1,261 @@
+import argparse
+import logging
+import math
+
+import torch
+from vllm.kvprune.attention.sparse_varlen_kernel import (
+ causal_sparse_varlen_with_cache,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def build_mock_paged_cache_from_lengths(
+ L_cache_per_b: torch.Tensor,
+ HKV: int,
+ D: int,
+ PAGE_SIZE: int,
+ N_LOGICAL_PAGES_MAX: int,
+ device,
+ dtype,
+):
+ B = len(L_cache_per_b)
+ max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
+ assert (L_cache_per_b <= max_len).all()
+
+ seq_lens_bh = torch.empty((B, HKV), dtype=torch.int32, device=device)
+ for b in range(B):
+ seq_lens_bh[b, :].fill_(L_cache_per_b[b])
+
+ num_phys_pages = B * HKV * N_LOGICAL_PAGES_MAX
+ CACHE_SIZE = num_phys_pages * PAGE_SIZE
+
+ K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
+ V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
+ page_table = torch.empty(
+ (B, HKV, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
+ )
+
+ # assign unique physical pages per (b, h, lp)
+ phys_page = 0
+ for b in range(B):
+ for h in range(HKV):
+ for lp in range(N_LOGICAL_PAGES_MAX):
+ page_table[b, h, lp] = phys_page
+ phys_page += 1
+
+ for b in range(B):
+ Lc = int(L_cache_per_b[b].item())
+ for h in range(HKV):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b, h, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_cache[idx] = torch.randn(D, device=device, dtype=dtype)
+ V_cache[idx] = torch.randn(D, device=device, dtype=dtype)
+
+ return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
+
+
+def autotune_causal_sparse_varlen_with_cache(
+ *,
+ max_length: int = 16384,
+ HKV: int = 8,
+ HQ: int = 32,
+ D: int = 128,
+ PAGE_SIZE: int = 128,
+ device: str = "cuda",
+ dtype=torch.float16,
+):
+ """
+ Autotune causal_sparse_varlen_with_cache over a sweep of cache/append lengths.
+ """
+ import itertools
+
+ import tqdm
+
+ N_LOGICAL_PAGES_MAX = ((max_length + PAGE_SIZE - 1) // PAGE_SIZE) * PAGE_SIZE
+ B = 4
+
+ # D must be a power of two (kernel requirement).
+ assert (D & (D - 1)) == 0
+
+ lengths_to_sweep = [0, 256]
+ i = 9
+ while (v := (1 << i)) < max_length:
+ lengths_to_sweep.append(v)
+ i += 1
+
+ combos = list(itertools.product(lengths_to_sweep, repeat=2))
+ logger.info(
+ "tuning kernels. this may take a few minutes, "
+ "but only needs to be run once per LLMConfig"
+ )
+
+ for cache_l, append_l in tqdm.tqdm(combos):
+ if cache_l + append_l == 0:
+ continue
+
+ L_cache_per_b = torch.tensor(
+ [cache_l] * B,
+ device=device,
+ dtype=torch.int32,
+ )
+ assert (L_cache_per_b <= PAGE_SIZE * N_LOGICAL_PAGES_MAX).all()
+ K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
+ build_mock_paged_cache_from_lengths(
+ L_cache_per_b=L_cache_per_b,
+ HKV=HKV,
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ L_app_list = [append_l] * B
+ cu = [0]
+ for L in L_app_list:
+ cu.append(cu[-1] + L)
+ cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
+ N = int(cu_seqlens_qk[-1].item())
+
+ max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
+ max_seqlen_k = seq_lens_bh.max().item()
+ q_raw = torch.randn(N, HQ, D, device=device, dtype=dtype)
+ k_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
+ v_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
+
+ # Identity batch mapping (local batch index == global)
+ batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
+
+ sm_scale = 1.0 / math.sqrt(D)
+
+ causal_sparse_varlen_with_cache(
+ q=q_raw,
+ k_cache=K_cache,
+ v_cache=V_cache,
+ k=k_append_raw,
+ v=v_append_raw,
+ seq_lens_bh=seq_lens_bh,
+ global_page_table=page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=cu_seqlens_qk,
+ HKV=HKV,
+ PAGE_SIZE=PAGE_SIZE,
+ sm_scale=sm_scale,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k_cache=max_seqlen_k,
+ )
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Autotune Triton kernels. "
+ "Results are cached, so this should only need to be run once per configuration."
+ "This script doesn't need to be run, as the kernels will be autotuned at runtime"
+ "if no cached autotuning data exists. Running this before hand will prevent run-time"
+ "autotuning, which will accelerate compactor-vllm at inference time."
+ )
+ parser.add_argument(
+ "--max-length",
+ type=int,
+ default=16384,
+ help="Maximum total sequence length to consider.",
+ )
+ parser.add_argument(
+ "--HKV",
+ type=int,
+ default=8,
+ help="Number of KV heads.",
+ )
+ parser.add_argument(
+ "--HQ",
+ type=int,
+ default=32,
+ help="Number of query heads.",
+ )
+ parser.add_argument(
+ "--D",
+ type=int,
+ default=128,
+ help="Per-head hidden dimension (must be power of 2).",
+ )
+ parser.add_argument(
+ "--page-size",
+ type=int,
+ default=128,
+ help="Page size (tokens per physical page).",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Torch device to run on (e.g. 'cuda', 'cuda:0', 'cpu').",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ default="float16",
+ help="Dtype for tensors: one of {float16, fp16, bfloat16, bf16, float32, fp32}.",
+ )
+ parser.add_argument(
+ "--log-level",
+ type=str,
+ default="INFO",
+ choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
+ help="Logging level.",
+ )
+ return parser.parse_args()
+
+
+def _resolve_dtype(dtype_str: str):
+ s = dtype_str.lower()
+ if s in ("float16", "fp16", "half"):
+ return torch.float16
+ if s in ("bfloat16", "bf16"):
+ return torch.bfloat16
+ if s in ("float32", "fp32"):
+ return torch.float32
+ raise ValueError(f"Unsupported dtype: {dtype_str}")
+
+
+def main():
+ args = _parse_args()
+ logging.basicConfig(
+ level=getattr(logging, args.log_level.upper()),
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ )
+
+ dtype = _resolve_dtype(args.dtype)
+ logger.info(
+ "Starting autotune with max_length=%d, HKV=%d, HQ=%d, D=%d, page_size=%d, "
+ "device=%s, dtype=%s",
+ args.max_length,
+ args.HKV,
+ args.HQ,
+ args.D,
+ args.page_size,
+ args.device,
+ dtype,
+ )
+
+ autotune_causal_sparse_varlen_with_cache(
+ max_length=args.max_length,
+ HKV=args.HKV,
+ HQ=args.HQ,
+ D=args.D,
+ PAGE_SIZE=args.page_size,
+ device=args.device,
+ dtype=dtype,
+ )
+
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s: %(message)s",
+ )
+ main()
diff --git a/vllm/kvprune/attention/fa_paged_bridge.py b/vllm/kvprune/attention/fa_paged_bridge.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fa539bee7bdc94e0e8aef7c8318c558e0b7d27b
--- /dev/null
+++ b/vllm/kvprune/attention/fa_paged_bridge.py
@@ -0,0 +1,244 @@
+# SPDX-License-Identifier: Apache-2.0
+"""FlashAttention paths over compactor paged KV (materialize + FA ops).
+
+Used when :class:`~vllm.kvprune.config.engine_config.KvpruneAttentionSchedule`
+selects FlashAttention for prefill and/or decode while KV **writes** remain on
+Triton (``prefill_store_*``, ``decode_store_kv``).
+
+**Why compactor-vllm looked fine but kvprune ``fa_triton`` + compression did not**
+
+compactor-vllm ``layers/attention.py`` (prefill)::
+
+ use_flash_prefill = (backend == FLASH) or (COMPACTOR_TRITON and not do_compression)
+ if use_flash_prefill:
+ flash_attn_varlen_func(q, k, v, ...) # dense packed Q/K/V, one length per batch
+ elif COMPACTOR_TRITON:
+ causal_sparse_varlen_with_cache(..., seq_lens_bh=...) # paged KV, **per-(b,h)** lengths
+
+So **with compression** (``do_compression``), compactor-vllm **never** runs FlashAttention on
+paged top-K KV; it always uses Triton ``causal_sparse_varlen_with_cache``.
+
+kvprune ``fa_triton`` (``FA_PREFILL_TRITON_DECODE``) keeps the intended split: **FA prefill**
++ **Triton decode**. For compressed prefill it calls :func:`flash_prefill_from_paged`, which
+builds a dense ``[total_k, H_kv, D]`` tensor and calls ``flash_attn_varlen_func``. That layout
+assumes **one cache prefix length per batch row shared by all KV heads** (same ``Lc`` for every
+``g`` when copying from ``k_cache``). Top-K retention instead updates ``bh_lens`` with
+**different** counts per head (``seq_lens_bh`` shape ``[B, HKV]``). Taking ``max(dim=1)``
+(older code) used one ``Lc`` per batch but still filled ``K_total[offset+i, g]`` for every head
+``g`` — heads with **shorter** real cache were **over-read**, corrupting attention.
+
+We therefore **require** ``seq_lens_bh[b, :]`` to be constant in ``h`` for each ``b`` before
+materializing for FA (see :func:`_require_uniform_kv_lens_per_batch_for_fa_materialize`). If your
+retention policy yields unequal per-head lengths, use ``pdtriton`` (Triton prefill) for that
+run, or disable compression while using ``fa_triton``.
+"""
+
+from __future__ import annotations
+
+import math
+from typing import TYPE_CHECKING
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+
+if TYPE_CHECKING:
+ pass
+
+
+def _require_uniform_kv_lens_per_batch_for_fa_materialize(
+ seq_lens_bh: torch.Tensor, *, caller: str
+) -> None:
+ """FlashAttention varlen + dense ``[total_k, H_kv, D]`` layout needs one K length per batch."""
+ if seq_lens_bh.ndim != 2:
+ raise ValueError(f"{caller}: expected seq_lens_bh [B, HKV], got {seq_lens_bh.shape}")
+ row_min = seq_lens_bh.min(dim=1).values
+ row_max = seq_lens_bh.max(dim=1).values
+ if not bool((row_min == row_max).all().item()):
+ raise RuntimeError(
+ f"{caller}: FlashAttention materialization needs identical cached KV lengths "
+ "across KV heads for each batch row (seq_lens_bh[b, :] constant in h). "
+ f"Got per-batch min/max mismatch: min={row_min.tolist()} max={row_max.tolist()}. "
+ "Typical top-K compression uses different counts per head; compactor-vllm uses "
+ "Triton causal_sparse_varlen_with_cache in that case, not FA on materialized paged KV. "
+ "Use schedule ``pdtriton`` (Triton prefill + Triton decode), or disable compression "
+ "for this model run with ``fa_triton``."
+ )
+
+
+def materialize_kv_for_flash_prefill(
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ L_cache_per_b: torch.Tensor,
+ k_append: torch.Tensor,
+ v_append: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ H_kv: int,
+ PAGE_SIZE: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Build packed K/V for :func:`flash_attn_varlen_func` (cache prefix + append)."""
+ device = k_cache.device
+ dtype = k_cache.dtype
+ B = cu_seqlens_q.numel() - 1
+ N, H_kv_raw, D = k_append.shape
+ assert H_kv_raw == H_kv
+ L_app = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(torch.int32)
+ seqlen_k = L_cache_per_b.to(torch.int32) + L_app
+
+ cu_seqlens_k = torch.empty(B + 1, device=device, dtype=torch.int32)
+ cu_seqlens_k[0] = 0
+ total_k = int(seqlen_k.sum().item())
+ K_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
+ V_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
+
+ for b in range(B):
+ offset_k = int(cu_seqlens_k[b].item())
+ Lc = int(L_cache_per_b[b].item())
+ La = int(L_app[b].item())
+ q_start = int(cu_seqlens_q[b].item())
+ b_true = int(batch_mapping[b].item())
+
+ for g in range(H_kv):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b_true, g, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_total[offset_k + i, g] = k_cache[idx]
+ V_total[offset_k + i, g] = v_cache[idx]
+
+ for g in range(H_kv):
+ for j in range(La):
+ src = q_start + j
+ dst = offset_k + Lc + j
+ K_total[dst, g] = k_append[src, g]
+ V_total[dst, g] = v_append[src, g]
+
+ cu_seqlens_k[b + 1] = cu_seqlens_k[b] + (Lc + La)
+
+ return K_total, V_total, cu_seqlens_k
+
+
+def flash_prefill_from_paged(
+ q: torch.Tensor,
+ k_append: torch.Tensor,
+ v_append: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ *,
+ seq_lens_bh_before: torch.Tensor,
+ global_page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ max_seqlen_q: int,
+ PAGE_SIZE: int,
+ HKV: int,
+ sm_scale: float | None,
+) -> torch.Tensor:
+ """Prefill attention via FlashAttention-2 varlen after materializing paged KV + append."""
+ _require_uniform_kv_lens_per_batch_for_fa_materialize(
+ seq_lens_bh_before, caller="flash_prefill_from_paged"
+ )
+ L_cache_per_b = seq_lens_bh_before.max(dim=1).values.to(torch.int32)
+ K_total, V_total, cu_seqlens_k = materialize_kv_for_flash_prefill(
+ k_cache,
+ v_cache,
+ global_page_table,
+ batch_mapping,
+ L_cache_per_b,
+ k_append,
+ v_append,
+ cu_seqlens_q,
+ HKV,
+ PAGE_SIZE,
+ )
+ max_seqlen_k = int((cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item())
+ return flash_attn_varlen_func(
+ q,
+ K_total,
+ V_total,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ softmax_scale=sm_scale if sm_scale is not None else None,
+ causal=True,
+ )
+
+
+def materialize_kv_cache_for_flash_decode(
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ L_cache_per_b: torch.Tensor,
+ H_kv: int,
+ PAGE_SIZE: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Dense ``[B, S, H_kv, D]`` cache for :func:`flash_attn_func` decode."""
+ device = k_cache.device
+ dtype = k_cache.dtype
+ B = L_cache_per_b.shape[0]
+ D = k_cache.shape[1]
+
+ seqlen_cache_max = int(L_cache_per_b.max().item())
+ K_flash = torch.zeros((B, seqlen_cache_max, H_kv, D), device=device, dtype=dtype)
+ V_flash = torch.zeros_like(K_flash)
+
+ for b in range(B):
+ Lc = int(L_cache_per_b[b].item())
+ if Lc == 0:
+ continue
+ b_true = int(batch_mapping[b].item())
+ for g in range(H_kv):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b_true, g, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_flash[b, i, g] = k_cache[idx]
+ V_flash[b, i, g] = v_cache[idx]
+
+ return K_flash, V_flash
+
+
+def flash_decode_from_paged(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ *,
+ seq_lens_bh: torch.Tensor,
+ global_page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ PAGE_SIZE: int,
+ HKV: int,
+ sm_scale: float | None,
+) -> torch.Tensor:
+ """Decode step via FA: ``decode_store_kv`` has already appended the new K/V row."""
+ _require_uniform_kv_lens_per_batch_for_fa_materialize(
+ seq_lens_bh, caller="flash_decode_from_paged"
+ )
+ L_cache_per_b = seq_lens_bh.max(dim=1).values.to(torch.int32)
+ K_flash, V_flash = materialize_kv_cache_for_flash_decode(
+ k_cache,
+ v_cache,
+ global_page_table,
+ batch_mapping,
+ L_cache_per_b,
+ HKV,
+ PAGE_SIZE,
+ )
+ B, HQ, D = q.shape
+ q_b = q.unsqueeze(1)
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+ # One query position attends to all L keys already materialized in K/V (no causal mask).
+ out = flash_attn_func(
+ q_b,
+ K_flash,
+ V_flash,
+ softmax_scale=sm_scale,
+ causal=False,
+ )
+ return out.squeeze(1)
diff --git a/vllm/kvprune/attention/sparse_decode_kernel.py b/vllm/kvprune/attention/sparse_decode_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..a574622c7b6198604ccc1865698e962283fbe053
--- /dev/null
+++ b/vllm/kvprune/attention/sparse_decode_kernel.py
@@ -0,0 +1,405 @@
+import functools
+import math
+
+import torch
+import triton
+import triton.language as tl
+
+from vllm.kvprune.utils.triton_compat import (
+ autotune as triton_autotune,
+ maybe_set_allocator,
+)
+
+
+def head_sparse_decode_attention(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_lens_bh: torch.Tensor,
+ global_page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ HKV: int,
+ PAGE_SIZE: int,
+ sm_scale: float = None,
+ key_split: int = None,
+):
+ """
+ Decode-time head-sparse attention over a paged KV cache.
+
+ This is a wrapper around the Triton decode kernel used during incremental
+ generation. For each batch, we read the cached keys
+ and values from a global paged KV buffer, apply causal attention with one
+ new query token, and return the attention output.
+
+ The KV cache is stored in a single global K/V tensor of shape
+ ``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
+ (batch, kv_head, token_idx) is mapped to a physical row in the cache by:
+
+ 1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
+ 2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.
+
+ Grouped-query attention (GQA / MQA) is supported by passing more query
+ heads than KV heads (``HQ`` must be a multiple of ``HKV``).
+
+ Args:
+ :param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]``
+ containing the new decode tokens for each sequence in the launch batch.
+ :param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
+ backing buffer for all (batch, head) KV pages.
+ :param v: Global value cache of shape ``[CACHE_SIZE, D]``.
+ :param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
+ local batch index and KV head, the number of valid cached tokens
+ in the paged KV cache.
+ :param global_page_table: Tensor of shape
+ ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
+ ``(true_batch_idx, kv_head, logical_page)`` to a physical page id
+ in the global cache.
+ :param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
+ index used by this call to the true batch row used to index
+ ``global_page_table``.
+ :param HKV: Number of KV heads.
+ :param PAGE_SIZE: Number of tokens stored per physical KV page.
+ :param sm_scale: Optional scaling factor applied to the attention logits
+ before softmax. If ``None``, ``1 / sqrt(D)`` is used.
+ :param key_split: Optional number of splits along the key sequence length.
+ If > 1, the kernel will process the KV sequence in ``key_split``
+ chunks to reduce on-chip memory usage. If ``None`` or 0, a
+ heuristic is used.
+
+ Returns:
+ :return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
+ device and dtype as ``q``.
+ """
+
+ with torch.cuda.device(q.device):
+ if q.ndim != 3:
+ assert q.ndim == 4
+ B, HQ, S, D = q.shape
+ assert S == 1, "head_sparse_decode_attention only supports q_len=1"
+ q = q.squeeze(-2)
+ elif q.ndim == 3:
+ B, HQ, D = q.shape
+
+ CACHE_SIZE = k.shape[0]
+ assert PAGE_SIZE % 32 == 0, "PAGE_SIZE must be divisible by 32"
+ GROUP_M = HQ // HKV
+ assert GROUP_M * HKV == HQ, "HQ must be divisible by H_kv"
+
+ FP8 = hasattr(torch, "float8_e5m2") and q.dtype == torch.float8_e5m2
+
+ seq_lens_bh = seq_lens_bh.to(torch.int32)
+ assert B <= 32767, "too many batches"
+ assert global_page_table.shape[1] == HKV
+ assert q.is_contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ global_page_table = global_page_table.contiguous()
+ batch_mapping = batch_mapping.contiguous()
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
+
+ sm_scale = 1 / math.sqrt(D) if sm_scale is None else sm_scale
+ if key_split is None:
+ # round max_seq_len to the next power of two to maximize cache hits
+ key_split = num_splits_heuristic(
+ B * HKV,
+ max_seq_len=1 << int(seq_lens_bh.max()).bit_length(),
+ num_sms=torch.cuda.get_device_properties(
+ q.device
+ ).multi_processor_count,
+ max_splits=12,
+ )
+
+ maybe_set_allocator(
+ lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
+ )
+
+ # stage 1 scratch
+ mid_o = torch.empty((B, key_split, HQ, D), device=q.device, dtype=q.dtype)
+ mid_lse = torch.empty((B, key_split, HQ), device=q.device, dtype=torch.float32)
+ # processes all queries for a KV head together
+ # pointers are lowercase, CONSTANTS are upper
+ grid1 = (B, HKV, key_split)
+ _varkv_stage1_groupM[grid1](
+ q=q,
+ k=k,
+ v=v,
+ mid_o=mid_o,
+ mid_lse=mid_lse,
+ page_table_bhl=global_page_table,
+ batch_mapping=batch_mapping,
+ seq_lens_bh=seq_lens_bh.contiguous(),
+ SM_SCALE=sm_scale,
+ B=B,
+ HKV=HKV,
+ HQ=HQ,
+ CACHE_SIZE=CACHE_SIZE,
+ STRIDE_LBS=mid_lse.stride(0),
+ STRIDE_LS=mid_lse.stride(1),
+ STRIDE_LH=mid_lse.stride(2),
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ D=D,
+ KEY_SPLIT=key_split,
+ GROUP_M=GROUP_M,
+ DTYPE=tl.float8e5
+ if FP8
+ else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
+ PAGE_SIZE=PAGE_SIZE,
+ )
+
+ if key_split == 1:
+ return mid_o.squeeze(1).contiguous()
+
+ # reduce partial results across splits
+ output = torch.empty_like(q)
+ grid2 = (B, HQ)
+ _varkv_stage2_reduce[grid2](
+ mid_o=mid_o,
+ mid_lse=mid_lse,
+ output=output,
+ STRIDE_LBS=mid_lse.stride(0),
+ STRIDE_LS=mid_lse.stride(1),
+ STRIDE_LH=mid_lse.stride(2),
+ STRIDE_OBS=output.stride(0),
+ STRIDE_OH=output.stride(1),
+ B=B,
+ HQ=HQ,
+ D=D, # type: ignore
+ KEY_SPLIT=key_split, # type: ignore
+ DTYPE=tl.float8e5
+ if FP8
+ else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
+ )
+ return output
+
+
+# similar to flash attention split heuristic
+@functools.lru_cache(maxsize=128)
+def num_splits_heuristic(
+ total_mblocks: int,
+ max_seq_len: int,
+ num_sms: int,
+ max_splits: int,
+) -> int:
+ # If we nearly fill SMs already, prefer 1 split
+ if total_mblocks >= 0.8 * num_sms or max_seq_len <= 1024:
+ return 1
+ eff = []
+ max_eff = 0.0
+ for s in range(1, min(max_splits, num_sms) + 1):
+ if (max_seq_len / s) <= 512:
+ break
+ n_waves = float(total_mblocks * s) / float(num_sms)
+ e = n_waves / math.ceil(n_waves) if n_waves > 0 else 0.0
+ eff.append(e)
+ max_eff = max(max_eff, e)
+ threshold = 0.75 * max_eff # if not split_min_hit else 0.9 * max_eff
+ for i, e in enumerate(eff, start=1):
+ if e >= threshold:
+ return i
+ return 1
+
+
+def prune_invalid_configs(configs, _, **kwargs):
+ PAGE_SIZE = kwargs["PAGE_SIZE"]
+ return [conf for conf in configs if conf.kwargs.get("BLOCK_N", 0) <= PAGE_SIZE]
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_N": BLOCK_N, "MIN_BLOCK_KV": MIN_BLOCK_KV, "WARPSPEC": ws},
+ num_warps=w,
+ num_stages=s,
+ )
+ for BLOCK_N in [32, 64, 128]
+ for MIN_BLOCK_KV in [8]
+ for s in [2, 3, 4]
+ for w in [4, 8]
+ for ws in [True, False]
+ ],
+ key=[
+ "HKV",
+ "GROUP_M",
+ "D",
+ "PAGE_SIZE", # "B"
+ ],
+ cache_results=True,
+ prune_configs_by={"early_config_prune": prune_invalid_configs},
+)
+@triton.jit
+def _varkv_stage1_groupM(
+ q, # [B, HQ, D] contiguous
+ k, # GLOBAL cache: [CACHE_SIZE, D], contiguous
+ v, # GLOBAL cache: [CACHE_SIZE, D], contiguous
+ mid_o,
+ mid_lse,
+ page_table_bhl, # int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
+ batch_mapping, # int32 [B] maps local pid_b -> true batch index
+ seq_lens_bh, # int32 [B*H_kv] valid tokens per (b,h)
+ SM_SCALE,
+ B,
+ HKV,
+ HQ,
+ CACHE_SIZE, # CACHE_SIZE = N_PAGES * PAGE_SIZE
+ STRIDE_LBS,
+ STRIDE_LS,
+ STRIDE_LH,
+ # constexprs
+ N_LOGICAL_PAGES_MAX: tl.constexpr, # page table width per (b,h)
+ D: tl.constexpr,
+ KEY_SPLIT: tl.constexpr,
+ GROUP_M: tl.constexpr,
+ DTYPE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ MIN_BLOCK_KV: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+):
+ pid_b = tl.program_id(0) # batch
+ pid_kvh = tl.program_id(1) # kv head
+ pid_s = tl.program_id(2) # split
+
+ # valid length L for this (b,h)
+ bh_stride = HKV
+ L = tl.load(seq_lens_bh + pid_b * bh_stride + pid_kvh)
+ if L == 0:
+ return
+
+ tl.assume(L > 0)
+
+ # split sizing on logical token axis [0..L)
+ base = tl.cdiv(L, KEY_SPLIT)
+ per_split_len = tl.cdiv(base, MIN_BLOCK_KV) * MIN_BLOCK_KV
+ split_start = pid_s * per_split_len
+ split_end = tl.minimum(split_start + per_split_len, L)
+
+ # query heads mapped to this kv head
+ base_qh = pid_kvh * GROUP_M
+ GROUP_M_PAD: tl.constexpr = 16 if GROUP_M < 16 else GROUP_M
+ offs_m = tl.arange(0, GROUP_M_PAD)
+ mask_m = offs_m < GROUP_M
+ offs_d = tl.arange(0, D)
+
+ # load Q tile [M, D]
+ q_ptrs = q + (pid_b * HQ + base_qh + offs_m)[:, None] * D + offs_d[None, :]
+ q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0).to(DTYPE) # [M, D]
+
+ # streaming softmax state per query
+ e_max = tl.zeros([GROUP_M_PAD], dtype=tl.float32) - float("inf")
+ e_sum = tl.zeros([GROUP_M_PAD], dtype=tl.float32)
+ acc = tl.zeros([GROUP_M_PAD, D], dtype=tl.float32)
+
+ if split_end > split_start:
+ # logical pages covering [split_start, split_end)
+ lp0 = split_start // PAGE_SIZE
+ lp1 = tl.cdiv(split_end, PAGE_SIZE) # exclusive
+
+ mapped_b = tl.load(batch_mapping + pid_b)
+ tl.assume(mapped_b >= 0)
+ # page table base for this (b,h)
+ pt_stride = N_LOGICAL_PAGES_MAX
+ pt_base = (mapped_b * HKV + pid_kvh) * pt_stride
+
+ for lp in tl.range(lp0, lp1):
+ phys = tl.load(
+ page_table_bhl + pt_base + lp, cache_modifier=".cg"
+ ) # physical page id
+ # bounds within the logical page
+ local_start = tl.where(lp == lp0, split_start - lp * PAGE_SIZE, 0)
+ local_end = tl.where(lp == (lp1 - 1), split_end - lp * PAGE_SIZE, PAGE_SIZE)
+
+ page_base = phys * PAGE_SIZE
+ page_base = tl.multiple_of(page_base, BLOCK_N)
+ for s in tl.range(local_start, local_end, BLOCK_N):
+ s = tl.multiple_of(s, MIN_BLOCK_KV)
+ offs_bn = tl.arange(0, BLOCK_N)
+ key_idx = page_base + s + offs_bn
+ k_ptrs = k + key_idx[:, None] * D + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
+ qk = tl.dot(q, k_blk.T) * SM_SCALE # [M, BN]
+
+ offs_n = s + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < local_end
+ qk = tl.where(mask_n[None, :], qk, -float("inf"))
+
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max) # [M]
+ re_scale = tl.exp(e_max - n_e_max) # [M]
+ acc = acc * re_scale[:, None] # [M, D]
+ v_ptrs = v + key_idx[:, None] * D + offs_d[None, :]
+ v_blk = tl.load(v_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
+ p = tl.exp(qk - n_e_max[:, None]) # [M, BN]
+ acc = tl.dot(p.to(DTYPE), v_blk, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # write mid outputs [M, D] for this split
+ tmp = (acc / e_sum[:, None]).to(DTYPE)
+ row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
+ mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
+ tl.store(mid_ptrs, tmp, mask=mask_m[:, None])
+
+ ml_ptrs = (
+ mid_lse
+ + pid_b * STRIDE_LBS
+ + pid_s * STRIDE_LS
+ + (base_qh + offs_m) * STRIDE_LH
+ )
+ safe_sum = tl.where(mask_m, e_sum, 1.0)
+ tl.store(ml_ptrs, e_max + tl.log(safe_sum), mask=mask_m)
+ else:
+ # empty split
+ zero_md = tl.zeros([GROUP_M_PAD, D], dtype=DTYPE)
+ row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
+ mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
+ tl.store(mid_ptrs, zero_md, mask=mask_m[:, None])
+ ml_ptrs = (
+ mid_lse
+ + pid_b * STRIDE_LBS
+ + pid_s * STRIDE_LS
+ + (base_qh + offs_m) * STRIDE_LH
+ )
+ tl.store(ml_ptrs, -float("inf"), mask=mask_m)
+
+
+@triton.jit
+def _varkv_stage2_reduce(
+ mid_o,
+ mid_lse,
+ output,
+ STRIDE_LBS,
+ STRIDE_LS,
+ STRIDE_LH,
+ STRIDE_OBS,
+ STRIDE_OH,
+ B,
+ HQ,
+ D: tl.constexpr,
+ KEY_SPLIT: tl.constexpr,
+ DTYPE: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
+ offs_d = tl.arange(0, D)
+
+ # across split LSE combine
+ e_sum = 0.0
+ e_max = -float("inf")
+ acc = tl.zeros([D], dtype=tl.float32)
+
+ for s in tl.range(KEY_SPLIT):
+ row_mid = pid_b * (KEY_SPLIT * HQ) + s * HQ + pid_h
+ tv = tl.load(mid_o + row_mid * D + offs_d).to(DTYPE)
+ tl_ptr = mid_lse + pid_b * STRIDE_LBS + s * STRIDE_LS + pid_h * STRIDE_LH
+ tlogic = tl.load(tl_ptr)
+
+ n_e_max = tl.maximum(e_max, tlogic)
+ old_scale = tl.exp(e_max - n_e_max)
+ acc = acc * old_scale + tl.exp(tlogic - n_e_max) * tv.to(tl.float32)
+ e_sum = e_sum * old_scale + tl.exp(tlogic - n_e_max)
+ e_max = n_e_max
+
+ o = (acc / e_sum).to(DTYPE)
+ o_ptr = output + pid_b * STRIDE_OBS + pid_h * STRIDE_OH + offs_d
+ tl.store(o_ptr, o)
diff --git a/vllm/kvprune/attention/sparse_varlen_kernel.py b/vllm/kvprune/attention/sparse_varlen_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c306859caed3077351e8bdeb95ba4ba2a20a4ba
--- /dev/null
+++ b/vllm/kvprune/attention/sparse_varlen_kernel.py
@@ -0,0 +1,600 @@
+import logging
+import math
+
+import torch
+import triton
+import triton.language as tl
+from flash_attn.flash_attn_interface import flash_attn_varlen_func
+
+from vllm.kvprune.utils.triton_compat import (
+ autotune as triton_autotune,
+ cuda_capability_geq,
+ maybe_set_allocator,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _causal_appended_only_exact(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ *,
+ sm_scale: float,
+ max_seqlen_q: int,
+) -> torch.Tensor:
+ """Exact zero-prefix prefill attention over appended q/k/v only.
+
+ This is the mathematically correct subcase of
+ :func:`causal_sparse_varlen_with_cache` when there is no cached KV prefix.
+ It avoids the problematic Triton on-band appended branch while preserving
+ ``pdtriton`` semantics for later cached-prefix steps. Use the same
+ ``flash_attn_varlen_func`` path as the debug reference so this subcase is
+ numerically identical to the known-good result.
+ """
+ return flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_q,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_q,
+ softmax_scale=sm_scale,
+ causal=True,
+ )
+
+
+def causal_sparse_varlen_with_cache(
+ q,
+ k,
+ v,
+ k_cache,
+ v_cache,
+ seq_lens_bh,
+ global_page_table,
+ batch_mapping,
+ cu_seqlens_q,
+ max_seqlen_q: int,
+ max_seqlen_k_cache: int,
+ HKV: int,
+ PAGE_SIZE: int,
+ sm_scale=None,
+):
+ """
+ Causal prefill attention over a paged KV cache plus a block of newly
+ appended tokens in a packed batch format.
+
+ This function wraps the Triton kernel
+ ``_causal_head_sparse_varlen_with_cache`` to compute prefill attention for
+ a batch of variable-length sequences, where:
+ 鈥?Past keys/values are stored in a paged global KV cache
+ (``k_cache``, ``v_cache``) with a (per-layer) page table.
+
+ 鈥?New tokens for this step are given as K/V blocks
+ (``k``, ``v``), together with a packed query block ``q``.
+
+ 鈥?The result is equivalent to applying causal attention over the
+ concatenation of:
+ [ cached KV prefix || (K_app, V_app) for this step ]
+ for each sequence in the batch.
+
+ Grouped-query attention (GQA / MQA) is supported by allowing more query
+ heads than KV heads: ``HQ`` must be divisible by ``HKV``.
+
+ Args:
+ :param q:
+ Query tensor of shape ``[N, HQ, D]`` (float16 / bfloat16/float32).
+ ``N`` is the total number of new tokens across the batch
+ (i.e. ``N = sum_b seqlen_q[b]``), packed according to
+ ``cu_seqlens_q``. ``HQ`` is the number of query heads, ``D`` the
+ head dimension (must be a power of two).
+ :param k:
+ New key tensor of shape ``[N, HKV, D]`` for the same tokens as
+ ``q``. These are the K values appended to the cache for this
+ prefill step.
+ :param v:
+ New value tensor of shape ``[N, HKV, D]`` for the same tokens as
+ ``q``.
+ :param k_cache:
+ Global key cache backing buffer of shape ``[CACHE_SIZE, D]``.
+ Keys for all cached tokens and heads are stored here; the mapping
+ from (batch, head, token index) to a row in this buffer is
+ given by ``global_page_table``.
+ :param v_cache:
+ Global value cache of shape ``[CACHE_SIZE, D]``. Must have the
+ same layout as ``k_cache`` (same ``CACHE_SIZE`` and ``D``).
+ :param seq_lens_bh:
+ Tensor of shape ``[B, HKV]`` (int32) giving, for each local batch
+ index and KV head, the number of cached tokens already present
+ in the paged KV cache before this prefill step.
+ :param global_page_table:
+ Tensor of shape ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32)
+ mapping ``(true_batch_idx, kv_head, logical_page)`` to a physical
+ page id in the global KV cache. A physical page id `p` refers to
+ the slice:
+ ``k_cache[p * PAGE_SIZE : (p + 1) * PAGE_SIZE]``.
+ :param batch_mapping:
+ Tensor of shape ``[B]`` (int16 / int32) mapping the local batch
+ index used in this kernel launch to the global batch index used
+ to index ``global_page_table``. This allows the same global cache
+ to be shared across multiple microbatches.
+ :param cu_seqlens_q:
+ Tensor of shape ``[B + 1]`` (int32) with cumulative sequence
+ lengths for the *new* tokens (q/k/v) in packed form. For batch
+ element ``b``:
+ ``seqlen_q[b] = cu_seqlens_q[b + 1] - cu_seqlens_q[b]``.
+ The total number of tokens satisfies
+ ``N = cu_seqlens_q[-1]``.
+ :param max_seqlen_q:
+ Maximum new query sequence length across the batch, i.e.
+ ``max_b seqlen_q[b]``.
+ :param max_seqlen_k_cache:
+ Maximum cached sequence length across (batch, KV head), i.e.
+ ``max_{b,h} seq_lens_bh[b, h]``.
+ :param HKV:
+ Number of KV heads. Must divide ``HQ``.
+ :param PAGE_SIZE:
+ Number of tokens stored per physical page in the paged KV cache.
+ ``CACHE_SIZE`` must be divisible by ``PAGE_SIZE``.
+ :param sm_scale:
+ Optional scaling factor applied to the attention logits before
+ softmax. If ``None``, defaults to ``1.0 / sqrt(D)``.
+ :returns torch.Tensor:
+ Attention output of shape ``[N, HQ, D]``, with the same dtype and
+ device as ``q``. The output is laid out in the same packed
+ varlen format as the input queries, i.e. the first
+ ``seqlen_q[0]`` rows correspond to batch 0, the next
+ ``seqlen_q[1]`` rows to batch 1, etc.
+ """
+ assert q.ndim == 3, "q should be [N, HQ, D]"
+ N, HQ, D = q.shape
+ assert (D & (D - 1)) == 0, "D must be power of two"
+
+ B = cu_seqlens_q.numel() - 1
+ assert B > 0
+ assert HQ % HKV == 0, "Number of query heads must divide number of keys heads"
+ if max_seqlen_k_cache == 0:
+ # Zero-prefix compressed prefill on DCU produced repeated-character output in
+ # the Triton on-band appended branch; use exact varlen FA for this subcase.
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+ return _causal_appended_only_exact(
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ sm_scale=sm_scale,
+ max_seqlen_q=max_seqlen_q,
+ )
+ H_g = HQ // HKV
+ # view Q as [HKV, N, QUERY_GROUP_SIZE, D]
+ out = torch.empty_like(q)
+ q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+ out = out.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+
+ # K_app/V_app: [N, HKV, D] -> [HKV, N, D]
+ k_app = k.view(N, HKV, D).permute(1, 0, 2)
+ v_app = v.view(N, HKV, D).permute(1, 0, 2)
+ q = q.contiguous()
+ out = out.contiguous()
+ k_app = k_app.contiguous()
+ v_app = v_app.contiguous()
+
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=q.device)
+ seq_lens_bh = seq_lens_bh.to(dtype=torch.int32, device=q.device)
+ batch_mapping = batch_mapping.to(dtype=torch.int16, device=q.device)
+
+ N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
+ CACHE_SIZE = k_cache.shape[0]
+ assert v_cache.shape[0] == CACHE_SIZE
+ assert k_cache.shape[1] == D and v_cache.shape[1] == D
+ assert PAGE_SIZE > 0 and CACHE_SIZE % PAGE_SIZE == 0
+
+ k_cache = k_cache.contiguous()
+ v_cache = v_cache.contiguous()
+ global_page_table = global_page_table.contiguous()
+
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ # strides for Q [G, N, QUERY_GROUP_SIZE, D]
+ STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
+ STRIDE_KC, STRIDE_VC = k_cache.stride(0), v_cache.stride(0)
+ # [G, N, D]
+ STRIDE_KA_G, STRIDE_KA_N, STRIDE_KA_D = k_app.stride()
+ STRIDE_VA_G, STRIDE_VA_N, STRIDE_VA_D = v_app.stride()
+
+ # OUT [G, N, QUERY_GROUP_SIZE, D]
+ STRIDE_OUT_G, STRIDE_OUT_N, STRIDE_OUT_H, STRIDE_OUT_D = out.stride()
+ # launch grid
+ maybe_set_allocator(
+ lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
+ )
+ assert STRIDE_KA_D == STRIDE_VA_D == STRIDE_Q_D == STRIDE_OUT_D == 1, (
+ "final dimension must be contiguous"
+ )
+
+ def grid(META):
+ return HKV, B, triton.cdiv(max_seqlen_q, META["BLOCK_M"])
+
+ # Autotune key must reflect the **total** K length seen by the kernel:
+ # cached prefix + appended tokens from the current prefill chunk.
+ #
+ # Using only `max_seqlen_k_cache` is wrong for the first compressed prefill
+ # step in `pdtriton`: the cache prefix is 0, but the kernel actually attends
+ # over the entire appended prompt (`seq_len_append`). On DCU this can cause
+ # Triton to autotune/select a kernel as if K==1 while executing on a long K,
+ # which has been observed to produce incorrect outputs. We still clamp to 1
+ # to avoid `next_power_of_2(0)`.
+ _k_max_autotune = max(int(max_seqlen_k_cache) + int(max_seqlen_q), 1)
+ AUTOTUNE_MAX_Q_LEN = triton.next_power_of_2(max_seqlen_q)
+ AUTOTUNE_MAX_K_LEN = triton.next_power_of_2(_k_max_autotune)
+ _causal_head_sparse_varlen_with_cache[grid](
+ Q=q,
+ K_cache=k_cache,
+ V_cache=v_cache,
+ K_app=k_app,
+ V_app=v_app,
+ cu_seqlens_qk=cu_seqlens_q,
+ seq_lens_bh=seq_lens_bh,
+ page_table=global_page_table,
+ batch_mapping=batch_mapping,
+ OUT=out,
+ HKV=HKV,
+ QUERY_GROUP_SIZE=H_g,
+ PAGE_SIZE=PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ STRIDE_Q_G=STRIDE_Q_G,
+ STRIDE_Q_N=STRIDE_Q_N,
+ STRIDE_Q_H=STRIDE_Q_H,
+ STRIDE_KC=STRIDE_KC,
+ STRIDE_VC=STRIDE_VC,
+ STRIDE_KA_G=STRIDE_KA_G,
+ STRIDE_KA_N=STRIDE_KA_N,
+ STRIDE_VA_G=STRIDE_VA_G,
+ STRIDE_VA_N=STRIDE_VA_N,
+ STRIDE_OUT_G=STRIDE_OUT_G,
+ STRIDE_OUT_N=STRIDE_OUT_N,
+ STRIDE_OUT_H=STRIDE_OUT_H,
+ sm_scale=sm_scale,
+ D=D,
+ AUTOTUNE_MAX_Q_LEN=AUTOTUNE_MAX_Q_LEN,
+ AUTOTUNE_MAX_K_LEN=AUTOTUNE_MAX_K_LEN,
+ )
+ # permute breaks contiguity; view() requires a single contiguous span.
+ return out.permute(1, 0, 2, 3).reshape(N, HQ, D)
+
+
+autotune_configs_cc9 = [
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=16, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=4, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
+ ),
+]
+
+autotune_configs_cc8 = [
+ triton.Config(
+ {"BLOCK_N": BN, "BLOCK_M": BM, "WARPSPEC": True}, num_warps=w, num_stages=s
+ )
+ for BN in [16, 32]
+ for BM in [64]
+ for w in [4, 8]
+ for s in [2, 3]
+]
+
+
+def prune_invalid_configs(configs, _, **kwargs):
+ return [
+ conf
+ for conf in configs
+ if not (conf.kwargs.get("BLOCK_N") == 32 and conf.kwargs.get("num_stages") == 4)
+ ]
+
+
+def get_autotune_configs():
+ if cuda_capability_geq(9, 0):
+ return autotune_configs_cc9
+ else:
+ return autotune_configs_cc8
+
+
+@triton_autotune(
+ configs=get_autotune_configs(),
+ key=[
+ "HKV",
+ "QUERY_GROUP_SIZE",
+ "D",
+ "PAGE_SIZE",
+ "AUTOTUNE_MAX_K_LEN",
+ "AUTOTUNE_MAX_Q_LEN",
+ ],
+ cache_results=True,
+)
+@triton.jit
+def _causal_head_sparse_varlen_with_cache(
+ Q, # [HKV, N, QUERY_GROUP_SIZE, D] (non-contiguous)
+ K_cache,
+ V_cache, # [CACHE_SIZE, D]
+ K_app,
+ V_app, # [HKV, N, D]
+ cu_seqlens_qk, # [B+1]
+ seq_lens_bh, # [B, HKV]
+ page_table, # [B_total, HKV, N_LOGICAL_PAGES_MAX]
+ batch_mapping, # [B], maps local b -> global batch index
+ OUT, # [HKV, N, QUERY_GROUP_SIZE, D]
+ #
+ HKV: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ N_LOGICAL_PAGES_MAX,
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_KC,
+ STRIDE_VC,
+ STRIDE_KA_G,
+ STRIDE_KA_N,
+ STRIDE_VA_G,
+ STRIDE_VA_N,
+ STRIDE_OUT_G,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ #
+ D: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+ AUTOTUNE_MAX_Q_LEN: tl.constexpr, # used for autotune key
+ AUTOTUNE_MAX_K_LEN: tl.constexpr, # used for autotune key
+):
+ TOTAL_N_QUERIES: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
+ pid_g = tl.program_id(0) # kv_head id in [0, HKV)
+ pid_b = tl.program_id(1) # batch id
+ pid_m = tl.program_id(2) # query-tile id within batch
+
+ # batch segment [qb, qe) in N
+ off_b = tl.load(cu_seqlens_qk + pid_b)
+ off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
+ seq_len_append = off_b1 - off_b
+
+ q_start = off_b + pid_m * BLOCK_M
+ q_end = tl.minimum(q_start + BLOCK_M, off_b1)
+ # number of queries in this tile for this batch
+ M = q_end - q_start
+ if M <= 0:
+ return
+
+ # cached length for (b, kv_head=pid_g)
+ L_cache = tl.load(seq_lens_bh + pid_b * HKV + pid_g)
+ # row indices flattened over [QUERY_GROUP_SIZE, M]
+ offs_row = tl.arange(0, TOTAL_N_QUERIES)
+ row_m = offs_row % BLOCK_M
+ row_h = offs_row // BLOCK_M
+ # valid rows: only those with row_m < M
+ row_mask = row_m < M
+
+ # global query index per row
+ q_idx = q_start + row_m
+ offs_d = tl.arange(0, D)
+ # Q tile: [TOTAL_N_QUERIES, D]
+ # Q layout: [HKV, N, QUERY_GROUP_SIZE, D]
+ q_ptrs = (
+ Q
+ + pid_g * STRIDE_Q_G
+ + q_idx[:, None] * STRIDE_Q_N
+ + row_h[:, None] * STRIDE_Q_H
+ + offs_d[None, :]
+ )
+ q = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+
+ e_max = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32) - float("inf")
+ e_sum = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32)
+ acc = tl.zeros([TOTAL_N_QUERIES, D], dtype=tl.float32)
+
+ offs_block_n = tl.arange(0, BLOCK_N)
+ # Convert natural-log softmax scale into log2 domain for exp2-based updates.
+ # Use the full log2(e) constant; this is mathematically equivalent to exp and
+ # not the source of the zero-prefix bug, but avoids avoidable rounding loss.
+ qk_scale = sm_scale * 1.4426950408889634
+
+ # 1) attend over cachee K/V
+ if L_cache > 0:
+ # map local (b) to global batch index
+ mapped_b = tl.load(batch_mapping + pid_b)
+ pt_base = (mapped_b * HKV + pid_g) * N_LOGICAL_PAGES_MAX
+ # iterate logical pages
+ num_lp = tl.cdiv(L_cache, PAGE_SIZE)
+ for lp in tl.range(0, num_lp):
+ # can overflow in 32 bits so upcast
+ phys = tl.load(page_table + pt_base + lp).to(tl.int64)
+ page_start = phys * PAGE_SIZE
+ # how many valid tokens in this page for this (b,g)
+ remain = L_cache - lp * PAGE_SIZE
+ page_len = tl.minimum(PAGE_SIZE, remain)
+ # iterate over this page in BLOCK_N chunks
+ for ks in tl.range(0, page_len, BLOCK_N):
+ offs_n = ks + offs_block_n
+ mask_n = offs_n < page_len
+
+ key_idx = page_start + offs_n
+ k_ptrs = K_cache + key_idx[:, None] * STRIDE_KC + offs_d[None, :]
+
+ k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
+ qk = tl.dot(q, k.T) * qk_scale # [TOTAL_N_QUERIES, BN]
+ qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
+
+ # softmax update
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ v_ptrs = V_cache + key_idx[:, None] * STRIDE_VC + offs_d[None, :]
+ v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # 2) attend over appended K_app/V_app (causal)
+ # appended tokens for batch b are in [off_b, off_b1)
+ # query tile is [q_start, q_end)
+ # for each query at index q_idx, valid appended keys k satisfy off_b <= k <= q_idx
+ if q_end > off_b:
+ # exactly one appended token
+ if seq_len_append == 1:
+ ka_ptrs = K_app + pid_g * STRIDE_KA_G + off_b * STRIDE_KA_N + offs_d
+ k = tl.load(ka_ptrs) # [D]
+ qk = tl.sum(q * k[None, :], 1) * qk_scale
+ qk = tl.where(row_mask, qk, -1.0e6)
+ n_e_max = tl.maximum(e_max, qk)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max)
+ va_ptrs = V_app + pid_g * STRIDE_VA_G + off_b * STRIDE_VA_N + offs_d
+ v = tl.load(va_ptrs) # [D]
+ acc = acc * re_scale[:, None] + p[:, None] * v[None, :]
+ e_sum = e_sum * re_scale + p
+ else:
+ # off-band: k in [off_b, q_start)
+ # for all queries t in [q_start, q_end), any k < q_start satisfies k <= t.
+ # so no causal mask needed.
+ off_band_start = off_b
+ off_band_end = q_start
+
+ if off_band_end > off_band_start:
+ for ks in tl.range(off_band_start, off_band_end, BLOCK_N):
+ offs_n = ks + offs_block_n
+ mask_n = offs_n < off_band_end
+
+ ka_ptrs = (
+ K_app
+ + pid_g * STRIDE_KA_G
+ + offs_n[:, None] * STRIDE_KA_N
+ + offs_d[None, :]
+ )
+ k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
+
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ va_ptrs = (
+ V_app
+ + pid_g * STRIDE_VA_G
+ + offs_n[:, None] * STRIDE_VA_N
+ + offs_d[None, :]
+ )
+ v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # on-band remaining k
+ on_band_start = tl.maximum(q_start, off_b)
+ if on_band_start < q_end:
+ for ks in tl.range(on_band_start, q_end, BLOCK_N):
+ offs_n = ks + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < q_end
+
+ ka_ptrs = (
+ K_app
+ + pid_g * STRIDE_KA_G
+ + offs_n[:, None] * STRIDE_KA_N
+ + offs_d[None, :]
+ )
+
+ k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+
+ # DCU/ROCm: using a single fused boolean expression here can lead
+ # to early query rows in the tile behaving as if they could attend
+ # to later appended keys in the same on-band block. That shows up
+ # as token-0 output deviating from V[0] while the last token in the
+ # batch remains almost exact. Apply the three masks explicitly.
+ #
+ # Use local positions within the current query tile for the causal
+ # relation: all off-band keys (< q_start) were already handled
+ # above, so the on-band block only needs a lower-triangular mask
+ # relative to q_start.
+ qk = tl.where(row_mask[:, None], qk, -1.0e6)
+ qk = tl.where(mask_n[None, :], qk, -1.0e6)
+ local_q = row_m
+ local_k = offs_n - q_start
+ caus_mask = local_k[None, :] <= local_q[:, None]
+ qk = tl.where(caus_mask, qk, -1.0e6)
+
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ va_ptrs = (
+ V_app
+ + pid_g * STRIDE_VA_G
+ + offs_n[:, None] * STRIDE_VA_N
+ + offs_d[None, :]
+ )
+ v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # 3) write outputs
+ o = (acc / e_sum[:, None]).to(q.dtype)
+ out_ptrs = (
+ OUT
+ + pid_g * STRIDE_OUT_G
+ + q_idx[:, None] * STRIDE_OUT_N
+ + row_h[:, None] * STRIDE_OUT_H
+ + offs_d[None, :]
+ )
+ tl.store(out_ptrs, o, mask=row_mask[:, None])
diff --git a/vllm/kvprune/benchmark/__init__.py b/vllm/kvprune/benchmark/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f8699a480b2a9ab11562d2e9fcb7f546eb8f9b4
--- /dev/null
+++ b/vllm/kvprune/benchmark/__init__.py
@@ -0,0 +1,47 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Benchmark helpers for kv-prune / compactor kernels.
+
+Upstream snapshot (``compactor-vllm/src/compactor_vllm/benchmark``) contained **only**
+an empty ``__init__.py`` — no additional ``.py`` scripts. Those files are merged here
+as-is; there is nothing else to list under that directory in upstream.
+
+Use :data:`BENCHMARK_REGISTRY` to register microbenchmarks or CLI entrypoints you
+add under ``vllm.kvprune.benchmark``.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Callable
+
+# Files copied from upstream ``compactor_vllm/benchmark/`` (relative to that dir).
+UPSTREAM_BENCHMARK_FILES: tuple[str, ...] = ("__init__.py",)
+
+# Optional: name -> benchmark callable or import path string (e.g. "mymod:main").
+# Populated when you add real benchmarks beside this package.
+BENCHMARK_REGISTRY: dict[str, Callable[..., Any] | str] = {}
+
+
+def list_upstream_benchmark_files() -> tuple[str, ...]:
+ """Return the list of filenames that existed in upstream ``benchmark/``."""
+ return UPSTREAM_BENCHMARK_FILES
+
+
+def register_benchmark(name: str, target: Callable[..., Any] | str) -> None:
+ """Register a benchmark by name (callable or ``"module:attr"`` import path)."""
+ BENCHMARK_REGISTRY[name] = target
+
+
+def iter_registered_benchmarks() -> list[tuple[str, Callable[..., Any] | str]]:
+ """Return ``(name, target)`` pairs from :data:`BENCHMARK_REGISTRY`."""
+ return list(BENCHMARK_REGISTRY.items())
+
+
+__all__ = [
+ "BENCHMARK_REGISTRY",
+ "UPSTREAM_BENCHMARK_FILES",
+ "iter_registered_benchmarks",
+ "list_upstream_benchmark_files",
+ "register_benchmark",
+]
diff --git a/vllm/kvprune/compression/__init__.py b/vllm/kvprune/compression/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aee2d4640f8b46955788ac25c6987b4868afde35
--- /dev/null
+++ b/vllm/kvprune/compression/__init__.py
@@ -0,0 +1,41 @@
+from vllm.kvprune.compression.common import (
+ BaseCompressionMethod,
+ NoCompression,
+)
+from vllm.kvprune.compression.criticalkv import CriticalAdaKVCompression
+from vllm.kvprune.compression.compactor import CompactorCompression
+from vllm.kvprune.compression.compression_config import (
+ BatchCompressionParams,
+ CompressionMethod,
+ SequenceCompressionParams,
+)
+from vllm.kvprune.compression.snapkv import SnapKVCompression
+
+COMPRESSION_REGISTRY: dict[CompressionMethod, type[BaseCompressionMethod]] = {
+ CompressionMethod.CRITICALADAKV: CriticalAdaKVCompression,
+ CompressionMethod.COMPACTOR: CompactorCompression,
+ CompressionMethod.SNAPKV: SnapKVCompression,
+ CompressionMethod.NONE: NoCompression,
+}
+
+
+def apply_prerope_compression(q, k, v, context):
+ method = context.compression_context.compression_method
+ return COMPRESSION_REGISTRY[method].pre_rope_scoring(q, k, v, context=context)
+
+
+def apply_postrope_compression(q, k, v, prerope_scores, context):
+ method = context.compression_context.compression_method
+ return COMPRESSION_REGISTRY[method].post_rope_scoring(
+ q, k, v, prerope_scores, context=context
+ )
+
+
+__all__ = [
+ "apply_prerope_compression",
+ "apply_postrope_compression",
+ "CompressionMethod",
+ "BatchCompressionParams",
+ "SequenceCompressionParams",
+ "COMPRESSION_REGISTRY"
+]
diff --git a/vllm/kvprune/compression/common.py b/vllm/kvprune/compression/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..46ce3f1b01ebf81b405f7b46f8792999b71c7557
--- /dev/null
+++ b/vllm/kvprune/compression/common.py
@@ -0,0 +1,324 @@
+from abc import ABC, abstractmethod
+import os
+from typing import Optional
+
+import torch
+
+from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv
+
+
+class BaseCompressionMethod(ABC):
+ """
+ Abstract interface for KV cache compression methods.
+
+ A compression method is implemented as a pair of optional scoring phases
+ that run before and after rotary position embedding (RoPE) is applied:
+
+ 1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.
+
+ 2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
+ - refine / reweight the pre-RoPE scores, or
+ - compute potentially position-aware.
+
+ Concrete subclasses are expected to implement both
+ static methods and return a single tensor of scores (or ``None`` if the
+ phase is a no-op), which the caller can then feed into the shared
+ “scores → top-k indices → KV extraction” pipeline.
+ """
+
+ @staticmethod
+ @abstractmethod
+ def pre_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ """
+ Compute per-token importance scores from pre-RoPE queries/keys.
+
+ Args:
+ :param q:
+ Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
+ :param k:
+ Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
+ :param v:
+ Value tensor. Shape ``[total_tokens, HKV, D]```
+ :param context:
+ vllm.kvprune.utils.context.Context object carrying additional metadata,
+ such as batch mappings or temporary buffers
+
+ Returns:
+ :return Optional[torch.Tensor]:
+ A tensor of scores (e.g. per-token, per-head importance values)
+ to be passed to ``post_rope_scoring`` or directly into the
+ top-k selection step. If this phase is a no-op, implementations
+ should return ``None``. Shape ``[total_tokens, HKV]```.
+ """
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ """
+ Compute or refine importance scores from post-RoPE queries/keys.
+
+ This method is called after rotary embeddings have been applied. It can
+ optionally use both the post-RoPE Q/K and any scores produced by
+ ``pre_rope_scoring`` to produce final scores used for token selection.
+
+ Common patterns include:
+ * Using ``pre_rope_scores`` as a base signal and applying a
+ position-aware correction.
+ * Only computing scores that depend on absolute or relative positions.
+ * Simply passing through ``pre_rope_scores`` unchanged.
+
+ Args:
+ :param q:
+ Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
+ :param k:
+ Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
+ :param pre_rope_scores:
+ Optional scores returned by ``pre_rope_scoring``. May be
+ ``None`` if the pre-RoPE phase returned None.
+ :param v:
+ Value tensor. Shape ``[total_tokens, HKV, D]```
+ :param context:
+ vllm.kvprune.utils.context.Context object carrying additional metadata,
+ such as batch mappings or temporary buffers
+ Returns:
+ :return Optional[torch.Tensor]:
+ Final importance scores to be consumed by the compression
+ pipeline (for top-k token selection). If this phase is a
+ no-op, implementations may return ``pre_rope_scores``. If
+ None is returned, no compression will be applied.
+ """
+ pass
+
+
+class NoCompression(BaseCompressionMethod):
+ """
+ Trivial compression method that disables KV cache compression.
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ return pre_rope_scores
+
+
+def extract_and_store_top_kv(
+ scores: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_k_len: int,
+ top_k: int,
+ H: int,
+ new_keys: torch.Tensor, # [N_total, H, D]
+ new_vals: torch.Tensor, # [N_total, H, D]
+ num_tokens_to_retain: torch.Tensor, # [B] int32
+ page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
+ batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
+ bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
+ k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ PAD_TO_PAGE_SIZE: bool = True,
+ K_TILE: int = 16,
+ padding: float = -float("inf"),
+):
+ """helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
+ # per_head: per-head highest-scoring remaining tokens for page padding.
+ # global_scan: legacy global ranking order, padded by scanning forward in-kernel.
+ padding_mode = os.environ.get(
+ "VLLM_KVPRUNE_PADDING_MODE", "per_head"
+ ).strip().lower()
+ max_pairs_per_batch = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(
+ device=num_tokens_to_retain.device, dtype=num_tokens_to_retain.dtype
+ ) * H
+ num_tokens_to_retain = torch.minimum(num_tokens_to_retain, max_pairs_per_batch)
+
+ indices_topk, candidate_counts = scores_to_retain_indices(
+ scores,
+ cu_seqlens_k=cu_seqlens_k,
+ max_k_len=max_k_len,
+ top_k=top_k,
+ H=H,
+ num_tokens_to_retain=num_tokens_to_retain,
+ page_size=PAGE_SIZE,
+ padding_mode=padding_mode,
+ padding=padding,
+ )
+ prefill_store_topk_kv(
+ new_keys=new_keys,
+ new_vals=new_vals,
+ indices_topk=indices_topk,
+ candidate_counts=candidate_counts,
+ num_tokens_to_retain=num_tokens_to_retain,
+ page_table=page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=bh_lens,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ cu_seqlens_k=cu_seqlens_k,
+ PAGE_SIZE=PAGE_SIZE,
+ PAD_TO_PAGE_SIZE=PAD_TO_PAGE_SIZE,
+ K_TILE=K_TILE,
+ )
+
+
+def scores_to_retain_indices(
+ scores: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_k_len: int,
+ top_k: int,
+ H: int,
+ num_tokens_to_retain: torch.Tensor,
+ page_size: int,
+ padding_mode: str = "per_head",
+ padding: float = -float("inf"),
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Build candidate token-head indices for compression writes.
+
+ For each batch element, this helper returns:
+
+ 1. a prefix of the true global top-k ``(token, head)`` pairs, and
+ 2. a suffix of additional padding candidates according to ``padding_mode``:
+ - ``per_head``: choose each head's highest-scoring remaining tokens.
+ - ``global_scan``: keep the legacy global ranking order and let the
+ store kernel scan forward until it finds enough entries for that head.
+
+ The page-alignment requirement comes from the paged KV cache, but the
+ padding candidates themselves do not need to be discovered inside the
+ Triton store kernel. Choosing them here avoids the older "scan the global
+ candidate list until you stumble across enough entries for this head"
+ behavior, which could distort the retained set even though the page-table
+ / reclaim logic only cares about the final per-head counts.
+
+ Args:
+ :param scores:
+ Tensor of shape ``[N_total, HKV]`` containing scores for each
+ (token, head) pair in packed varlen format.
+ :param cu_seqlens_k:
+ Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
+ lengths for each batch element. The total number of tokens
+ satisfies ``N_total = cu_seqlens_k[-1]``.
+ :param max_k_len:
+ Maximum key sequence length across the batch (i.e.
+ ``max_b seqlen_k[b]``). Used to allocate the padded buffer.
+ :param top_k:
+ Kept for API compatibility with the caller. The retained prefix is
+ determined by ``num_tokens_to_retain``; the tail is built from
+ per-head padding needs.
+ :param H:
+ Number of key heads; must match ``scores.shape[1]``.
+ :param num_tokens_to_retain:
+ The true number of token-head pairs to keep for each batch element
+ before page padding.
+ :param page_size:
+ Page size of the KV cache. Determines how many extra candidates
+ are needed per head to reach page alignment.
+ :param padding_mode:
+ ``per_head`` for per-head optimal padding candidates, or
+ ``global_scan`` for the legacy "scan the global ranking" behavior.
+ :param padding:
+ Kept for backward compatibility; no longer used.
+
+ Returns:
+ A tuple ``(indices, counts)`` where:
+
+ - ``indices`` is ``[B, MAX_SEL]`` int64, containing global flattened
+ ``token * H + head`` indices.
+ - ``counts`` is ``[B]`` int32, the number of valid candidates for each
+ batch row inside ``indices``.
+ """
+ del max_k_len, top_k, padding
+
+ B, device = cu_seqlens_k.numel() - 1, scores.device
+ row_indices: list[torch.Tensor] = []
+ candidate_counts = torch.zeros(B, dtype=torch.int32, device=device)
+ if padding_mode not in ("per_head", "global_scan"):
+ raise ValueError(
+ "Unsupported VLLM_KVPRUNE_PADDING_MODE. "
+ f"Expected 'per_head' or 'global_scan', got {padding_mode!r}."
+ )
+
+ for b in range(B):
+ s = int(cu_seqlens_k[b].item())
+ e = int(cu_seqlens_k[b + 1].item())
+ seq_len = e - s
+ total_pairs = seq_len * H
+ keep = min(int(num_tokens_to_retain[b].item()), total_pairs)
+ if total_pairs == 0 or keep == 0:
+ row_indices.append(torch.empty(0, dtype=torch.int64, device=device))
+ continue
+
+ seq_scores = scores[s:e, :] # [L, H]
+ flat_scores = seq_scores.reshape(-1)
+
+ if padding_mode == "global_scan":
+ row = torch.argsort(flat_scores, dim=0, descending=True)
+ else:
+ prefix = torch.topk(
+ flat_scores, k=keep, dim=0, largest=True, sorted=True
+ ).indices
+
+ selected_flat = torch.zeros(total_pairs, dtype=torch.bool, device=device)
+ selected_flat[prefix] = True
+ selected_mask = selected_flat.view(seq_len, H)
+
+ head_counts = torch.bincount(prefix % H, minlength=H)
+ need_per_head = (page_size - (head_counts % page_size)) % page_size
+ max_extra_per_head = seq_len - head_counts
+ need_per_head = torch.minimum(need_per_head, max_extra_per_head)
+
+ tails: list[torch.Tensor] = []
+ for h in range(H):
+ need = int(need_per_head[h].item())
+ if need <= 0:
+ continue
+ rem_scores_h = seq_scores[:, h].masked_fill(
+ selected_mask[:, h], -torch.inf
+ )
+ tail_tok = torch.topk(
+ rem_scores_h, k=need, dim=0, largest=True, sorted=True
+ ).indices
+ tails.append(tail_tok * H + h)
+
+ if tails:
+ row = torch.cat([prefix, *tails], dim=0)
+ else:
+ row = prefix
+
+ row_indices.append(row + s * H)
+ candidate_counts[b] = int(row.numel())
+
+ max_sel = max((int(x.numel()) for x in row_indices), default=0)
+ if max_sel == 0:
+ return (
+ torch.zeros((B, 1), dtype=torch.int64, device=device),
+ candidate_counts,
+ )
+
+ indices = torch.zeros((B, max_sel), dtype=torch.int64, device=device)
+ for b, row in enumerate(row_indices):
+ if row.numel():
+ indices[b, : row.numel()] = row
+ return indices, candidate_counts
diff --git a/vllm/kvprune/compression/compactor.py b/vllm/kvprune/compression/compactor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4bc22be1fe6132884e5a93925c807a30e0fda14
--- /dev/null
+++ b/vllm/kvprune/compression/compactor.py
@@ -0,0 +1,722 @@
+"""
+Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress``
+算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、
+全局 z-score、blending 与首尾 sink pad)。
+
+非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。
+杠杆分路径使用 batched ``torch.matmul``;在 transpose 与进入线性代数前对张量 ``.contiguous()``。
+CUDA 上用 ``cholesky_solve``;在 HIP/ROCm 上对小的 sketch 维 ``k`` 用 ``linalg.inv(G+λI) @ X^T``
+代替 ``cholesky_solve``,避开 rocBLAS TRSM 的 launch-bounds 告警与部分栈上的不稳定行为。
+非因果 PyTorch 回退同理。
+"""
+
+from __future__ import annotations
+
+import math
+from typing import List, Optional
+
+import torch
+import triton
+import triton.language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+
+
+def resolve_kvpress_compactor_blending(compression_context) -> float:
+ """与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。"""
+ if compression_context is None:
+ return 0.35
+ b = getattr(compression_context, "compactor_blending", None)
+ if b is not None:
+ return float(b)
+ cr = getattr(compression_context, "compression_ratio", None)
+ if cr is not None:
+ return float(cr)
+ return 0.35
+
+
+class CompactorCompression(BaseCompressionMethod):
+ """与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。"""
+
+ chunk_size: int = 256
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ return maybe_execute_in_stream(
+ kvpress_leverage_scores_packed,
+ k,
+ context.cu_seqlens_q,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ blending = resolve_kvpress_compactor_blending(compression_context)
+ return maybe_execute_in_stream(
+ kvpress_compactor_post_rope,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ pre_rope_scores,
+ compression_context,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ blending=float(blending),
+ STORE_STREAM=context.STORE_STREAM,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Cholesky 杠杆分(kvpress ``LeverageScorePress``)
+# ---------------------------------------------------------------------------
+
+
+def chol_with_jitter(
+ G: torch.Tensor, jitter: float = 0.0, max_tries: int = 5
+) -> torch.Tensor:
+ identity = torch.eye(G.shape[-1], device=G.device, dtype=G.dtype)
+ cur = float(jitter)
+ for _ in range(max_tries):
+ L, info = torch.linalg.cholesky_ex(
+ (G + cur * identity).contiguous(), upper=False
+ )
+ if bool((info == 0).all()):
+ return L
+ cur = max(1e-8, (1e-2 if cur == 0.0 else 10.0 * cur))
+ raise RuntimeError(f"Cholesky failed after {max_tries} tries.")
+
+
+def compute_leverage_scores_mid(
+ key_states: torch.Tensor, sketch_dimension: int
+) -> torch.Tensor:
+ """
+ 与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``,
+ 返回 ``[L, H]``(未 z-score)。
+
+ 维序与 kvpress 的 ``(B, H, S, D)`` 对齐;batched GEMM + ``.contiguous()`` 以利于后端库。
+ """
+ d, k = key_states.shape[-1], sketch_dimension
+ device, dtype = key_states.device, key_states.dtype
+ H = key_states.shape[1]
+ Phi = torch.randn(1, H, d, k, device=device, dtype=dtype) * (1.0 / math.sqrt(k))
+
+ X0 = key_states.transpose(0, 1).unsqueeze(0).contiguous()
+ X = (X0 - X0.mean(dim=-2, keepdim=True)).contiguous()
+ Phi = Phi.contiguous()
+ X = torch.matmul(X, Phi).to(torch.float32).contiguous()
+ XT = X.transpose(-2, -1).contiguous()
+ G = torch.matmul(XT, X)
+ G_sym = 0.5 * (G + G.transpose(-2, -1)).contiguous()
+ # HIP: avoid batched cholesky_solve -> rocBLAS TRSM (launch_bounds noise / edge cases).
+ # k is sketch_dim (typically modest); inv is O(k^3) but batched over heads.
+ if torch.version.hip is not None:
+ kk = G_sym.shape[-1]
+ eye = torch.eye(
+ kk, device=G_sym.device, dtype=G_sym.dtype, requires_grad=False
+ )
+ G_reg = G_sym + 1e-2 * eye
+ inv_Xt = torch.linalg.inv(G_reg) @ XT
+ else:
+ L_mat = chol_with_jitter(G_sym, jitter=1e-2, max_tries=5)
+ inv_Xt = torch.cholesky_solve(XT, L_mat, upper=False)
+ inv_Xt_T = inv_Xt.transpose(-2, -1).contiguous()
+ scores = (X * inv_Xt_T).sum(dim=-1).clamp_min(0)
+ return scores.squeeze(0).transpose(0, 1).contiguous()
+
+
+def kvpress_leverage_scores_packed(
+ key_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ compression_ctx,
+) -> torch.Tensor:
+ device = key_states.device
+ N, Hkv, _D = key_states.shape
+ sketch_dim = int(getattr(compression_ctx, "sketch_dimension", 48))
+ sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
+ sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
+
+ out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
+ mids_flat: list[torch.Tensor] = []
+ mid_ranges: list[tuple[int, int, int]] = []
+
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ k_mid = key_states[mid_start:mid_end, :, :].contiguous()
+ raw = compute_leverage_scores_mid(k_mid, sketch_dim)
+ mids_flat.append(raw.reshape(-1))
+ mid_ranges.append((mid_start, mid_end, Hkv))
+
+ if not mids_flat:
+ return out
+
+ flat = torch.cat(mids_flat, dim=0)
+ z = _zscore_flat_f32_global(flat)
+ offset = 0
+ for (mid_start, mid_end, _Hkv), r in zip(mid_ranges, mids_flat):
+ n = r.numel()
+ seg = z[offset : offset + n].view(mid_end - mid_start, Hkv)
+ out[mid_start:mid_end, :] = seg
+ offset += n
+ return out
+
+
+# ---------------------------------------------------------------------------
+# 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton
+# ---------------------------------------------------------------------------
+
+
+def _non_causal_chunked_attn_pytorch(
+ q: torch.Tensor, k: torch.Tensor, chunk_size: int
+) -> torch.Tensor:
+ """参考实现:与 kvpress 逐算子一致。"""
+ assert chunk_size > 0 and q.shape == k.shape
+ L, H, d = q.shape
+ B = 1
+ q = q.permute(1, 0, 2).unsqueeze(0).contiguous()
+ k = k.permute(1, 0, 2).unsqueeze(0).contiguous()
+ _B, H, S, _d = k.shape
+ S_pad = math.ceil(S / chunk_size) * chunk_size
+ pad_len = S_pad - S
+
+ if pad_len > 0:
+ q_padded = torch.cat(
+ [q, torch.zeros(B, H, pad_len, d, device=q.device, dtype=q.dtype)], dim=2
+ )
+ k_padded = torch.cat(
+ [k, torch.zeros(B, H, pad_len, d, device=k.device, dtype=k.dtype)], dim=2
+ )
+ last_chunk_start = (S // chunk_size) * chunk_size
+ in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
+ query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
+ else:
+ q_padded, k_padded = q, k
+ last_chunk_start = ((S - 1) // chunk_size) * chunk_size
+ in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
+ query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
+
+ num_chunks = S_pad // chunk_size
+ q_chunks = q_padded.contiguous().view(B, H, num_chunks, chunk_size, d)
+ k_chunks = k_padded.contiguous().view(B, H, num_chunks, chunk_size, d)
+ dots = torch.matmul(
+ q_chunks, k_chunks.transpose(-2, -1).contiguous()
+ )
+ dots[:, :, -1].masked_fill_(query_mask.unsqueeze(-1), 0)
+ dots[:, :, -1].masked_fill_(key_mask.unsqueeze(-2), -1e-9)
+ attn = torch.softmax(dots.to(torch.float32), dim=-1)
+ out = attn.sum(dim=-2).view(B, H, S_pad)[..., :S]
+ return out.squeeze(0).transpose(0, 1).contiguous()
+
+
+@triton.jit
+def _non_causal_chunk_row_kernel(
+ Q_ptr,
+ K_ptr,
+ Out_ptr,
+ stride_qh,
+ stride_qs,
+ stride_qd,
+ stride_kh,
+ stride_ks,
+ stride_kd,
+ stride_oh,
+ stride_os,
+ S,
+ S_pad,
+ num_chunks,
+ CHUNK_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ ND: tl.constexpr,
+):
+ """
+ 每个 program:一个 head、一个 chunk、一条 query 行。
+ 对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。
+ """
+ h = tl.program_id(0)
+ c = tl.program_id(1)
+ iq = tl.program_id(2)
+
+ g_i = c * CHUNK_SIZE + iq
+
+ offs_j = tl.arange(0, CHUNK_SIZE)
+ logits = tl.zeros([CHUNK_SIZE], dtype=tl.float32)
+
+ for db in range(ND):
+ offs_d = tl.arange(0, BLOCK_D) + db * BLOCK_D
+ mask_d = offs_d < D
+ q_off = (
+ h * stride_qh + g_i * stride_qs + offs_d * stride_qd
+ )
+ qd = tl.load(Q_ptr + q_off, mask=mask_d, other=0.0).to(tl.float32)
+
+ g_j = c * CHUNK_SIZE + offs_j
+ k_row_off = h * stride_kh + g_j[:, None] * stride_ks + offs_d[None, :] * stride_kd
+ kj = tl.load(K_ptr + k_row_off, mask=mask_d[None, :], other=0.0).to(tl.float32)
+ logits += tl.sum(qd[None, :] * kj, axis=1)
+
+ row_invalid = g_i >= S
+ g_j_all = c * CHUNK_SIZE + offs_j
+ col_invalid = g_j_all >= S
+
+ logits = tl.where(row_invalid, tl.zeros([CHUNK_SIZE], dtype=tl.float32), logits)
+ logits = tl.where(
+ row_invalid,
+ logits,
+ tl.where(col_invalid, tl.full([CHUNK_SIZE], -1e-9, dtype=tl.float32), logits),
+ )
+
+ m = tl.max(logits)
+ logits = logits - m
+ exp_v = tl.exp(logits)
+ denom = tl.sum(exp_v)
+ p = exp_v / denom
+
+ out_base = h * stride_oh + g_j_all * stride_os
+ tl.atomic_add(Out_ptr + out_base, p, mask=g_j_all < S)
+
+
+def _non_causal_chunked_attn_triton(
+ q: torch.Tensor, k: torch.Tensor, chunk_size: int
+) -> torch.Tensor:
+ """CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。"""
+ assert q.is_cuda and k.is_cuda and q.shape == k.shape
+ L, H, d = q.shape
+ assert chunk_size > 0
+ S_pad = math.ceil(L / chunk_size) * chunk_size
+ pad_len = S_pad - L
+ if pad_len > 0:
+ zq = torch.zeros(
+ pad_len, H, d, device=q.device, dtype=q.dtype, requires_grad=False
+ )
+ zk = torch.zeros(
+ pad_len, H, d, device=k.device, dtype=k.dtype, requires_grad=False
+ )
+ q = torch.cat([q, zq], dim=0)
+ k = torch.cat([k, zk], dim=0)
+
+ Q = q.transpose(0, 1).contiguous().to(dtype=torch.float32)
+ K = k.transpose(0, 1).contiguous().to(dtype=torch.float32)
+
+ num_chunks = S_pad // chunk_size
+ out_acc = torch.zeros(H, S_pad, device=q.device, dtype=torch.float32)
+
+ S = int(L)
+ grid = (H, num_chunks, chunk_size)
+ BLOCK_D = 32 if d <= 128 else 64
+ ND = (d + BLOCK_D - 1) // BLOCK_D
+ _non_causal_chunk_row_kernel[grid](
+ Q,
+ K,
+ out_acc,
+ Q.stride(0),
+ Q.stride(1),
+ Q.stride(2),
+ K.stride(0),
+ K.stride(1),
+ K.stride(2),
+ out_acc.stride(0),
+ out_acc.stride(1),
+ S,
+ S_pad,
+ int(num_chunks),
+ CHUNK_SIZE=chunk_size,
+ D=d,
+ BLOCK_D=BLOCK_D,
+ ND=ND,
+ num_warps=4,
+ )
+ return out_acc[:, :S].transpose(0, 1).contiguous()
+
+
+def non_causal_chunked_attn(q: torch.Tensor, k: torch.Tensor, chunk_size: int) -> torch.Tensor:
+ """q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。"""
+ if q.is_cuda and k.is_cuda:
+ return _non_causal_chunked_attn_triton(q, k, chunk_size)
+ return _non_causal_chunked_attn_pytorch(q, k, chunk_size)
+
+
+# ---------------------------------------------------------------------------
+# ×||V|| + avg_pool1d(k=3) — Triton(CUDA)
+# ---------------------------------------------------------------------------
+
+
+@triton.jit
+def _mul_vnorm_avgpool3_kernel(
+ A_ptr,
+ V_ptr,
+ OUT_ptr,
+ stride_al,
+ stride_ah,
+ stride_vl,
+ stride_vh,
+ stride_vd,
+ stride_ol,
+ stride_oh,
+ L,
+ D: tl.constexpr,
+):
+ """Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。"""
+ l = tl.program_id(0)
+ h = tl.program_id(1)
+ offs = tl.arange(0, D)
+
+ pos_m1 = l - 1
+ inb_m1 = (pos_m1 >= 0) & (pos_m1 < L)
+ ps_m1 = tl.where(inb_m1, pos_m1, 0)
+ a_m1 = tl.load(
+ A_ptr + ps_m1 * stride_al + h * stride_ah,
+ mask=inb_m1,
+ other=0.0,
+ ).to(tl.float32)
+ v_m1 = tl.load(
+ V_ptr + ps_m1 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_m1,
+ other=0.0,
+ ).to(tl.float32)
+ s_m1 = tl.where(inb_m1, a_m1 * tl.sqrt(tl.sum(v_m1 * v_m1)), 0.0)
+
+ inb_0 = (l >= 0) & (l < L)
+ ps0 = tl.where(inb_0, l, 0)
+ a0 = tl.load(
+ A_ptr + ps0 * stride_al + h * stride_ah,
+ mask=inb_0,
+ other=0.0,
+ ).to(tl.float32)
+ v0 = tl.load(
+ V_ptr + ps0 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_0,
+ other=0.0,
+ ).to(tl.float32)
+ s_0 = tl.where(inb_0, a0 * tl.sqrt(tl.sum(v0 * v0)), 0.0)
+
+ pos_p1 = l + 1
+ inb_p1 = (pos_p1 >= 0) & (pos_p1 < L)
+ ps_p1 = tl.where(inb_p1, pos_p1, 0)
+ a_p1 = tl.load(
+ A_ptr + ps_p1 * stride_al + h * stride_ah,
+ mask=inb_p1,
+ other=0.0,
+ ).to(tl.float32)
+ v_p1 = tl.load(
+ V_ptr + ps_p1 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_p1,
+ other=0.0,
+ ).to(tl.float32)
+ s_p1 = tl.where(inb_p1, a_p1 * tl.sqrt(tl.sum(v_p1 * v_p1)), 0.0)
+
+ out = (s_m1 + s_0 + s_p1) * (1.0 / 3.0)
+ tl.store(OUT_ptr + l * stride_ol + h * stride_oh, out)
+
+
+def _mul_vnorm_avgpool3_fused(
+ a: torch.Tensor, v: torch.Tensor, out: torch.Tensor | None = None
+) -> torch.Tensor:
+ assert a.dim() == 2 and v.dim() == 3 and a.shape[0] == v.shape[0] and a.shape[1] == v.shape[1]
+ L, H, D = v.shape
+ a = a.contiguous()
+ v = v.contiguous()
+ if a.dtype != torch.float32:
+ a = a.float()
+ if out is None:
+ out = torch.empty((L, H), device=v.device, dtype=torch.float32)
+ if L == 0 or H == 0:
+ return out
+ grid = (L, H)
+ _mul_vnorm_avgpool3_kernel[grid](
+ a,
+ v,
+ out,
+ a.stride(0),
+ a.stride(1),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ out.stride(0),
+ out.stride(1),
+ L,
+ D=D,
+ num_warps=4,
+ )
+ return out
+
+
+def _maybe_mul_vnorm_avgpool3_fused(a: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ if not a.is_cuda or not v.is_cuda:
+ import torch.nn.functional as F
+
+ s = a * v.norm(dim=-1)
+ return (
+ F.avg_pool1d(s.transpose(0, 1).unsqueeze(0), kernel_size=3, padding=1, stride=1)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
+ return _mul_vnorm_avgpool3_fused(a, v)
+
+
+@triton.jit
+def _zscore_elem_1d_kernel(
+ X_ptr,
+ OUT_ptr,
+ n,
+ mean,
+ inv_std,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+ mask = offs < n
+ x = tl.load(X_ptr + offs, mask=mask, other=0.0)
+ tl.store(OUT_ptr + offs, (x - mean) * inv_std, mask=mask)
+
+
+def _zscore_flat_f32_global(x: torch.Tensor) -> torch.Tensor:
+ """
+ 与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。
+ ``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。
+ """
+ if x.numel() == 0:
+ return x
+ mu = x.mean()
+ sig = x.std().clamp_min(1e-6)
+ inv = 1.0 / sig
+ if not x.is_cuda:
+ return (x - mu) * inv
+ x = x.contiguous()
+ out = torch.empty_like(x)
+ n = x.numel()
+ BLOCK = 1024
+ grid = (triton.cdiv(n, BLOCK),)
+ _zscore_elem_1d_kernel[grid](
+ x,
+ out,
+ n,
+ float(mu.item()),
+ float(inv.item()),
+ BLOCK=BLOCK,
+ num_warps=4,
+ )
+ return out
+
+
+def _attn_scores_kvpress_middle(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ sink_start: int,
+ sink_end: int,
+ chunk_size: int,
+ do_zscore: bool = True,
+) -> torch.Tensor:
+ """仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。"""
+ N, HQ, D = q.shape
+ Hkv = k.shape[1]
+ G = HQ // Hkv
+ device = q.device
+ attn_out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
+ parts: list[torch.Tensor] = []
+
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ q_m = q[mid_start:mid_end, :, :].contiguous()
+ k_m = k[mid_start:mid_end, :, :].contiguous()
+ v_m = v[mid_start:mid_end, :, :].contiguous()
+ # HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]``
+ k_4d = k_m.unsqueeze(0).transpose(1, 2).contiguous() # [1, Hkv, Lm, D]
+ k_rep = repeat_kv(k_4d, G)[0].transpose(0, 1).contiguous() # [Lm, HQ, D]
+ A = non_causal_chunked_attn(q_m, k_rep, chunk_size)
+ Lm, HQa = A.shape
+ assert HQa == HQ
+ A = A.view(Lm, Hkv, G).mean(dim=-1)
+ scores = _maybe_mul_vnorm_avgpool3_fused(A, v_m)
+ parts.append(scores.reshape(-1))
+
+ if not parts:
+ return attn_out
+
+ flat_a = torch.cat(parts, dim=0)
+ if do_zscore:
+ z_a = _zscore_flat_f32_global(flat_a)
+ else:
+ z_a = flat_a
+ offset = 0
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ n = (mid_end - mid_start) * Hkv
+ attn_out[mid_start:mid_end, :] = z_a[offset : offset + n].view(
+ mid_end - mid_start, Hkv
+ )
+ offset += n
+ return attn_out
+
+
+def non_causal_attn_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens_qk: torch.Tensor,
+ max_seqlen_qk: int,
+ chunk_size: int,
+ sm_scale: float = None,
+ normalize: bool = True,
+ context_lens: Optional[List[int]] = None,
+ protected_first_tokens: Optional[List[int]] = None,
+ protected_last_tokens: Optional[List[int]] = None,
+ *,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+) -> torch.Tensor:
+ """
+ 与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。
+ ``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。
+ 然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。
+ """
+ del sm_scale, max_seqlen_qk
+ sink_start, sink_end = 8, 4
+ out = _attn_scores_kvpress_middle(
+ q,
+ k,
+ v,
+ cu_seqlens_qk,
+ sink_start,
+ sink_end,
+ chunk_size,
+ do_zscore=normalize,
+ )
+
+ if accum_scores is not None:
+ w = 0.5 if accum_blending is None else float(accum_blending)
+ out = out + w * accum_scores.to(device=out.device, dtype=out.dtype)
+
+ if protected_first_tokens is not None and protected_last_tokens is not None and context_lens:
+ start = 0
+ for first, last, Lc in zip(
+ protected_first_tokens, protected_last_tokens, context_lens
+ ):
+ out[start : start + int(first)].fill_(torch.inf)
+ out[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
+ start += int(Lc)
+ return out
+
+
+def kvpress_compactor_post_rope(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ compression_ctx,
+ max_seqlen_q: int,
+ chunk_size: int,
+ blending: float,
+) -> torch.Tensor:
+ del max_seqlen_q
+ Hkv = k.shape[1]
+ device = q.device
+
+ sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
+ sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
+ context_lens: Optional[List[int]] = getattr(
+ compression_ctx, "context_lens", None
+ )
+ protected_first: Optional[List[int]] = getattr(
+ compression_ctx, "protected_first_tokens", None
+ )
+ protected_last: Optional[List[int]] = getattr(
+ compression_ctx, "protected_last_tokens", None
+ )
+
+ attn_out = _attn_scores_kvpress_middle(
+ q, k, v, cu_seqlens, sink_start, sink_end, chunk_size
+ )
+ lev = pre_rope_scores.to(device=device, dtype=torch.float32)
+ blended = torch.zeros_like(lev)
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ blended[mid_start:mid_end, :] = (
+ blending * lev[mid_start:mid_end, :] + attn_out[mid_start:mid_end, :]
+ )
+
+ pad_val = blended.max()
+ if not torch.isfinite(pad_val) or pad_val == 0:
+ pad_val = torch.tensor(1.0, device=device, dtype=torch.float32)
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if left_keep > 0:
+ blended[k_beg:mid_start, :] = pad_val
+ if right_keep > 0:
+ blended[mid_end:k_end, :] = pad_val
+
+ if protected_first is not None and protected_last is not None and context_lens:
+ start = 0
+ for first, last, Lc in zip(
+ protected_first, protected_last, context_lens
+ ):
+ blended[start : start + int(first)].fill_(torch.inf)
+ blended[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
+ start += int(Lc)
+
+ return blended
+
diff --git a/vllm/kvprune/compression/compactor_origin.py b/vllm/kvprune/compression/compactor_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a8871bfdcdf24171af87e46bea17c5af2d7f947
--- /dev/null
+++ b/vllm/kvprune/compression/compactor_origin.py
@@ -0,0 +1,600 @@
+import logging
+import math
+from typing import List, Optional
+
+import torch
+import triton
+from tqdm.contrib.logging import logging_redirect_tqdm
+from triton import language as tl
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+logger = logging.getLogger(__name__)
+
+
+class CompactorCompression(BaseCompressionMethod):
+ chunk_size: int = 128
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ scores = maybe_execute_in_stream(
+ approximate_leverage_scores,
+ k,
+ compression_context.context_lens,
+ compression_context.PHI,
+ normalize=True,
+ chunk_size=compression_context.compression_chunk_size,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ return maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+
+def split_into_chunks(xs, chunk_size):
+ """
+ Convert a list of sequence lengths into a sequence of coalesced chunk lengths.
+
+ Given an iterable of per-sequence context lengths ``xs`` and a target ``chunk_size``,
+ this helper produces two parallel lists:
+
+ * ``coalesced_chunks`` – lengths of contiguous segments in the
+ **concatenated** sequence space, where each segment corresponds either
+ to a full chunk of size ``chunk_size`` or to a residual "epilogue"
+ tail shorter than ``chunk_size``.
+
+ * ``chunks`` – the actual chunk sizes used within each original sequence.
+ For a length ``n``, we produce ``n // chunk_size`` entries of
+ ``chunk_size`` (the "prologue") and at most one final entry equal to
+ ``n % chunk_size`` (the "epilogue").
+
+ ``chunks`` reflects how each input length is decomposed into
+ fixed-size (plus optional tail) processing blocks, while
+ ``coalesced_chunks`` describes those same blocks after concatenating consecutive
+ chunks of size ``chunk_size``. together
+
+ Example:
+ xs = [257, 127], chunk_size = 128
+ coalesced_chunks = [256, 1, 127]
+ chunks = [128, 128, 1, 127]
+
+ Args:
+ :param xs:
+ Iterable of non-negative integers
+ :param chunk_size:
+ Target chunk size
+
+ Returns:
+ :return Tuple[List[int], List[int]]:
+ ``(coalesced_chunks, chunks)`` as described above.
+ """
+ coalesced_chunks, chunks = [], []
+ for n in xs:
+ nchunks = n // chunk_size
+ prologue = nchunks * chunk_size
+ epilogue = n - prologue
+ if prologue > 0:
+ coalesced_chunks.append(prologue)
+ chunks.extend([chunk_size] * nchunks)
+ if epilogue > 0:
+ coalesced_chunks.append(epilogue)
+ chunks.append(epilogue)
+ return coalesced_chunks, chunks
+
+
+def approximate_leverage_scores(
+ key_states: torch.Tensor, # [N, H, D]
+ context_lens: List[int], # [B]
+ PHI: torch.Tensor, # [D, k]
+ regularizer: float = 5e-3,
+ normalize: bool = False,
+ chunk_size: int = 512,
+) -> torch.Tensor: # returns [N, H]
+ """
+ Approximate leverage scores for keys via randomized sketching.
+
+ This implements a randomized approximation to per-token leverage scores for
+ the key matrix, as described in Compactor: Calibrated Query-Agnostic KV Cache
+ Compression with Approximate Leverage Scores (https://arxiv.org/abs/2507.08143).
+ Args:
+ :param key_states:
+ Tensor of shape ``[N, H, D]`` containing pre-RoPE key states for
+ all tokens across the batch, packed along the sequence dimension.
+ ``N = sum(context_lens)``.
+ :param context_lens:
+ List of per-sequence context lengths, length ``B``.
+ :param PHI:
+ Random projection matrix of shape ``[D, k]`` used to sketch the
+ keys into a lower-dimensional subspace (k < D).
+ :param regularizer:
+ Small positive scalar added to the diagonal of each Gram matrix
+ before SVD to improve numerical stability. Defaults to ``1e-2``.
+ :param normalize:
+ If True, apply per-sequence z-score normalization to the scores
+ across all heads and tokens in a batch.
+ :param chunk_size:
+ Target chunk size along the sequence dimension. If > 0, the
+ concatenated sequence is split into chunks of at most this size
+ before forming Gram matrices and SVD. If ≤ 0, the entire sequence
+ for each context is treated as a single chunk.
+ Returns:
+ :return torch.Tensor:
+ Approximate leverage scores of shape ``[N, H]``, where each row
+ corresponds to a token and each column to a head.
+ """
+ if chunk_size > 0:
+ coalesced_chunk_lens, chunks_lens = split_into_chunks(context_lens, chunk_size)
+ else:
+ coalesced_chunk_lens, chunks_lens = context_lens, context_lens
+ chunk_lens_cuda = torch.tensor([0] + chunks_lens).cuda(non_blocking=True)
+ X = torch.matmul(key_states.transpose(0, 1).contiguous(), PHI.contiguous())
+ H, N, k = X.shape
+ chunks = torch.split(X, coalesced_chunk_lens, dim=-2)
+ gram_matrices = []
+ for i, L in enumerate(coalesced_chunk_lens):
+ chunk = chunks[i]
+ if chunk_size <= 0 or L % chunk_size != 0:
+ chunk.sub_(chunk.mean(dim=-2, keepdim=True))
+ g = torch.matmul(chunk.transpose(-1, -2).contiguous(), chunk.contiguous())
+ g = g.unsqueeze(1)
+ else:
+ chunk = chunk.view(H, -1, chunk_size, k) # [H, num_chunks, chunk_size, k]
+ chunk.sub_(chunk.mean(dim=-2, keepdim=True))
+ g = torch.matmul(chunk.transpose(-1, -2).contiguous(), chunk.contiguous())
+ gram_matrices.append(g)
+ G = torch.cat(gram_matrices, dim=1).to(torch.float32)
+ diag = G.diagonal(dim1=-2, dim2=-1)
+ diag.add_(regularizer)
+ try:
+ V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
+ except RuntimeError:
+ try:
+ diag = G.diagonal(dim1=-2, dim2=-1)
+ diag.add_(regularizer * 10)
+ V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
+ except RuntimeError:
+ with logging_redirect_tqdm():
+ logger.warning(
+ "GESVDA failed, falling back to QR decomposition, which will be MUCH slower. "
+ "Try increasing chunk_size if this issue persists."
+ )
+ # this is over 50 times slower than using GESVDA
+ return _approximate_leverage_scores_qr_fallback(
+ X=X,
+ chunks_lens=chunks_lens,
+ chunk_lens_cuda=chunk_lens_cuda,
+ normalize=normalize,
+ chunk_size=chunk_size,
+ )
+ SV = (V * S.rsqrt().unsqueeze(-2)).to(X.dtype)
+ start = 0
+ all_scores = []
+ for i, L in enumerate(coalesced_chunk_lens):
+ chunk = chunks[i]
+ if chunk_size <= 0 or L % chunk_size != 0:
+ num_chunks = 1
+ sv = SV[:, start]
+ else:
+ num_chunks = L // chunk_size
+ chunk = chunk.view(H, -1, chunk_size, k) # [H, NC, CS]
+ sv = SV[:, start : start + num_chunks]
+ U = torch.matmul(chunk.contiguous(), sv.contiguous())
+ scores = (U * U).sum(dim=-1).clamp_min_(0.0).view(H, -1)
+ all_scores.append(scores.transpose(-1, -2))
+ start += num_chunks
+
+ scores = torch.cat(all_scores, dim=0)
+ if normalize:
+ grid = (len(chunks_lens),)
+ cu_k = chunk_lens_cuda.cumsum(dim=0)
+ _zscore_per_batch_epilogue_no_window[grid](
+ scores, cu_k, scores.stride(0), scores.stride(1), H
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue_no_window(
+ OUT, # [Nk, Hk], float32
+ cu_k, # [B+1] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ if k_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+def _approximate_leverage_scores_qr_fallback(
+ X: torch.Tensor, # [H, N, k], already sketched (KΦ) and centered in-place
+ chunks_lens: List[int], # [num_chunks]
+ chunk_lens_cuda: torch.Tensor, # [num_chunks + 1] (prefix base)
+ normalize: bool,
+ chunk_size: int,
+) -> torch.Tensor:
+ H, N, k = X.shape
+ device, dtype = X.device, X.dtype
+ offsets: List[int] = []
+ offset = 0
+ for L in chunks_lens:
+ offsets.append(offset)
+ offset += L
+ if offset != N:
+ raise RuntimeError(
+ f"QR fallback: sum(chunks_lens)={offset} does not match N={N}"
+ )
+
+ blocks = torch.split(X, chunks_lens, dim=-2)
+ scores = torch.empty(N, H, device=device, dtype=dtype)
+ if chunk_size > 0:
+ full_indices = [i for i, L in enumerate(chunks_lens) if L == chunk_size]
+ epi_indices = [i for i, L in enumerate(chunks_lens) if L != chunk_size]
+
+ if full_indices:
+ # stack full chunks
+ full_blocks = torch.stack(
+ [blocks[i] for i in full_indices], dim=0
+ ) # [M, H, CS, k]
+ M, Hf, Lf, kf = full_blocks.shape
+ assert Lf == chunk_size
+
+ # merge (M, H) into a single batch dim for torch.linalg.q
+ full_blocks_2d = full_blocks.view(M * Hf, Lf, kf).to(torch.float32)
+
+ U_full, _ = torch.linalg.qr(full_blocks_2d, mode="reduced")
+ U_full = U_full.to(dtype)
+ scores_full = (U_full * U_full).sum(dim=-1).clamp_min(0.0) # [M * Hf, Lf]
+ scores_full = scores_full.view(M, Hf, Lf).transpose(-1, -2) # [M, H, CS]
+ for m, chunk_idx in enumerate(full_indices):
+ start = offsets[chunk_idx]
+ Lc = chunks_lens[chunk_idx]
+ scores[start : start + Lc].copy_(scores_full[m])
+ else:
+ epi_indices = list(range(len(chunks_lens)))
+
+ for chunk_idx in epi_indices:
+ block = blocks[chunk_idx]
+ _, Lc, _ = block.shape
+ if Lc == 0:
+ continue
+ U_epi, _ = torch.linalg.qr(block.to(torch.float32), mode="reduced")
+ scores_epi = (U_epi * U_epi).sum(dim=-1).to(dtype) # [H, Lc]
+ start = offsets[chunk_idx]
+ scores[start : start + Lc] = scores_epi.transpose(0, 1) # [Lc, H]
+
+ if normalize:
+ grid = (len(chunks_lens),)
+ cu_k = chunk_lens_cuda.cumsum(dim=0)
+ _zscore_per_batch_epilogue_no_window[grid](
+ scores, cu_k, scores.stride(0), scores.stride(1), H
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_M": BM, "BLOCK_K": BK, "WARPSPEC": False}, num_warps=w, num_stages=s
+ )
+ for BM in [64]
+ for BK in [64]
+ for w in [4]
+ for s in [2]
+ ],
+ key=[
+ "QUERY_GROUP_SIZE",
+ "D",
+ "CHUNK_SIZE",
+ ],
+ cache_results=True,
+)
+@triton.jit
+def _non_causal_attn_kernel(
+ Q,
+ K,
+ V,
+ accum_scores,
+ cu_seqlens_qk,
+ #
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_Q_D,
+ STRIDE_K_G,
+ STRIDE_K_N,
+ STRIDE_K_D,
+ STRIDE_V_G,
+ STRIDE_V_N,
+ STRIDE_V_D,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ #
+ CHUNK_SIZE: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ D: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+):
+ TOTAL_QUERIES_PER_BLOCK: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
+ INVERSE_CHUNK: tl.constexpr = 1.0 / CHUNK_SIZE
+ pid_g = tl.program_id(0) # KV head in [0, HKV)
+ pid_b = tl.program_id(1) # batch id
+ pid_m = tl.program_id(2) # chunk id within batch
+
+ off_b = tl.load(cu_seqlens_qk + pid_b)
+ off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
+
+ chunk_start = off_b + pid_m * CHUNK_SIZE
+ chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, off_b1)
+ M = chunk_end - chunk_start
+ if M <= 0:
+ return
+
+ offs_d = tl.arange(0, D)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ # Flattened query rows inside a [BLOCK_M, QUERY_GROUP_SIZE] tile
+ offs_q = tl.arange(0, TOTAL_QUERIES_PER_BLOCK)
+ row_m = offs_q % BLOCK_M # token offset in this tile
+ row_h = offs_q // BLOCK_M # query-group index
+
+ qk_scale = sm_scale * 1.44269504 # convert to log2-domain
+ NEG_INF = -1.0e9
+
+ # Iterate over query tiles within this chunk
+ for qs in tl.range(chunk_start, chunk_end, BLOCK_M):
+ # Global query indices for rows in this tile
+ q_idx = qs + row_m # [TOTAL_QUERIES_PER_BLOCK]
+ q_mask = q_idx < chunk_end # mask for valid rows in this tile
+
+ # Load Q tile: [TOTAL_QUERIES_PER_BLOCK, D]
+ q_ptrs = (
+ Q
+ + pid_g * STRIDE_Q_G
+ + q_idx[:, None] * STRIDE_Q_N
+ + row_h[:, None] * STRIDE_Q_H
+ + offs_d[None, :] * STRIDE_Q_D
+ )
+ q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0)
+
+ # ---- Pass 1: per-row max and denominator over all keys in this chunk ----
+ row_max = tl.full([TOTAL_QUERIES_PER_BLOCK], NEG_INF, tl.float32)
+ row_sum = tl.zeros([TOTAL_QUERIES_PER_BLOCK], dtype=tl.float32)
+
+ for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
+ k_idx = ks + offs_k # [BLOCK_K]
+ k_mask = k_idx < chunk_end # which keys are valid in this tile
+
+ k_ptrs = (
+ K
+ + pid_g * STRIDE_K_G
+ + k_idx[:, None] * STRIDE_K_N
+ + offs_d[None, :] * STRIDE_K_D
+ )
+ k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # [BLOCK_K, D]
+
+ # logits: [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)
+
+ cur_max = tl.max(qk, 1)
+ new_max = tl.maximum(row_max, cur_max)
+
+ # rescale previous sum to new_max (base 2)
+ rescale = tl.math.exp2(row_max - new_max)
+ p = tl.math.exp2(qk - new_max[:, None])
+
+ row_sum = row_sum * rescale + tl.sum(p, 1)
+ row_max = new_max
+
+ # Avoid division by zero for inactive rows
+ denom = tl.where(q_mask, row_sum, 1.0)
+
+ for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
+ k_idx = ks + offs_k
+ k_mask = k_idx < chunk_end
+
+ k_ptrs = (
+ K
+ + pid_g * STRIDE_K_G
+ + k_idx[:, None] * STRIDE_K_N
+ + offs_d[None, :] * STRIDE_K_D
+ )
+ k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)
+
+ # p has shape [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
+ p = tl.math.exp2(qk - row_max[:, None]) / denom[:, None]
+ # zero-out invalid rows / columns
+ p = tl.where(
+ q_mask[:, None], p, INVERSE_CHUNK
+ ) # preserve attention mass in shorter chunks
+
+ contrib = tl.sum(p, 0) # [BLOCK_K], sum over queries & query-groups
+
+ out_ptrs = accum_scores + k_idx * STRIDE_OUT_N + pid_g * STRIDE_OUT_H
+ old = tl.load(out_ptrs, mask=k_mask, other=0.0)
+ new = old + contrib.to(old.dtype)
+ tl.store(out_ptrs, new, mask=k_mask)
+
+
+def non_causal_attn_scores(
+ q: torch.Tensor, # [N, HQ, D]
+ k: torch.Tensor, # [N, HKV, D]
+ v: torch.Tensor, # [N, HKV, D]
+ cu_seqlens_qk: torch.Tensor, # [B + 1]
+ max_seqlen_qk: int,
+ chunk_size: int,
+ sm_scale: float = None,
+ normalize: bool = True,
+ context_lens: Optional[List[int]] = None,
+ protected_first_tokens: Optional[List[int]] = None,
+ protected_last_tokens: Optional[List[int]] = None,
+ *,
+ accum_scores: torch.Tensor = None, # [N, HKV] (float32)
+ accum_blending: float = None,
+) -> torch.Tensor:
+ """
+ :param q: Tensor of shape ``[N, H, D]`` containing post-rope queries
+ :param k: Tensor of shape ``[N, H, D]`` containing post-rope keys
+ :param v: Tensor of shape ``[N, H, D]`` containing values
+ :param cu_seqlens_qk Tensor of shape ``[B + 1]`` demarcating batch boundaries
+ :param max_seqlen_qk int containing the maximum sequence length
+ :param chunk_size: int specifying the size of the chunk to perform non-causal attention over
+ :param sm_scale: float specifying the scaling factor applied to attention scores (1/sqrt(D) if None)
+ :param normalize: bool specifying whether to z-score normalize final attention scores
+ :param context_lens: List[int] specifying the context lengths. CPU version of cu_seqlens_qk.diff(0)
+ :param protected_first_tokens: List[int] specifying how many tokens should be protected at the
+ start of each sequence
+ :param protected_last_tokens: List[int] specifying how many tokens should be protected at the
+ end of each sequence
+ :param accum_scores: Tensor of shape ``[N, H]`` containing key scores that should be accumulated into
+ :param accum_blending float specifying the scaling of ``accum_scores`` prior to adding the new
+ non-causal attention scores. Final output is equivalent to return out + accum_blending * accum_scores
+ """
+ assert q.ndim == 3 and k.ndim == 3
+ assert q.shape[0] == k.shape[0] and q.shape[-1] == k.shape[-1]
+ N, HQ, D = q.shape
+ HKV = k.shape[1]
+ assert HQ % HKV == 0, "Number of query heads must divide number of KV heads"
+ assert (D & (D - 1)) == 0, "D must be a power of two"
+
+ B = cu_seqlens_qk.numel() - 1
+ H_g = HQ // HKV # query-group size per KV head
+
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+ out = torch.zeros(N, HKV, device=q.device, dtype=torch.float32)
+ q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+ k = k.view(N, HKV, D).permute(1, 0, 2)
+ # v = v.view(N, HKV, D).permute(1, 0, 2)
+
+ if cu_seqlens_qk.device != q.device:
+ cu_seqlens_qk = cu_seqlens_qk.to(device=q.device)
+ cu_seqlens_qk = cu_seqlens_qk.to(torch.int32)
+
+ STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
+ STRIDE_K_G, STRIDE_K_N, STRIDE_K_D = k.stride()
+ STRIDE_V_G, STRIDE_V_N, STRIDE_V_D = v.stride()
+ STRIDE_OUT_N, STRIDE_OUT_H = out.stride()
+
+ assert STRIDE_Q_D == 1 and STRIDE_K_D == 1, "last dim must be contiguous"
+
+ def grid(_):
+ return (
+ HKV,
+ B,
+ triton.cdiv(max_seqlen_qk, chunk_size),
+ )
+
+ _non_causal_attn_kernel[grid](
+ q,
+ k,
+ v,
+ out,
+ cu_seqlens_qk,
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_Q_D,
+ STRIDE_K_G,
+ STRIDE_K_N,
+ STRIDE_K_D,
+ STRIDE_V_G,
+ STRIDE_V_N,
+ STRIDE_V_D,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ CHUNK_SIZE=chunk_size,
+ QUERY_GROUP_SIZE=H_g,
+ D=D,
+ )
+ if normalize:
+ grid = (B,)
+ _zscore_per_batch_epilogue_no_window[grid](
+ out, cu_seqlens_qk, out.stride(0), out.stride(1), HKV
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ out += accum_scores * accum_blending
+ else:
+ out += accum_scores
+ if protected_first_tokens is not None or protected_last_tokens is not None:
+ start = 0
+ for first, last, L in zip(
+ protected_first_tokens, protected_last_tokens, context_lens
+ ):
+ out[start : start + first].fill_(torch.inf)
+ out[start + L - last : start + L].fill_(torch.inf)
+ start += L
+ return out
diff --git a/vllm/kvprune/compression/compression_config.py b/vllm/kvprune/compression/compression_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e861e663644b0ff6e9d0d2641e6940e6514bf6f3
--- /dev/null
+++ b/vllm/kvprune/compression/compression_config.py
@@ -0,0 +1,45 @@
+import logging
+from dataclasses import dataclass
+from enum import Enum, auto
+
+logger = logging.getLogger(__name__)
+
+
+class CompressionMethod(Enum):
+ CRITICALADAKV = auto()
+ COMPACTOR = auto()
+ SNAPKV = auto()
+ NONE = auto()
+
+
+# class CachingPolicy(Enum):
+# CACHE_PROMPT = auto()
+# DONT_CACHE = auto()
+
+
+# class CompressionType(Enum):
+# QUERY_AWARE = auto()
+# QUERY_AGNOSTIC = auto()
+
+
+@dataclass
+class SequenceCompressionParams:
+ compression_ratio: float = 1.0
+ protected_first_tokens: int = 16
+ protected_last_tokens: int = 64
+
+
+@dataclass
+class BatchCompressionParams:
+ # compression_type: CompressionType = CompressionType.QUERY_AGNOSTIC
+ compression_method: CompressionMethod = CompressionMethod.COMPACTOR
+
+ do_chunked_compression: bool = True
+ chunk_size: int = 512
+
+ def __post_init__(self):
+ if self.compression_method == CompressionMethod.SNAPKV:
+ self.do_chunked_compression = False
+ logger.warning(
+ "CompressionMethod.SNAPKV is not compatible with chunked compression. Disabling it."
+ )
diff --git a/vllm/kvprune/compression/criticalkv-cursor.py b/vllm/kvprune/compression/criticalkv-cursor.py
new file mode 100644
index 0000000000000000000000000000000000000000..128632e098d7b61a11be85322dc06518c1afbe05
--- /dev/null
+++ b/vllm/kvprune/compression/criticalkv-cursor.py
@@ -0,0 +1,459 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)及首/尾保护段长度。
+
+注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.compression.compactor import (
+ CompactorCompression,
+ non_causal_attn_scores,
+)
+from vllm.kvprune.compression.snapkv import SnapKVCompression
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+
+
+# ============================================================================
+# Triton Kernel 1: 计算 ||Wo @ V||₁ (L1 范数)
+# ============================================================================
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ],
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+# ============================================================================
+# Triton Kernel 2: Stage 1 保护 + Stage 2 加权融合
+# ============================================================================
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
+ key=["Hk"],
+ cache_results=True,
+)
+@triton.jit
+def _critical_ada_fuse_kernel(
+ BASE_SCORES,
+ WO_V_NORM,
+ STAGE1_MASK,
+ cu_k,
+ OUT,
+ EPSILON: tl.constexpr,
+ STRIDE_BS_NK,
+ STRIDE_BS_HK,
+ STRIDE_WN_NK,
+ STRIDE_WN_HK,
+ STRIDE_S1_NK,
+ STRIDE_S1_HK,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
+ wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
+ s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
+
+ base = tl.load(bs_ptrs, mask=kmask, other=0.0)
+ wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
+ stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
+
+ fused = (base + EPSILON) * wnorm
+ fused = tl.where(stage1_protect == 1, float("inf"), fused)
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, fused, mask=kmask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 在每条序列上尽量贴近 kvpress 的 CriticalAdaKV 语义:
+ 1) alpha_safeguard 安全预算(每头至少保留一部分);
+ 2) 基于 base_scores 的 head-wise 自适应预算分配(head_budgets);
+ 3) Stage-1 按 head_budgets * first_stage_ratio 保护;
+ 4) Stage-2 计算 ``(base + eps) * ||Wo@V||_1``,再按 head_budgets 做每头 top-k 保护。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``、``protected_first_tokens``、``protected_last_tokens``;
+ 可选 ``critical_ada_epsilon``、``critical_ada_first_stage_ratio``、
+ ``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ prot_first = compression_ctx.protected_first_tokens or [0] * B
+ prot_last = compression_ctx.protected_last_tokens or [0] * B
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(getattr(compression_ctx, "critical_ada_alpha_safeguard", 0.2))
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
+ # kvpress 风格的每头预算(按序列自适应),用于 Stage-1/Stage-2。
+ head_budgets_by_batch = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ s = int(prot_first[b]) if b < len(prot_first) else 0
+ e = int(prot_last[b]) if b < len(prot_last) else 0
+ lo, hi = k_beg + s, k_end - e
+ compressible = max(0, hi - lo)
+ keep_pairs = int(btr[b].item())
+ if compressible <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ # 每头 token 预算(kvpress 的 n_kept)
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, compressible)
+ # 安全预算(每头至少保留 n_safe)
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ if n_safe > 0:
+ tk_safe = min(n_safe, compressible)
+ for hk in range(Hk):
+ safe_idx = torch.topk(base_scores[lo:hi, hk], tk_safe, sorted=False).indices
+ stage1_mask[lo + safe_idx, hk] = 1
+
+ # 自适应预算分配:在扁平 (token, head) 空间取 top n_kept_tokens*Hk,统计每个 head 的预算
+ budget_scores = base_scores[lo:hi, :].clone()
+ if n_safe > 0:
+ budget_scores[stage1_mask[lo:hi, :] == 1] = float("inf")
+ top_pairs = min(n_kept_tokens * Hk, budget_scores.numel())
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ top_idx_flat = torch.topk(
+ budget_scores.reshape(-1), top_pairs, sorted=False
+ ).indices
+ top_head_idx = top_idx_flat % Hk
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
+ head_budgets_by_batch.append(head_budgets)
+
+ # Stage-1:按 head_budgets 的 first_stage_ratio 分头保护(kvpress 语义)
+ for hk in range(Hk):
+ phase1_budget = int(head_budgets[hk].item() * first_stage_ratio)
+ if phase1_budget <= 0:
+ continue
+ tk = min(phase1_budget, compressible)
+ top_idx = torch.topk(base_scores[lo:hi, hk], tk, sorted=False).indices
+ stage1_mask[lo + top_idx, hk] = 1
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_fuse(_META):
+ return (B, Hk)
+
+ _critical_ada_fuse_kernel[grid_fuse](
+ base_scores,
+ wo_v_norm,
+ stage1_mask,
+ cu_seqlens,
+ final_scores,
+ EPSILON=epsilon,
+ *base_scores.stride(),
+ *wo_v_norm.stride(),
+ *stage1_mask.stride(),
+ *final_scores.stride(),
+ Hk=Hk,
+ )
+
+ # Stage-2(kvpress 语义):在融合后按每头预算再做一次 top-k 保护。
+ for b in range(B):
+ hb = head_budgets_by_batch[b]
+ if hb is None:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ s = int(prot_first[b]) if b < len(prot_first) else 0
+ e = int(prot_last[b]) if b < len(prot_last) else 0
+ lo, hi = k_beg + s, k_end - e
+ if hi <= lo:
+ continue
+ region_len = hi - lo
+ for hk in range(Hk):
+ budget = int(hb[hk].item())
+ if budget <= 0:
+ continue
+ tk = min(budget, region_len)
+ idx = torch.topk(final_scores[lo:hi, hk], tk, sorted=False).indices
+ final_scores[lo + idx, hk] = float("inf")
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx % Hk
+ seq_idx = prune_idx // Hk + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
+ 再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = getattr(cc, "critical_ada_base_scorer", "compactor") if cc is not None else "compactor"
+ if str(base).lower() == "snapkv":
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "snapkv":
+ base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
+ else:
+ # 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
+ # maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
+ # 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ base_scores = maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
diff --git a/vllm/kvprune/compression/criticalkv.py b/vllm/kvprune/compression/criticalkv.py
new file mode 100644
index 0000000000000000000000000000000000000000..d04b7bb1a9824f551ecf0ee2f2fdd07765a47db6
--- /dev/null
+++ b/vllm/kvprune/compression/criticalkv.py
@@ -0,0 +1,451 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)。CriticalAda 主链在 **PyTorch** 中与 kvpress ``CriticalAdaKVPress.compress``
+对齐;``||Wo@V||_1`` 仍默认用 Triton ``_compute_wo_v_l1_kernel``(与 ``CriticalKVPress.vwl1norm`` 同式)。
+将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 可改走 ``_vwl1_norm_kvpress_reference``。
+
+注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.compression.compactor import (
+ CompactorCompression,
+ kvpress_compactor_post_rope,
+ resolve_kvpress_compactor_blending,
+)
+from vllm.kvprune.compression.snapkv import SnapKVCompression
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
+_USE_WO_L1_REFERENCE_BACKEND = False
+
+
+def _vwl1_norm_kvpress_reference(
+ values_seg: torch.Tensor,
+ wo: torch.Tensor,
+ num_kv_heads: int,
+ num_query_groups: int,
+) -> torch.Tensor:
+ """
+ 与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
+ 将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
+
+ 算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
+ """
+ k_len, Hk, D = values_seg.shape
+ Hq, D_wo, hidden = wo.shape
+ assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
+ # [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
+ v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
+ v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
+ # Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
+ wo_f = wo
+ head_list = []
+ for head in range(Hq):
+ v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
+ head_wov = v_h.matmul(wo_f[head, :, :])
+ head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
+ head_list.append(head_wov_norm)
+ stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
+ stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
+ return stacked.transpose(0, 1).contiguous()
+
+
+# ============================================================================
+# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
+# ============================================================================
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ],
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 按 kvpress ``CriticalAdaKVPress.compress`` 的顺序实现:safeguard scatter →
+ head-major 展平做 head_budgets → Stage1 在 **已抬高** 的分数上 top-k →
+ ``(scores + ε) * ||WoV||₁`` → Stage2 scatter → 最终按 head-major 展平做 bottom-k。
+
+ ``||Wo@V||₁`` 仍用 Triton(``_compute_wo_v_l1_kernel``);中间 CriticalAda 步骤用 PyTorch
+ 与 kvpress 逐句对齐。仅 base 分数来自 Compactor/SnapKV。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
+ ``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ if B > 0 and int(k_lengths.max().item()) > 0:
+ if _USE_WO_L1_REFERENCE_BACKEND:
+ for b in range(B):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_end <= k_beg:
+ continue
+ v_seg = v[k_beg:k_end, :, :].contiguous()
+ wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
+ v_seg, wo, Hk, G
+ )
+ else:
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ # kvpress 用 finfo.max 抬高分数;与 inf 混用时 topk 行为一致
+ _score_max = float(torch.finfo(torch.float32).max)
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ head_budgets_by_batch: list[Optional[torch.Tensor]] = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+
+ scores_seg = base_scores[k_beg:k_end, :].float()
+ keep_pairs = int(btr[b].item())
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, k_len)
+
+ # scores_work: 布局 [k_len, Hk],对应 kvpress [bsz=1, H, k_len] 的 transpose(0,2) 视角下沿 token 维的 topk
+ scores_work = scores_seg.clone()
+
+ # --- Alpha safeguard(kvpress L148–152)---
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ nk = min(n_safe, k_len) if n_safe > 0 else 0
+ if nk > 0:
+ for hk in range(Hk):
+ top_idx = torch.topk(scores_work[:, hk], nk, dim=0, largest=True).indices
+ scores_work[top_idx, hk] = _score_max
+
+ # --- Head budgets:kvpress L158–164,展平顺序与 [bsz, H, k_len] 一致(head-major:h*K + t)---
+ top_pairs = min(n_kept_tokens * Hk, k_len * Hk)
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ wn = wo_v_norm[k_beg:k_end, :]
+ final_scores[k_beg:k_end, :] = (scores_seg + epsilon) * wn
+ continue
+
+ budget_flat = scores_work.permute(1, 0).contiguous().reshape(-1)
+ top_idx_flat = torch.topk(
+ budget_flat, top_pairs, largest=True, sorted=False
+ ).indices
+ top_head_idx = top_idx_flat // k_len
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int64)
+ head_budgets_by_batch.append(head_budgets)
+
+ # --- Stage 1(kvpress L166–171):在已 safeguard 的 scores_work 上沿 token 维 top-k ---
+ head_selection_budget_1st = (
+ (head_budgets.to(torch.float32) * float(first_stage_ratio))
+ .to(torch.int64)
+ .tolist()
+ )
+ M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
+ mk = min(M1, k_len) if M1 > 0 else 0
+ if mk > 0:
+ top_k_index = torch.topk(scores_work, mk, dim=0, largest=True, sorted=True).indices
+ for hk in range(Hk):
+ phase1_budget = int(head_selection_budget_1st[hk])
+ if phase1_budget <= 0:
+ continue
+ take = min(phase1_budget, mk)
+ scores_work[top_k_index[:take, hk], hk] = _score_max
+
+ # --- Stage 2 重加权(kvpress L173–175)---
+ wn = wo_v_norm[k_beg:k_end, :]
+ scores_fused = (scores_work + epsilon) * wn
+
+ # --- Stage 2 scatter(kvpress L176–179)---
+ M2 = int(head_budgets.max().item())
+ mk2 = min(M2, k_len) if M2 > 0 else 0
+ if mk2 > 0:
+ top_k_index2 = torch.topk(
+ scores_fused, mk2, dim=0, largest=True, sorted=True
+ ).indices
+ for hk in range(Hk):
+ budget = int(head_budgets[hk].item())
+ if budget <= 0:
+ continue
+ take = min(budget, mk2)
+ scores_fused[top_k_index2[:take, hk], hk] = _score_max
+
+ final_scores[k_beg:k_end, :] = scores_fused
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ # kvpress L187:``scores.reshape(bsz, -1)`` 即 [H, K] 按 head-major 展平(flat = h*K + t)
+ flat_scores = (
+ final_scores[k_beg:k_end, :].permute(1, 0).contiguous().reshape(-1)
+ )
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx // k_len
+ seq_idx = prune_idx % k_len + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 仅 ``critical_ada_base_scorer == "compactor"`` 时与 kvpress ``CompactorPress.score`` 一致
+ (``kvpress_compactor_post_rope``:``blending * l_scores + attn_scores``);其它 base(如 SnapKV)
+ 走对应单一 ScorerPress,再叠 CriticalAda。须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = (
+ getattr(cc, "critical_ada_base_scorer", "compactor")
+ if cc is not None
+ else "compactor"
+ )
+ if str(base).lower() == "compactor":
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "compactor":
+ # 特例:与 ``CompactorPress.score`` / ``CompactorCompression.post_rope_scoring`` 一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ blending = resolve_kvpress_compactor_blending(compression_context)
+ base_scores = maybe_execute_in_stream(
+ kvpress_compactor_post_rope,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ pre_rope_scores,
+ compression_context,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ blending=float(blending),
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ else:
+ base_scores = SnapKVCompression.post_rope_scoring(
+ q, k, v, pre_rope_scores, context
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
+
diff --git a/vllm/kvprune/compression/criticalkv_origin.py b/vllm/kvprune/compression/criticalkv_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..8534dee35621a9c9171c959cf87bcda52098fc34
--- /dev/null
+++ b/vllm/kvprune/compression/criticalkv_origin.py
@@ -0,0 +1,502 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)。Stage1/2 与 kvpress 论文/实现一致;``||Wo@V||_1`` 在 **算法上** 与
+``CriticalKVPress.vwl1norm`` 相同(GQA 上逐 query 头 L1 再对组取均值)。**默认用 Triton**
+(``_compute_wo_v_l1_kernel``);若需与 PyTorch 逐行对齐,将模块内 ``_USE_WO_L1_REFERENCE_BACKEND`` 改为 ``True`` 即走 ``_vwl1_norm_kvpress_reference``。
+
+注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.compression.compactor import (
+ CompactorCompression,
+ non_causal_attn_scores,
+)
+from vllm.kvprune.compression.snapkv import SnapKVCompression
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
+_USE_WO_L1_REFERENCE_BACKEND = False
+
+
+def _vwl1_norm_kvpress_reference(
+ values_seg: torch.Tensor,
+ wo: torch.Tensor,
+ num_kv_heads: int,
+ num_query_groups: int,
+) -> torch.Tensor:
+ """
+ 与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
+ 将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
+
+ 算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
+ """
+ k_len, Hk, D = values_seg.shape
+ Hq, D_wo, hidden = wo.shape
+ assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
+ # [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
+ v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
+ v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
+ # Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
+ wo_f = wo
+ head_list = []
+ for head in range(Hq):
+ v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
+ head_wov = v_h.matmul(wo_f[head, :, :])
+ head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
+ head_list.append(head_wov_norm)
+ stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
+ stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
+ return stacked.transpose(0, 1).contiguous()
+
+
+# ============================================================================
+# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
+# ============================================================================
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ],
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+# ============================================================================
+# Triton:Stage 1 保护 + Stage 2 加权融合(逐元素)
+# ============================================================================
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
+ key=["Hk"],
+ cache_results=True,
+)
+@triton.jit
+def _critical_ada_fuse_kernel(
+ BASE_SCORES,
+ WO_V_NORM,
+ STAGE1_MASK,
+ cu_k,
+ OUT,
+ STRIDE_BS_NK,
+ STRIDE_BS_HK,
+ STRIDE_WN_NK,
+ STRIDE_WN_HK,
+ STRIDE_S1_NK,
+ STRIDE_S1_HK,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ EPSILON: tl.constexpr,
+ Hk: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
+ wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
+ s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
+
+ base = tl.load(bs_ptrs, mask=kmask, other=0.0)
+ wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
+ stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
+
+ fused = (base + EPSILON) * wnorm
+ fused = tl.where(stage1_protect == 1, float("inf"), fused)
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, fused, mask=kmask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 在每条序列上对齐 kvpress ``CriticalAdaKVPress.compress``(整段 ``k_len``、与源实现相同的
+ top-k / scatter 顺序);仅 base 分数来自 vllm.kvprune 的 Compactor/SnapKV。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
+ ``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ if B > 0 and int(k_lengths.max().item()) > 0:
+ if _USE_WO_L1_REFERENCE_BACKEND:
+ for b in range(B):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_end <= k_beg:
+ continue
+ v_seg = v[k_beg:k_end, :, :].contiguous()
+ wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
+ v_seg, wo, Hk, G
+ )
+ else:
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
+ head_budgets_by_batch: list[Optional[torch.Tensor]] = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ keep_pairs = int(btr[b].item())
+ scores_seg = base_scores[k_beg:k_end, :]
+ # 与 kvpress 的 n_kept 一致:每头保留 n_kept 个 token
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, k_len)
+
+ # kvpress:topk 在「未改动的」scores 上取索引,scatter 只写在副本上,供 head_budgets 用;
+ # Stage1 仍用原始 scores_seg(见下)。
+ working = scores_seg.clone()
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ if n_safe > 0:
+ nk = min(n_safe, k_len)
+ for hk in range(Hk):
+ top_idx = torch.topk(scores_seg[:, hk], nk, sorted=True).indices
+ working[:, hk].scatter_(0, top_idx, float("inf"))
+
+ top_pairs = min(n_kept_tokens * Hk, working.numel())
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ top_idx_flat = torch.topk(working.reshape(-1), top_pairs, sorted=False).indices
+ top_head_idx = top_idx_flat % Hk
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
+ head_budgets_by_batch.append(head_budgets)
+
+ # Stage 1:与 kvpress 相同 — 先 topk(..., M1, sorted=True),再每头取前 phase1 个下标
+ head_selection_budget_1st = (
+ (head_budgets.to(torch.float32) * float(first_stage_ratio))
+ .to(torch.int64)
+ .tolist()
+ )
+ M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
+ if M1 > 0:
+ mk = min(M1, k_len)
+ for hk in range(Hk):
+ phase1_budget = int(head_selection_budget_1st[hk])
+ if phase1_budget <= 0:
+ continue
+ full_idx = torch.topk(scores_seg[:, hk], mk, sorted=True).indices
+ take = min(phase1_budget, mk)
+ stage1_mask[k_beg + full_idx[:take], hk] = 1
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_fuse(_META):
+ return (B, Hk)
+
+ _critical_ada_fuse_kernel[grid_fuse](
+ base_scores,
+ wo_v_norm,
+ stage1_mask,
+ cu_seqlens,
+ final_scores,
+ *base_scores.stride(),
+ *wo_v_norm.stride(),
+ *stage1_mask.stride(),
+ *final_scores.stride(),
+ Hk=Hk,
+ EPSILON=float(epsilon),
+ )
+
+ # Stage 2(kvpress):对融合后分数先 topk(..., M2, sorted=True),再每头取前 budget 个下标置 inf
+ for b in range(B):
+ hb = head_budgets_by_batch[b]
+ if hb is None:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ k_len = k_end - k_beg
+ if k_len <= 0:
+ continue
+ fused_seg = final_scores[k_beg:k_end, :]
+ M2 = int(hb.max().item())
+ if M2 <= 0:
+ continue
+ mk = min(M2, k_len)
+ for hk in range(Hk):
+ budget = int(hb[hk].item())
+ if budget <= 0:
+ continue
+ full_idx = torch.topk(fused_seg[:, hk], mk, sorted=True).indices
+ take = min(budget, mk)
+ final_scores[k_beg + full_idx[:take], hk] = float("inf")
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx % Hk
+ seq_idx = prune_idx // Hk + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
+ 再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = getattr(cc, "critical_ada_base_scorer", "snapkv") if cc is not None else "compactor"
+ if str(base).lower() == "snapkv":
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "snapkv":
+ base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
+ else:
+ # 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
+ # maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
+ # 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ base_scores = maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
diff --git a/vllm/kvprune/compression/snapkv.py b/vllm/kvprune/compression/snapkv.py
new file mode 100644
index 0000000000000000000000000000000000000000..69aa4c6392c9ea90c8b3c4b8230ed5bdc169ffc2
--- /dev/null
+++ b/vllm/kvprune/compression/snapkv.py
@@ -0,0 +1,546 @@
+import math
+from typing import Optional
+
+import torch
+import triton
+from triton import language as tl
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
+DEFAULT_SNAPKV_WINDOW_SIZE = 64
+DEFAULT_SNAPKV_KERNEL_SIZE = 5
+
+
+class SnapKVCompression(BaseCompressionMethod):
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ scores = maybe_execute_in_stream(
+ query_aware_key_scores,
+ q,
+ k,
+ context.cu_seqlens_q,
+ context.cu_seqlens_k,
+ w=DEFAULT_SNAPKV_WINDOW_SIZE,
+ kernel_size=DEFAULT_SNAPKV_KERNEL_SIZE,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
+ )
+ for bq in [32, 64]
+ for bk in [32, 64]
+ for num_warps in [4, 8]
+ for num_stages in [3, 4]
+ ],
+ key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
+ cache_results=True,
+)
+@triton.jit
+def _lse_and_store_logits_kernel(
+ Q,
+ K,
+ cu_q,
+ cu_k,
+ w_b, # int32 pointers
+ out_m,
+ out_S, # [B, Hk, ROWS_MAX] float32
+ LOGITS, # [Nk, Hk, ROWS_MAX] float32
+ sm_scale, # float
+ QUERY_GROUP_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ STRIDE_Q_NQ,
+ STRIDE_Q_HQ,
+ STRIDE_K_NK,
+ STRIDE_K_HK,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ ROWS_MAX,
+):
+ # program ids
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ rid = tl.program_id(2) # row-tile id
+ # batch segment bounds
+ q_end = tl.load(cu_q + b + 1)
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ q_win_beg = q_end - win
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ # rows for this (b,hk)
+ rows_b = win * QUERY_GROUP_SIZE
+ row0 = rid * BLOCK_Q
+ if row0 >= rows_b:
+ return
+
+ # exp(x) = exp2(x * 1/ln2)
+ qk_scale = sm_scale * 1.4426950408889634
+
+ offs_qrow = row0 + tl.arange(0, BLOCK_Q)
+ row_mask = offs_qrow < rows_b
+
+ # map row -> (q_idx, hq_local)
+ hq_local = offs_qrow % QUERY_GROUP_SIZE
+ q_off = offs_qrow // QUERY_GROUP_SIZE
+ q_idx = q_win_beg + q_off
+ hq_glob = hk * QUERY_GROUP_SIZE + hq_local
+
+ offs_d = tl.arange(0, D)
+
+ q_ptrs = (
+ Q
+ + q_idx[:, None] * STRIDE_Q_NQ
+ + hq_glob[:, None] * STRIDE_Q_HQ
+ + offs_d[None, :]
+ )
+ q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+ m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
+ S = tl.zeros([BLOCK_Q], dtype=tl.float32)
+
+ # Full-sequence causal attention (matches kvpress softmax), then use prefix columns only.
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
+
+ s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
+ s = tl.where(kmask[None, :], s, -float("inf"))
+ # Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
+ causal_ok = nk[None, :] <= q_idx[:, None]
+ s = tl.where(causal_ok, s, -float("inf"))
+
+ # store prefix logits only (for marginal probs on prefix keys)
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ store_mask = kmask & (nk < k_eff_end)
+ tl.store(log_ptrs, s.T, mask=store_mask[:, None] & row_mask[None, :])
+
+ # log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask)
+ cur_max = tl.max(s, 1) # [BQ]
+ n_m = tl.maximum(m, cur_max)
+ rescale = tl.math.exp2(m - n_m)
+ S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
+ m = n_m
+
+ # store m,S for these rows
+ m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
+ tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
+
+
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
+ for bq in [16, 32, 64]
+ for bk in [32, 64, 128]
+ ],
+ key=["HK", "HQ"],
+ cache_results=True,
+)
+@triton.jit
+def _prefix_probs_kernel(
+ cu_k,
+ w_b,
+ in_m,
+ in_S, # [B, Hk, ROWS_MAX] f32
+ LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only)
+ PROBS, # [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs
+ #
+ QUERY_GROUP_SIZE: tl.constexpr,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ STRIDE_PB_NK,
+ STRIDE_PB_HK,
+ STRIDE_PB_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ rows_b = win * QUERY_GROUP_SIZE
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ for row0 in tl.range(0, rows_b, BLOCK_Q):
+ r_idx = row0 + tl.arange(0, BLOCK_Q)
+ rmask = r_idx < rows_b
+
+ m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ m = tl.load(
+ m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
+ mask=rmask,
+ other=-float("inf"),
+ )
+ S = tl.load(
+ S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
+ )
+
+ valid_row = S > 0
+ m = tl.where(valid_row, m, 0.0)
+ S = tl.where(valid_row, S, 1.0)
+
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ s_T = tl.load(
+ log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
+ ) # [BK, BQ]
+
+ probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
+ probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
+
+ prob_ptrs = (
+ PROBS
+ + nk[:, None] * STRIDE_PB_NK
+ + hk * STRIDE_PB_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_PB_R
+ )
+ tl.store(prob_ptrs, probs_T, mask=kmask[:, None] & rmask[None, :])
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue(
+ OUT, # [Nk, Hk], float32
+ cu_k,
+ w_b, # [B+1], [B] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ EPS: tl.constexpr, # e.g., 1e-12
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if k_eff_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_eff_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var + EPS)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_T": bt}) for bt in [32, 64, 128, 256]],
+ key=["KERNEL_SIZE"],
+ cache_results=True,
+)
+@triton.jit
+def _snapkv_avg_pool1d_kernel(
+ IN,
+ OUT,
+ Lp,
+ STRIDE_IN_C,
+ STRIDE_IN_L,
+ STRIDE_OUT_C,
+ STRIDE_OUT_L,
+ KERNEL_SIZE: tl.constexpr,
+ PAD: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+):
+ """
+ Symmetric 1D average pool on the last dimension, matching
+ `F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp]
+ (equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size).
+ """
+ c = tl.program_id(0)
+ t0 = tl.program_id(1) * BLOCK_T + tl.arange(0, BLOCK_T)
+ mask = t0 < Lp
+
+ acc = tl.zeros([BLOCK_T], dtype=tl.float32)
+ for j in tl.static_range(KERNEL_SIZE):
+ idx = t0 - PAD + j
+ valid = (idx >= 0) & (idx < Lp)
+ ptrs = IN + c * STRIDE_IN_C + idx * STRIDE_IN_L
+ v = tl.load(ptrs, mask=valid & mask, other=0.0).to(tl.float32)
+ acc += v
+ acc = acc / tl.cast(KERNEL_SIZE, tl.float32)
+
+ out_ptrs = OUT + c * STRIDE_OUT_C + t0 * STRIDE_OUT_L
+ tl.store(out_ptrs, acc, mask=mask)
+
+
+def _snapkv_avg_pool1d_triton(x: torch.Tensor, kernel_size: int) -> torch.Tensor:
+ """
+ kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp].
+ `x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]).
+ """
+ assert x.dtype == torch.float32
+ Hk, G, Lp = x.shape
+ if Lp == 0:
+ return x
+ pad = kernel_size // 2
+ x2 = x.reshape(Hk * G, Lp).contiguous()
+ out = torch.empty_like(x2)
+ C = Hk * G
+ si_c, si_l = x2.stride()
+ so_c, so_l = out.stride()
+
+ def grid(meta):
+ return (C, triton.cdiv(Lp, meta["BLOCK_T"]))
+
+ _snapkv_avg_pool1d_kernel[grid](
+ x2,
+ out,
+ Lp,
+ si_c,
+ si_l,
+ so_c,
+ so_l,
+ KERNEL_SIZE=kernel_size,
+ PAD=pad,
+ )
+ return out.view(Hk, G, Lp)
+
+
+def _snapkv_kvpress_epilogue(
+ probs_buf: torch.Tensor,
+ out: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ w: torch.Tensor,
+ G: int,
+ Hk: int,
+ kernel_size: int,
+) -> None:
+ """
+ Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d
+ → mean over GQA groups → pad tail with global max of prefix scores.
+ """
+ B = cu_seqlens_k.numel() - 1
+ for b in range(B):
+ k_beg = int(cu_seqlens_k[b].item())
+ k_end = int(cu_seqlens_k[b + 1].item())
+ win = int(w[b].item())
+ k_eff_end = k_end - win
+ if win <= 0 or k_eff_end <= k_beg:
+ continue
+ Lp = k_eff_end - k_beg
+ rows_b = win * G
+ p = probs_buf[k_beg:k_eff_end, :, :rows_b]
+ # [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout
+ x = p.view(Lp, Hk, win, G).mean(dim=2)
+ x = x.permute(1, 2, 0).contiguous() # [Hk, G, Lp]
+ x = _snapkv_avg_pool1d_triton(x, kernel_size)
+ x = x.mean(dim=1)
+ seg = x.permute(1, 0).contiguous()
+ out[k_beg:k_eff_end, :] = seg
+ pad_val = seg.max()
+ out[k_eff_end:k_end, :] = pad_val
+
+
+def query_aware_key_scores(
+ q: torch.Tensor, # [N_q, Hq, D]
+ k: torch.Tensor, # [N_k, Hk, D]
+ cu_seqlens_q: torch.Tensor, # [B+1], int32
+ cu_seqlens_k: torch.Tensor, # [B+1], int32
+ w: torch.Tensor | int, # [B], int32
+ sm_scale: float = None, # defaults to 1/sqrt(D)
+ *,
+ kernel_size: int = DEFAULT_SNAPKV_KERNEL_SIZE,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+ normalize: bool = False,
+) -> Optional[torch.Tensor]:
+ assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
+ device = q.device
+ N_q, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ B = cu_seqlens_q.numel() - 1
+ assert B == cu_seqlens_k.numel() - 1
+
+ G = Hq // Hk
+ if type(w) is int:
+ max_w = w
+ w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
+ else:
+ max_w = int(w.max().item())
+ assert w.numel() == B
+ ROWS_MAX = max_w * G
+ if ROWS_MAX == 0:
+ return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+
+ out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+ m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ probs_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+
+ # strides
+ STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
+ STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
+ STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
+ STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
+ STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
+ STRIDE_PB_NK, STRIDE_PB_HK, STRIDE_PB_R = probs_buf.stride()
+ STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
+
+ def grid(META):
+ return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
+
+ _lse_and_store_logits_kernel[grid](
+ q,
+ k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ sm_scale,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ D=D,
+ STRIDE_Q_NQ=STRIDE_Q_NQ,
+ STRIDE_Q_HQ=STRIDE_Q_HQ,
+ STRIDE_K_NK=STRIDE_K_NK,
+ STRIDE_K_HK=STRIDE_K_HK,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ ROWS_MAX=ROWS_MAX,
+ )
+
+ _prefix_probs_kernel[(B, Hk)](
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ probs_buf,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ STRIDE_PB_NK=STRIDE_PB_NK,
+ STRIDE_PB_HK=STRIDE_PB_HK,
+ STRIDE_PB_R=STRIDE_PB_R,
+ )
+ _snapkv_kvpress_epilogue(
+ probs_buf, out, cu_seqlens_k, w, G, Hk, kernel_size
+ )
+ if normalize:
+ _zscore_per_batch_epilogue[(B,)](
+ out,
+ cu_seqlens_k,
+ w,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK=Hk,
+ EPS=1e-12,
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ accum_scores.mul_(accum_blending)
+ accum_scores.add_(out)
+ return accum_scores
+ else:
+ return out
+
diff --git a/vllm/kvprune/compression/snapkv_origin.py b/vllm/kvprune/compression/snapkv_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a60bc21ba417a128e8e6547b7f97935342f4e24
--- /dev/null
+++ b/vllm/kvprune/compression/snapkv_origin.py
@@ -0,0 +1,449 @@
+import math
+from typing import Optional
+
+import torch
+import triton
+from triton import language as tl
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+
+class SnapKVCompression(BaseCompressionMethod):
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ scores = maybe_execute_in_stream(
+ query_aware_key_scores,
+ q,
+ k,
+ context.cu_seqlens_q,
+ context.cu_seqlens_k,
+ w=32,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
+ )
+ for bq in [32, 64]
+ for bk in [32, 64]
+ for num_warps in [4, 8]
+ for num_stages in [3, 4]
+ ],
+ key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
+ cache_results=True,
+)
+@triton.jit
+def _lse_and_store_logits_kernel(
+ Q,
+ K,
+ cu_q,
+ cu_k,
+ w_b, # int32 pointers
+ out_m,
+ out_S, # [B, Hk, ROWS_MAX] float32
+ LOGITS, # [Nk, Hk, ROWS_MAX] float32
+ sm_scale, # float
+ QUERY_GROUP_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ STRIDE_Q_NQ,
+ STRIDE_Q_HQ,
+ STRIDE_K_NK,
+ STRIDE_K_HK,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ ROWS_MAX,
+):
+ # program ids
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ rid = tl.program_id(2) # row-tile id
+ # batch segment bounds
+ q_end = tl.load(cu_q + b + 1)
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ q_win_beg = q_end - win
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ # rows for this (b,hk)
+ rows_b = win * QUERY_GROUP_SIZE
+ row0 = rid * BLOCK_Q
+ if row0 >= rows_b:
+ return
+
+ # exp(x) = exp2(x * 1/ln2)
+ qk_scale = sm_scale * 1.4426950408889634
+
+ offs_qrow = row0 + tl.arange(0, BLOCK_Q)
+ row_mask = offs_qrow < rows_b
+
+ # map row -> (q_idx, hq_local)
+ hq_local = offs_qrow % QUERY_GROUP_SIZE
+ q_off = offs_qrow // QUERY_GROUP_SIZE
+ q_idx = q_win_beg + q_off
+ hq_glob = hk * QUERY_GROUP_SIZE + hq_local
+
+ offs_d = tl.arange(0, D)
+
+ q_ptrs = (
+ Q
+ + q_idx[:, None] * STRIDE_Q_NQ
+ + hq_glob[:, None] * STRIDE_Q_HQ
+ + offs_d[None, :]
+ )
+ q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+ m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
+ S = tl.zeros([BLOCK_Q], dtype=tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
+
+ s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
+ s = tl.where(kmask[None, :], s, -float("inf"))
+
+ # store into LOGITS[nk, hk, row] -> [BK, BQ]
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
+
+ # log2 streaming LSE update
+ cur_max = tl.max(s, 1) # [BQ]
+ n_m = tl.maximum(m, cur_max)
+ rescale = tl.math.exp2(m - n_m)
+ S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
+ m = n_m
+
+ # store m,S for these rows
+ m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
+ tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
+
+
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
+ for bq in [16, 32, 64]
+ for bk in [32, 64, 128]
+ ],
+ key=["HK", "HQ"],
+ cache_results=True,
+)
+@triton.jit
+def _scores_from_logits_kernel(
+ cu_k,
+ w_b,
+ in_m,
+ in_S, # [B, Hk, ROWS_MAX] f32
+ LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits
+ OUT, # [Nk, Hk] f32
+ #
+ QUERY_GROUP_SIZE: tl.constexpr,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ #
+ DO_POOL: tl.constexpr, # set True to enable in-place avg pool
+ KPOOL: tl.constexpr, # kernel size for avg pool (stride=1)
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ rows_b = win * QUERY_GROUP_SIZE
+
+ # === scores over computed region ===
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ scores = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for row0 in tl.range(0, rows_b, BLOCK_Q):
+ r_idx = row0 + tl.arange(0, BLOCK_Q)
+ rmask = r_idx < rows_b
+
+ # load m, S for rows
+ m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ m = tl.load(
+ m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
+ mask=rmask,
+ other=-float("inf"),
+ )
+ S = tl.load(
+ S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
+ )
+
+ valid_row = S > 0
+ m = tl.where(valid_row, m, 0.0)
+ S = tl.where(valid_row, S, 1.0)
+
+ # load stored logits^T: [BK, BQ]
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ s_T = tl.load(
+ log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
+ ) # [BK, BQ]
+
+ # probs^T = exp2(s_T - m) / S, sum over rows
+ probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
+ probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
+
+ scores += tl.sum(probs_T, 1) # [BK]
+
+ if DO_POOL and (KPOOL > 1):
+ i = tl.arange(0, BLOCK_K)[:, None]
+ j = tl.arange(0, BLOCK_K)[None, :]
+ band = (j <= i) & ((i - j) < KPOOL)
+ band = band & kmask[None, :]
+ # sum within band
+ sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1) # [BK]
+ denom = tl.sum(band, 1).to(tl.float32) # [BK]
+ denom = tl.where(denom > 0, denom, 1.0)
+ scores = sums / denom
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, scores, mask=kmask)
+
+ pad_beg = k_eff_end
+ pad_end = k_end
+ if pad_end > pad_beg:
+ for ks in tl.range(pad_beg, pad_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < pad_end
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(
+ out_ptrs, tl.full([BLOCK_K], float("inf"), dtype=tl.float32), mask=kmask
+ )
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue(
+ OUT, # [Nk, Hk], float32
+ cu_k,
+ w_b, # [B+1], [B] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ EPS: tl.constexpr, # e.g., 1e-12
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if k_eff_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_eff_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var + EPS)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+def query_aware_key_scores(
+ q: torch.Tensor, # [N_q, Hq, D]
+ k: torch.Tensor, # [N_k, Hk, D]
+ cu_seqlens_q: torch.Tensor, # [B+1], int32
+ cu_seqlens_k: torch.Tensor, # [B+1], int32
+ w: torch.Tensor | int, # [B], int32
+ sm_scale: float = None, # defaults to 1/sqrt(D)
+ *,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+ normalize: bool = False,
+) -> Optional[torch.Tensor]:
+ assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
+ device = q.device
+ N_q, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ B = cu_seqlens_q.numel() - 1
+ assert B == cu_seqlens_k.numel() - 1
+
+ G = Hq // Hk
+ if type(w) is int:
+ max_w = w
+ w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
+ else:
+ max_w = int(w.max().item())
+ assert w.numel() == B
+ ROWS_MAX = max_w * G
+ if ROWS_MAX == 0:
+ return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+
+ out = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+
+ # strides
+ STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
+ STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
+ STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
+ STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
+ STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
+ STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
+
+ def grid(META):
+ return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
+
+ _lse_and_store_logits_kernel[grid](
+ q,
+ k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ sm_scale,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ D=D,
+ STRIDE_Q_NQ=STRIDE_Q_NQ,
+ STRIDE_Q_HQ=STRIDE_Q_HQ,
+ STRIDE_K_NK=STRIDE_K_NK,
+ STRIDE_K_HK=STRIDE_K_HK,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ ROWS_MAX=ROWS_MAX,
+ )
+
+ _scores_from_logits_kernel[(B, Hk)](
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ out,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ STRIDE_OUT_NK=STRIDE_OUT_NK,
+ STRIDE_OUT_HK=STRIDE_OUT_HK,
+ DO_POOL=True,
+ KPOOL=5,
+ )
+ if normalize:
+ _zscore_per_batch_epilogue[(B,)](
+ out,
+ cu_seqlens_k,
+ w,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK=Hk,
+ EPS=1e-12,
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ accum_scores.mul_(accum_blending)
+ accum_scores.add_(out)
+ return accum_scores
+ else:
+ return out
diff --git a/vllm/kvprune/config/__init__.py b/vllm/kvprune/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..717459650025a6551cdc91bb5136c450984eaca6
--- /dev/null
+++ b/vllm/kvprune/config/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Engine / sampling / kernel constants (compactor-compatible)."""
+
+from vllm.kvprune.config.constants import RESERVED_BATCH, TRITON_RESERVED_BATCH
+
+__all__ = ["RESERVED_BATCH", "TRITON_RESERVED_BATCH"]
diff --git a/vllm/kvprune/config/constants.py b/vllm/kvprune/config/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac943e40f261c61398ad82cb7eb4f714e0590aad
--- /dev/null
+++ b/vllm/kvprune/config/constants.py
@@ -0,0 +1,5 @@
+RESERVED_BATCH = 0
+# NOTE: Triton `tl.constexpr` is intended for use in kernel signatures/annotations.
+# Some Triton builds reject passing `tl.constexpr(...)` objects as constexpr values.
+# Keep the runtime value as a plain int and let kernel signatures declare constexpr.
+TRITON_RESERVED_BATCH = RESERVED_BATCH
diff --git a/vllm/kvprune/config/engine_config.py b/vllm/kvprune/config/engine_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab165ffb92b93a21643f6270501bf3c61fe280dd
--- /dev/null
+++ b/vllm/kvprune/config/engine_config.py
@@ -0,0 +1,129 @@
+import os
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import List, Optional
+
+from transformers import AutoConfig
+
+
+class AttentionBackend(Enum):
+ """Legacy coarse backend toggle (prefer :class:`KvpruneAttentionSchedule`)."""
+
+ FLASH_ATTENTION = auto()
+ COMPACTOR_TRITON = auto()
+
+
+class KvpruneAttentionSchedule(Enum):
+ """FlashAttention vs Triton split for prefill / decode (KV **writes** stay Triton)."""
+
+ # Default: FA varlen prefill; decode uses ``head_sparse_decode_attention`` (Triton).
+ FA_PREFILL_TRITON_DECODE = auto()
+ # Prefill attention uses ``causal_sparse_varlen_with_cache`` (Triton); decode Triton.
+ TRITON_PREFILL_TRITON_DECODE = auto()
+ # "PDFA": FA prefill + FA decode; paged KV **storage** (incl. pruned top-k) unchanged.
+ PDFA = auto()
+
+
+@dataclass
+class LLMConfig:
+ """Configuration for the :class:`LLM` engine.
+ Parameters
+ ----------
+ model : str
+ Hugging Face model identifier (e.g. ``"meta-llama/Meta-Llama-3-8B"``) or
+ a local model name that can be resolved by
+ :func:`transformers.AutoConfig.from_pretrained`.
+ path : str, optional
+ Local directory containing the model weights. If ``None``, the engine
+ will attempt to resolve a local snapshot for ``model`` using
+ :func:`huggingface_hub.snapshot_download`.
+ max_num_seqs : int, default 256
+ Upper bound on the number of concurrent batches that the scheduler and
+ KV-cache manager are allowed to handle. This affects the size of the
+ page table and some internal buffers.
+ max_model_len : int, default 40960
+ Maximum context length (in tokens) that the engine will allocate KV cache
+ and CUDA graphs for. During initialization this value is clamped to
+ ``hf_config.max_position_embeddings`` for the chosen model.
+ gpu_memory_utilization : float, default 0.9
+ Fraction of the total GPU memory that may be used for KV cache and model
+ activations. Values should be in ``(0, 1]``. If this budget is too small,
+ the KV-cache manager may raise an error at warmup time due
+ to insufficient memory.
+ tensor_parallel_size : int, default 1
+ Number of tensor-parallel workers to shard the model
+ across. Must be between 1 and 8, and must evenly divide the model's
+ number of key/value heads.
+ enforce_eager : bool, default False
+ If ``True``, disable CUDA graph capture and always run the model in
+ eager mode during decoding. This reduces throughput. When ``False``,
+ the engine will capture and reuse CUDA graphs for supported
+ batch sizes and sequence lengths.
+ hf_config : transformers.AutoConfig, optional
+ Pre-loaded Hugging Face configuration for the model. If ``None``,
+ it will then be populated automatically based on ``model``.
+ eos : int, default -1
+ Primary stop token id (warmup / single-id paths). If ``-1``, the
+ :class:`LLM` constructor fills this and :attr:`eos_token_ids` from the
+ tokenizer.
+ eos_token_ids : list of int, optional
+ All token ids that terminate generation (e.g. HF tokenizers may expose
+ ``eos_token_id`` as a list for chat models). If ``None``, inferred in
+ :class:`LLM` from the tokenizer and model type.
+ kvcache_page_size : int, default 128
+ Number of tokens stored in a single KV-cache page. Smaller pages improve
+ allocation flexibility but increase page-table overhead; larger pages
+ reduce overhead but have coarser granularity.
+ leverage_sketch_size : int, default 48
+ Sketch dimension used by the Compactor leverage-score estimator.
+ attention_schedule : KvpruneAttentionSchedule, default FA_PREFILL_TRITON_DECODE
+ Which **attention** implementation runs on prefill vs decode. KV **writes**
+ (``prefill_store_*``, ``decode_store_kv``, pruned top-k) always use the
+ existing Triton store kernels. Env ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` uses
+ short names: ``fa_triton`` (default), ``pdtriton``, ``pdfa``. Enum values:
+ ``FA_PREFILL_TRITON_DECODE`` — FA prefill, Triton decode;
+ ``TRITON_PREFILL_TRITON_DECODE`` — Triton prefill + decode;
+ ``PDFA`` — FA prefill + FA decode (still Triton KV I/O).
+ attention_backend : AttentionBackend, optional
+ Deprecated. Ignored if ``attention_schedule`` is set; otherwise mapped
+ for backward compatibility.
+ """
+
+ model: str
+ path: Optional[str] = None
+ nccl_port: Optional[int] = 1218
+ max_num_seqs: int = 256
+ max_model_len: int = 40960
+ gpu_memory_utilization: float = 0.9
+ tensor_parallel_size: int = 1
+ enforce_eager: bool = False
+ hf_config: AutoConfig | None = None
+ eos: int = -1
+ eos_token_ids: Optional[List[int]] = None
+ kvcache_page_size: int = 128
+ leverage_sketch_size: int = 48
+ attention_schedule: KvpruneAttentionSchedule = (
+ KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ )
+ attention_backend: AttentionBackend | None = None
+ show_progress_bar: bool = True
+
+ def __post_init__(self):
+ if self.attention_backend is not None:
+ if self.attention_backend == AttentionBackend.FLASH_ATTENTION:
+ self.attention_schedule = KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ else:
+ self.attention_schedule = (
+ KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
+ )
+ if self.path is not None and not os.path.isdir(self.path):
+ raise NotADirectoryError(f"Engine config dir {self.path} does not exist")
+ if self.tensor_parallel_size <= 0 or self.tensor_parallel_size > 8:
+ assert 1 <= self.tensor_parallel_size <= 8
+ raise ValueError("tensor_parallel_size must be >= 1 and <= 8")
+ if self.hf_config is None:
+ self.hf_config = AutoConfig.from_pretrained(self.model)
+ self.max_model_len = min(
+ self.max_model_len, self.hf_config.max_position_embeddings
+ )
+
diff --git a/vllm/kvprune/config/sampling_params.py b/vllm/kvprune/config/sampling_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..8202ad67d07ed082822eedcc926e3fa85cf40234
--- /dev/null
+++ b/vllm/kvprune/config/sampling_params.py
@@ -0,0 +1,11 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class SamplingParams:
+ temperature: float = 1.0
+ max_new_tokens: int = 256
+
+ def __post_init__(self):
+ if self.temperature < 0:
+ raise ValueError("Temperature cannot be negative")
diff --git a/vllm/kvprune/core/__init__.py b/vllm/kvprune/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..87e5cfb1bb9b8d50785df223fe4c50b133246937
--- /dev/null
+++ b/vllm/kvprune/core/__init__.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Core: compactor ``LLMEngine`` stack (``llm_engine``, ``model_runner``, ``scheduler``, …).
+
+v1 集成路径使用子模块显式导入(如 ``from vllm.kvprune.core.llm_engine import LLMEngine``),
+不要求本包聚合已移除的可选钩子(``runtime`` / ``flash_integration`` / ``block_budget``)。
+"""
+
+from vllm.kvprune.core.compression_bridge import (
+ VALID_ALIASES_FOR_SAMPLING,
+ compression_method_id_to_enum,
+ compression_method_str_to_id,
+)
+
+__all__ = [
+ "VALID_ALIASES_FOR_SAMPLING",
+ "compression_method_id_to_enum",
+ "compression_method_str_to_id",
+]
diff --git a/vllm/kvprune/core/compression_bridge.py b/vllm/kvprune/core/compression_bridge.py
new file mode 100644
index 0000000000000000000000000000000000000000..43046dccf82b7233d7169d81bc299616ff3314f8
--- /dev/null
+++ b/vllm/kvprune/core/compression_bridge.py
@@ -0,0 +1,60 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Map compression method strings (e.g. from :class:`~vllm.kvprune.integration.CompressionParams`) to kvprune GPU / enum IDs."""
+
+from __future__ import annotations
+
+from vllm.kvprune.compression.compression_config import CompressionMethod
+
+# IDs stored on device [num_reqs_padded] (int32). Order is stable for kernels.
+COMPRESSION_METHOD_ID_NONE = 0
+COMPRESSION_METHOD_ID_CRITICALADAKV = 1
+COMPRESSION_METHOD_ID_COMPACTOR = 2
+COMPRESSION_METHOD_ID_SNAPKV = 3
+
+# Aliases accepted for method strings (case-insensitive after strip).
+VALID_ALIASES_FOR_SAMPLING: frozenset[str] = frozenset(
+ {"none", "criticaladakv", "compactor", "snapkv"}
+)
+
+_STR_TO_ID: dict[str, int] = {
+ "none": COMPRESSION_METHOD_ID_NONE,
+ "criticaladakv": COMPRESSION_METHOD_ID_CRITICALADAKV,
+ "compactor": COMPRESSION_METHOD_ID_COMPACTOR,
+ "snapkv": COMPRESSION_METHOD_ID_SNAPKV,
+}
+
+_ID_TO_COMPRESSION_METHOD: dict[int, CompressionMethod] = {
+ COMPRESSION_METHOD_ID_NONE: CompressionMethod.NONE,
+ COMPRESSION_METHOD_ID_CRITICALADAKV: CompressionMethod.CRITICALADAKV,
+ COMPRESSION_METHOD_ID_COMPACTOR: CompressionMethod.COMPACTOR,
+ COMPRESSION_METHOD_ID_SNAPKV: CompressionMethod.SNAPKV,
+}
+
+
+def compression_method_str_to_id(s: str) -> int:
+ """Normalize and map user string to a stable int id (0..3)."""
+ key = (s or "none").strip().lower()
+ if key not in _STR_TO_ID:
+ raise ValueError(
+ f"Unknown compression_method {s!r}; expected one of "
+ f"{sorted(VALID_ALIASES_FOR_SAMPLING)}"
+ )
+ return _STR_TO_ID[key]
+
+
+def compression_method_id_to_enum(method_id: int) -> CompressionMethod:
+ if method_id not in _ID_TO_COMPRESSION_METHOD:
+ return CompressionMethod.NONE
+ return _ID_TO_COMPRESSION_METHOD[method_id]
+
+
+__all__ = [
+ "COMPRESSION_METHOD_ID_NONE",
+ "COMPRESSION_METHOD_ID_CRITICALADAKV",
+ "COMPRESSION_METHOD_ID_COMPACTOR",
+ "COMPRESSION_METHOD_ID_SNAPKV",
+ "VALID_ALIASES_FOR_SAMPLING",
+ "compression_method_id_to_enum",
+ "compression_method_str_to_id",
+]
diff --git a/vllm/kvprune/core/llm_engine.py b/vllm/kvprune/core/llm_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..0813cbcabc2a36f92be188a13c01c39672c36c2d
--- /dev/null
+++ b/vllm/kvprune/core/llm_engine.py
@@ -0,0 +1,441 @@
+from __future__ import annotations
+
+import atexit
+import inspect
+import logging
+from pathlib import Path
+from typing import Any, List, Optional, Union
+
+import torch.nn as nn
+import torch.multiprocessing as mp
+from vllm.kvprune.compression.compression_config import (
+ BatchCompressionParams,
+ SequenceCompressionParams,
+)
+from vllm.kvprune.config.engine_config import LLMConfig
+from vllm.kvprune.config.sampling_params import SamplingParams
+from vllm.kvprune.core.model_runner import ModelRunner
+from vllm.kvprune.models import MODEL_REGISTRY
+from vllm.kvprune.utils.sequence import Sequence
+from transformers import AutoTokenizer
+
+logger = logging.getLogger(__name__)
+
+PromptLike = Union[str, List[int]]
+
+
+def _infer_stop_token_ids(tokenizer, hf_config) -> list[int]:
+ """
+ Build the set of token ids that should end generation.
+
+ Newer HF chat tokenizers often expose ``eos_token_id`` as a *list* of ids.
+ The engine must not compare generated ids to that list as a single ``int``;
+ see :attr:`LLMConfig.eos_token_ids` and decode-time ``torch.isin``.
+
+ Qwen chat uses ```` (im_end) as the assistant turn boundary; include it
+ when present in ``additional_special_tokens`` / ``added_tokens_encoder``. We
+ avoid loose substring matches like ``\"end\"`` that can tag unrelated tokens.
+ """
+ raw = tokenizer.eos_token_id
+ ids: list[int] = []
+ if isinstance(raw, (list, tuple)):
+ ids.extend(int(x) for x in raw)
+ elif raw is not None:
+ ids.append(int(raw))
+ unk_id = getattr(tokenizer, "unk_token_id", None)
+
+ def _maybe_add_tid(tid: int) -> None:
+ if not isinstance(tid, int) or tid < 0:
+ return
+ if unk_id is not None and tid == unk_id:
+ return
+ if tid not in ids:
+ ids.append(tid)
+
+ model_type = getattr(hf_config, "model_type", None)
+ if model_type in ("qwen2", "qwen3", "qwen2_moe", "qwen3_moe"):
+ enc = getattr(tokenizer, "added_tokens_encoder", None)
+ if isinstance(enc, dict):
+ for key, tid in enc.items():
+ if isinstance(key, str) and "im_end" in key:
+ _maybe_add_tid(int(tid))
+ for extra in getattr(tokenizer, "additional_special_tokens", []) or []:
+ if not isinstance(extra, str) or "im_end" not in extra:
+ continue
+ try:
+ tid = tokenizer.convert_tokens_to_ids(extra)
+ except (TypeError, ValueError, KeyError):
+ continue
+ _maybe_add_tid(tid)
+
+ if not ids:
+ raise ValueError(
+ "Could not infer stop token ids from the tokenizer; set "
+ "LLMConfig(eos_token_ids=[...]) explicitly."
+ )
+ return ids
+
+
+def _merge_apply_chat_template_kwargs(
+ tokenizer,
+ user_kwargs: Optional[dict[str, Any]],
+) -> dict[str, Any]:
+ """
+ Merge user kwargs with defaults for HF chat templates that support them.
+
+ Qwen3 (and similar) instruct models expect `add_generation_prompt=True` so
+ the first generated token continues the assistant turn; without it, output
+ can repeat punctuation / template fragments. `enable_thinking=False` avoids
+ the Qwen3 reasoning channel when the tokenizer supports it.
+ """
+ out = dict(user_kwargs or {})
+ try:
+ sig = inspect.signature(tokenizer.apply_chat_template)
+ except (TypeError, ValueError):
+ return out
+ if "add_generation_prompt" in sig.parameters and "add_generation_prompt" not in out:
+ out["add_generation_prompt"] = True
+ if "enable_thinking" in sig.parameters and "enable_thinking" not in out:
+ out["enable_thinking"] = False
+ return out
+
+
+def _runner_entry(config: LLMConfig, rank: int, evt):
+ runner = None
+ try:
+ runner = ModelRunner(config, rank, evt)
+ runner.loop()
+ except Exception as e:
+ logging.exception(f"Rank {rank}: {repr(e)}")
+ finally:
+ if runner is not None:
+ runner.exit()
+
+
+class LLMEngine:
+ """High-level engine coordinating model runners and scheduling"""
+
+ def __init__(self, config: LLMConfig, external_model: nn.Module | None = None):
+ self.config = config
+ if self.config.hf_config.model_type not in MODEL_REGISTRY:
+ raise ValueError(f"Unknown model {self.config.model}")
+ if config.path is None:
+ # Local directory: use it directly (no Hub round-trip).
+ try:
+ mp = Path(config.model)
+ if mp.is_dir() and (mp / "config.json").is_file():
+ self.config.path = str(mp.resolve())
+ logger.info("Using local model directory for tokenizer: %s", self.config.path)
+ except OSError:
+ pass
+ if config.path is None:
+ from huggingface_hub import snapshot_download
+
+ # Hub repo id: allow downloading missing shards/tokenizer files when cache
+ # is incomplete (local_files_only=False). Local dirs are handled above.
+ self.config.path = snapshot_download(
+ repo_id=config.model,
+ local_files_only=False,
+ )
+ logger.info(
+ "Resolved Hugging Face snapshot for %s @ %s",
+ self.config.model,
+ self.config.path,
+ )
+ assert self.config.path is not None
+ _trust = bool(getattr(self.config.hf_config, "trust_remote_code", False))
+ # Always load tokenizer from the resolved on-disk tree so we do not re-hit
+ # the Hub with the repo id (can re-download tokenizer / LFS shards).
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.config.path,
+ use_fast=True,
+ trust_remote_code=_trust,
+ )
+ if self.config.eos_token_ids is None:
+ if self.config.eos != -1:
+ self.config.eos_token_ids = [int(self.config.eos)]
+ else:
+ self.config.eos_token_ids = _infer_stop_token_ids(
+ self.tokenizer, self.config.hf_config
+ )
+ else:
+ self.config.eos_token_ids = [int(x) for x in self.config.eos_token_ids]
+ self.config.eos_token_ids = sorted(set(self.config.eos_token_ids))
+ if self.config.eos == -1:
+ self.config.eos = int(self.config.eos_token_ids[0])
+ else:
+ self.config.eos = int(self.config.eos)
+ if self.config.eos not in self.config.eos_token_ids:
+ self.config.eos_token_ids = sorted(
+ self.config.eos_token_ids + [self.config.eos]
+ )
+
+ if external_model is not None and int(self.config.tensor_parallel_size) != 1:
+ raise ValueError(
+ "external_model (shared-weight compactor path) only supports "
+ "tensor_parallel_size=1"
+ )
+
+ self.ps = []
+ world_size = int(self.config.tensor_parallel_size)
+ self.events = []
+ if world_size > 1:
+ ctx = mp.get_context("spawn")
+ for r in range(1, world_size):
+ event = ctx.Event()
+ p = ctx.Process(
+ target=_runner_entry,
+ args=(self.config, r, event),
+ daemon=True,
+ )
+ p.start()
+ self.ps.append(p)
+ self.events.append(event)
+
+ self.master_model_runner = ModelRunner(
+ self.config,
+ rank=0,
+ peer_events=self.events,
+ external_model=external_model,
+ )
+ atexit.register(self.exit)
+
+ def exit(self):
+ if getattr(self, "_exited", False):
+ return
+ self._exited = True
+ runner = getattr(self, "master_model_runner", None)
+ if runner is not None:
+ try:
+ runner.exit()
+ except Exception:
+ logger.exception("Failed to exit master ModelRunner cleanly")
+ for p in self.ps:
+ if p.is_alive():
+ p.terminate()
+ p.join(timeout=1.0)
+ if hasattr(self, "events"):
+ self.events.clear()
+
+ def tokenize_prompt(self, prompt: PromptLike, **tokenizer_kwargs) -> List[int]:
+ """
+ Turn a raw prompt into token IDs.
+ """
+ if isinstance(prompt, str):
+ return self.tokenizer(prompt, **tokenizer_kwargs)["input_ids"]
+ else:
+ return list(prompt)
+
+ def detokenize_prompt(
+ self, sequences: List[Sequence], **detokenizer_kwargs
+ ) -> List[str]:
+ """
+ Turn completed Sequences into strings.
+ """
+ defaults: dict[str, Any] = {"skip_special_tokens": True}
+ merged = {**defaults, **detokenizer_kwargs}
+ return self.tokenizer.batch_decode(
+ [s.completion_token_ids for s in sequences], **merged
+ )
+
+ def _build_sequences(
+ self,
+ prompts: List[PromptLike] | PromptLike,
+ sampling_params: SamplingParams | List[SamplingParams],
+ per_sequence_compression_params: Optional[
+ SequenceCompressionParams | List[SequenceCompressionParams]
+ ] = None,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ ) -> List[Sequence]:
+ """
+ Build Sequence objects from prompts, sampling params, and optional
+ per-sequence compression parameters.
+ """
+ tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
+
+ if not isinstance(prompts, list):
+ prompts = [prompts]
+
+ if isinstance(sampling_params, SamplingParams):
+ sampling_params_list: List[SamplingParams] = [sampling_params] * len(
+ prompts
+ )
+ else:
+ sampling_params_list = sampling_params
+ assert len(sampling_params_list) == len(prompts), (
+ "sampling_params list must match prompts length"
+ )
+ if per_sequence_compression_params is None:
+ compression_params_list: List[SequenceCompressionParams] = [
+ SequenceCompressionParams(1.0) for _ in prompts
+ ]
+ elif isinstance(per_sequence_compression_params, SequenceCompressionParams):
+ compression_params_list = [per_sequence_compression_params] * len(prompts)
+ else:
+ # list-like
+ assert len(per_sequence_compression_params) == len(prompts), (
+ "per_sequence_compression_params list must match prompts length"
+ )
+ compression_params_list = list(per_sequence_compression_params)
+
+ seqs: List[Sequence] = []
+ for prompt, sparams, cparams in zip(
+ prompts, sampling_params_list, compression_params_list
+ ):
+ token_ids = self.tokenize_prompt(prompt, **tokenizer_kwargs)
+ if cparams.protected_first_tokens + cparams.protected_last_tokens >= len(token_ids):
+ cparams.compression_ratio = 1.0
+ seqs.append(
+ Sequence(
+ prompt_token_ids=token_ids,
+ sampling_params=sparams,
+ compression_params=cparams,
+ )
+ )
+ return seqs
+
+ def generate(
+ self,
+ prompts: List[PromptLike] | PromptLike,
+ sampling_params: SamplingParams | List[SamplingParams],
+ batch_compression_params: BatchCompressionParams,
+ *,
+ per_sequence_compression_params: Union[
+ List[SequenceCompressionParams], SequenceCompressionParams
+ ] = None,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ detokenizer_kwargs: Optional[dict[str, Any]] = None,
+ return_sequences: bool = False,
+ ) -> List[str] | tuple[List[str], List[Sequence]]:
+ """
+ Accept prompts and return completed Sequences.
+ Args:
+ :param prompts:
+ Single prompt or list of prompts, each either a raw text prompt,
+ or pre-tokenized input IDs.
+ :param sampling_params:
+ A single SamplingParams for all prompts in this batch or a list of
+ SamplingParams with the same length as ``prompts``.
+ :param batch_compression_params:
+ Compression settings for this batch.
+ :param per_sequence_compression_params:
+ Per-sequence compression parameters, including the compression
+ ratio to be applied and the size of the protected regions of the
+ sequence (how many start tokens and end tokens to keep uncompressed).
+ If a SequenceCompressionParams instance, the same params will be
+ applied to all sequences in this batch; if a list is provided,
+ each SequenceCompressionParams will be attached to the corresponding
+ prompt in the batch.
+ :param tokenizer_kwargs:
+ Extra kwargs forwarded to ``tokenizer(...)`` when tokenizing
+ string prompts.
+ :param detokenizer_kwargs:
+ Passed through to `tokenizer.batch_decode`.
+ :param return_sequences:
+ Whether to return sequence objects or not
+ Returns:
+ :return List[Sequence]:
+ One Sequence per input prompt, with `completion_token_ids`
+ filled in after generation.
+ """
+ tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
+ detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
+ seqs = self._build_sequences(
+ prompts,
+ sampling_params=sampling_params,
+ per_sequence_compression_params=per_sequence_compression_params,
+ tokenizer_kwargs=tokenizer_kwargs,
+ )
+ self.master_model_runner.generate(seqs, batch_compression_params)
+ output_strings = self.detokenize_prompt(seqs, **detokenizer_kwargs)
+ if return_sequences:
+ return output_strings, seqs
+ return output_strings
+
+ def generate_chat(
+ self,
+ messages_batch: List[List[dict]],
+ sampling_params: SamplingParams | List[SamplingParams],
+ batch_compression_params: BatchCompressionParams,
+ per_sequence_compression_params: Union[
+ SequenceCompressionParams, List[SequenceCompressionParams]
+ ],
+ *,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ detokenizer_kwargs: Optional[dict[str, Any]] = None,
+ return_sequences: bool = False,
+ ) -> List[str] | tuple[List[str], List[Sequence]]:
+ """
+ Convenience API for chat-style prompts using HF `apply_chat_template`.
+ Args:
+ :param messages_batch:
+ List of conversations, where each conversation is a list of
+ message dicts like:
+ {"role": "system" | "user" | "assistant", "content": str}
+ :param sampling_params:
+ A single SamplingParams for all prompts in this batch or a list of
+ SamplingParams with the same length as ``prompts``.
+ :param batch_compression_params:
+ Batch Level compression settings. Can set compression_method.
+ :param per_sequence_compression_params:
+ Per-sequence compression parameters, including the compression
+ ratio to be applied and the size of the protected regions of the
+ sequence (how many start tokens and end tokens to keep uncompressed).
+ If a SequenceCompressionParams instance, the same params will be
+ applied to all sequences in this batch; if a list is provided,
+ each SequenceCompressionParams will be attached to the corresponding
+ conversation in the batch.
+ :param tokenizer_kwargs:
+ Passed through to `tokenizer.apply_chat_template`.
+ :param detokenizer_kwargs:
+ Passed through to `tokenizer.batch_decode`.
+ :param return_sequences:
+ Whether to return sequence objects or not
+ Returns:
+ :return List[str] or tuple[List[str], List[Sequence]]:
+ One string per conversation.
+ """
+ prompts_token_ids: List[List[int]] = []
+ tokenizer_kwargs = _merge_apply_chat_template_kwargs(
+ self.tokenizer, tokenizer_kwargs
+ )
+ detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
+ for messages in messages_batch:
+ input_ids = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=True,
+ **tokenizer_kwargs,
+ )
+ if hasattr(input_ids, "tolist"):
+ input_ids = input_ids.tolist()
+ prompts_token_ids.append(input_ids)
+
+ return self.generate(
+ prompts_token_ids,
+ sampling_params=sampling_params,
+ batch_compression_params=batch_compression_params,
+ per_sequence_compression_params=per_sequence_compression_params,
+ tokenizer_kwargs=tokenizer_kwargs,
+ detokenizer_kwargs=detokenizer_kwargs,
+ return_sequences=return_sequences,
+ )
+
+ def generate_from_sequences(
+ self,
+ seqs: List[Sequence],
+ batch_compression_params: BatchCompressionParams,
+ ) -> List[Sequence]:
+ """
+ Args:
+ :param seqs:
+ List of Sequence instances
+ :param batch_compression_params:
+ Compression settings.
+
+ Returns:
+ :return List[Sequence]:
+ Same list, mutated in-place with completions.
+ """
+ self.master_model_runner.generate(seqs, batch_compression_params)
+ return seqs
+
diff --git a/vllm/kvprune/core/memory_manager.py b/vllm/kvprune/core/memory_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd3ee2ce1abe60a857e488b53e4d18cef20e4663
--- /dev/null
+++ b/vllm/kvprune/core/memory_manager.py
@@ -0,0 +1,237 @@
+import logging
+import os
+from typing import Iterable, List, Optional
+
+import torch
+from vllm.kvprune.config.engine_config import LLMConfig
+from vllm.kvprune.kv_cache.page_table import KVAllocationStatus, PagedKVCache
+from vllm.kvprune.utils.tp_utils import kv_heads_shard_divisor
+from torch import nn
+
+logger = logging.getLogger(__name__)
+
+
+class KVCacheManager:
+ def __init__(
+ self,
+ rank: int,
+ config: LLMConfig,
+ *,
+ device: str | None = None,
+ ):
+ super().__init__()
+ hf_config = config.hf_config
+ self.rank = rank
+ self.gpu_frac = config.gpu_memory_utilization
+ self.page_size = config.kvcache_page_size
+ self.world_size = config.tensor_parallel_size
+ self.max_num_batches = config.max_num_seqs
+ self.max_model_len = config.max_model_len
+ self.num_layers = hf_config.num_hidden_layers
+ self.model_dtype = hf_config.torch_dtype
+ self.head_dim = getattr(hf_config, "head_dim", None)
+ self.max_pages_per_batch = (
+ self.max_model_len + self.page_size - 1
+ ) // self.page_size
+ _ws = kv_heads_shard_divisor()
+ self.num_kv_heads = hf_config.num_key_value_heads // _ws
+ assert hf_config.num_key_value_heads % _ws == 0, (
+ "tensor-parallel world size needs to divide num_kv_heads"
+ )
+ self._cache_device = device if device is not None else f"cuda:{self.rank}"
+
+ self.num_pages = None
+ self.paged_cache: Optional[PagedKVCache] = None
+ self.max_batched_tokens = None
+
+ self.seq_id_to_batch = {}
+
+ def allocate_sequences(
+ self, seq_ids: List[int], max_positions: List[int]
+ ) -> (bool, Optional[torch.Tensor]):
+ batch_mapping = []
+ for seq_id, len_to_alloc in zip(seq_ids, max_positions):
+ if seq_id not in self.seq_id_to_batch:
+ batch_id = self.paged_cache.new_batch()
+ if batch_id is None:
+ logger.warning("Failed to allocate batch!")
+ return False, None
+ self.seq_id_to_batch[seq_id] = int(batch_id)
+ batch_mapping.append(self.seq_id_to_batch[seq_id])
+ if (
+ alloc_status := self.paged_cache.reserve_tokens(
+ self.seq_id_to_batch[seq_id], len_to_alloc
+ )
+ ) != KVAllocationStatus.SUCCESS:
+ logger.warning(f"Failed to allocate pages ({alloc_status})!")
+ return False, None
+ batch_mapping = torch.as_tensor(batch_mapping, dtype=torch.int32, device="cuda")
+ return True, batch_mapping
+
+ def free_sequences(self, seq_ids: Iterable[int]):
+ for seq_id in seq_ids:
+ global_batch_id = self.seq_id_to_batch.pop(seq_id, None)
+ self.paged_cache.free_batch(global_batch_id)
+
+ def init_cache(self, model: nn.Module):
+ self.num_pages = self.get_num_pages(self.gpu_frac, self.max_pages_per_batch)
+ self.paged_cache = PagedKVCache(
+ num_layers=self.num_layers,
+ H_kv=self.num_kv_heads,
+ head_dim=self.head_dim,
+ page_size=self.page_size,
+ num_pages=int(self.num_pages),
+ max_num_batches=self.max_num_batches,
+ device=self._cache_device,
+ dtype=self.model_dtype,
+ max_logical_pages_per_head=int(self.max_pages_per_batch),
+ )
+ self._assign_cache_to_layers(model)
+
+ def _assign_cache_to_layers(self, model) -> None:
+ for layer_index, layer in enumerate(model.model.layers):
+ attn = layer.self_attn.attn
+ k, v, pt, bh = self.paged_cache.layer_slices(layer_index)
+ attn.k_cache = k
+ attn.v_cache = v
+ attn.page_table = pt
+ attn.bh_seq_lens = bh
+ attn.page_size = self.page_size
+
+ def get_num_pages(self, frac: float, n_logical_pages_max: int):
+ free, total = torch.cuda.mem_get_info()
+ used = total - free
+ stats = torch.cuda.memory_stats()
+ peak = int(stats["allocated_bytes.all.peak"])
+ current = int(stats["allocated_bytes.all.current"])
+ bytes_for_kv_budget = int(total * frac * 0.9) - used - peak + current
+
+ if bytes_for_kv_budget <= 0:
+ # Standalone compactor: ``frac`` is a fraction of total VRAM. When a second
+ # engine shares the GPU with vLLM (shared weights), most VRAM is already
+ # committed; the formula above goes negative. Fall back to a slice of
+ # *currently free* memory for the compactor KV pool.
+ free_frac = float(
+ os.environ.get("VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC", "0.55")
+ )
+ free_frac = max(0.05, min(free_frac, 0.95))
+ bytes_for_kv_budget = int(free * free_frac)
+ logger.warning(
+ "KV cache budget from gpu_memory_utilization (%.2f) is exhausted "
+ "(%.2f MiB free on device); using %.0f%% of free memory (~%.2f MiB) "
+ "for compactor KV (set VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC to adjust).",
+ frac,
+ free / (1024**2),
+ free_frac * 100,
+ bytes_for_kv_budget / (1024**2),
+ )
+ if bytes_for_kv_budget <= 0:
+ raise RuntimeError(
+ "Insufficient memory for compactor KV cache: no free GPU memory left "
+ "after the primary vLLM engine. Lower vLLM gpu_memory_utilization or "
+ "max_model_len, shorten prompts, or run compactor-only / vLLM-only "
+ "sessions. Raising gpu_memory_utilization here does not help."
+ )
+ # page_table[L, B, H_kv, N_LOGICAL_PAGES_MAX] + bh_seq_lens[L, B, H_kv]
+ int32_sz = torch.empty((), dtype=torch.int32).element_size() # 4
+ page_table_bytes_per_layer = (
+ self.max_num_batches
+ * self.num_kv_heads
+ * n_logical_pages_max
+ * int32_sz # page_table
+ + self.max_num_batches * self.num_kv_heads * int32_sz
+ )
+ total_page_table_bytes = self.num_layers * page_table_bytes_per_layer
+ kv_bytes_net = bytes_for_kv_budget - total_page_table_bytes
+ if kv_bytes_net <= 0:
+ # Tight VRAM: metadata alone can exceed the first budget; reserve page
+ # tables plus a slice of remaining free for KV tensors.
+ bytes_for_kv_budget = min(
+ int(free * 0.95),
+ total_page_table_bytes + max(int(free * 0.25), 8 * 1024 * 1024),
+ )
+ kv_bytes_net = bytes_for_kv_budget - total_page_table_bytes
+ if kv_bytes_net <= 0:
+ raise RuntimeError(
+ "page-table footprint exceeds available GPU memory for compactor KV. "
+ f"Reduce vLLM max_num_seqs (compactor uses {self.max_num_batches}) "
+ f"or max_model_len ({self.max_model_len}), or free GPU memory."
+ )
+ dtype_sz = torch.empty((), dtype=self.model_dtype).element_size()
+ bytes_per_page_across_layers = self.num_layers * (
+ 2 * self.page_size * self.head_dim * dtype_sz
+ )
+ return max(1, kv_bytes_net // bytes_per_page_across_layers)
+
+ def estimate_max_batched_tokens(
+ self,
+ warmup_tokens: int,
+ bytes_used_before_warmup: int,
+ bytes_peak_after_warmup: int,
+ ) -> int:
+ """
+ Estimate the max total number of tokens that can be processed concurrently
+ without OOM.
+ """
+ assert warmup_tokens > 0, "warmup_tokens must be > 0"
+ # activation bytes per token
+ warmup_delta = max(
+ 0, int(bytes_peak_after_warmup) - int(bytes_used_before_warmup)
+ )
+ bytes_per_token = max(1, (warmup_delta + warmup_tokens - 1) // warmup_tokens)
+
+ free, total = torch.cuda.mem_get_info()
+ target = int(total * self.gpu_frac)
+ used_now = int(total - free)
+ # reserve headroom equal to the gap between peak and current allocations seen so far
+ stats = torch.cuda.memory_stats()
+ peak_cur = int(stats.get("allocated_bytes.all.peak", 0))
+ cur_now = int(stats.get("allocated_bytes.all.current", 0))
+ cushion = max(0, peak_cur - cur_now)
+
+ activation_budget = int(max(0, target - used_now - cushion) * 0.95)
+ max_tokens_per_batch = activation_budget // bytes_per_token
+ max_tokens_in_cache = (self.num_pages * self.page_size) // self.num_kv_heads
+ # round to lower multiple of page size
+ max_tokens_per_batch = (max_tokens_per_batch // self.page_size) * self.page_size
+ max_tokens_in_cache = (max_tokens_in_cache // self.page_size) * self.page_size
+
+ # When vLLM shares the same GPU, ``used_now`` often exceeds ``target`` (same
+ # situation as ``get_num_pages``), so activation_budget is ~0 and
+ # ``max_tokens_per_batch`` rounds to 0 or one page. The min(...) would then
+ # cap prefill at ~page_size tokens (e.g. 32) even though the compactor KV pool
+ # is large — no prompt longer than that can be scheduled. Prefer KV capacity
+ # (capped by max_model_len) whenever activation math yields only a token or two.
+ if (
+ max_tokens_in_cache > 0
+ and max_tokens_per_batch <= self.page_size
+ and max_tokens_in_cache > max_tokens_per_batch
+ ):
+ max_tokens_per_batch = min(max_tokens_in_cache, self.max_model_len)
+
+ self.max_batched_tokens = min(max_tokens_in_cache, max_tokens_per_batch)
+ # Last resort: allow at least one page when KV exists but min(...) is still 0.
+ if self.max_batched_tokens == 0 and self.num_pages > 0 and max_tokens_in_cache > 0:
+ self.max_batched_tokens = min(max_tokens_in_cache, self.page_size)
+ return self.max_batched_tokens
+
+ @property
+ def num_free_batches(self) -> int:
+ return len(self.paged_cache.free_batches)
+
+ @property
+ def num_free_pages(self) -> int:
+ return min(len(fp) for fp in self.paged_cache.free_pages)
+
+ def reclaim_pages(
+ self,
+ seq_ids_to_reclaim: Iterable[int],
+ future_reserved_buffer: List[int] | torch.Tensor,
+ ) -> int:
+ approximate_bytes_freed = 0
+ for i, seq_id in enumerate(seq_ids_to_reclaim):
+ batch_idx = self.seq_id_to_batch[seq_id]
+ approximate_bytes_freed += self.paged_cache.reclaim_pages(
+ batch_idx, future_reserved_buffer[i]
+ )
+ return approximate_bytes_freed
diff --git a/vllm/kvprune/core/model_runner.py b/vllm/kvprune/core/model_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..044daca9544483df46bbd5a07707d9dec9c71d72
--- /dev/null
+++ b/vllm/kvprune/core/model_runner.py
@@ -0,0 +1,804 @@
+import atexit
+import logging
+import os
+import inspect
+from typing import Any, List, Optional
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from vllm.kvprune.attention.sparse_decode_kernel import num_splits_heuristic
+from vllm.kvprune.compression.compression_config import BatchCompressionParams
+from vllm.kvprune.config.constants import RESERVED_BATCH
+from vllm.kvprune.config.engine_config import LLMConfig, KvpruneAttentionSchedule
+from vllm.kvprune.core.memory_manager import KVCacheManager
+from vllm.kvprune.core.scheduler import Scheduler
+from vllm.kvprune.layers.sampler import Sampler
+from vllm.kvprune.models import MODEL_REGISTRY
+from vllm.kvprune.utils.arguments import (
+ DecodeBatchArguments,
+ DecodeBatchOutput,
+ PackedTensorArguments,
+ PrefillBatchArguments,
+)
+from vllm.kvprune.utils.context import CompressionContext, reset_context, set_context
+from vllm.kvprune.utils.kv_dist import barrier_sync, broadcast_from_tp_rank0
+from vllm.kvprune.utils.sequence import Sequence
+from torch.multiprocessing import Event
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+class ModelRunner:
+ """Per-rank execution loop. Manages model, sampler, KV cache, and warmup"""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ rank: int,
+ batch_ready: Optional[Event] = None,
+ peer_events: List[Event] = None,
+ external_model: Optional[nn.Module] = None,
+ *,
+ embedded_in_vllm_worker: bool = False,
+ device: Optional[torch.device] = None,
+ ):
+ self.config = config
+ self.embedded_in_vllm_worker = embedded_in_vllm_worker
+ if embedded_in_vllm_worker:
+ from vllm.distributed.parallel_state import (
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ )
+
+ tp_ws = get_tensor_model_parallel_world_size()
+ tp_rank = get_tensor_model_parallel_rank()
+ if tp_ws != config.tensor_parallel_size:
+ raise RuntimeError(
+ f"tensor parallel world size {tp_ws} != "
+ f"LLMConfig.tensor_parallel_size {config.tensor_parallel_size}"
+ )
+ self.rank = tp_rank
+ _dev = device if device is not None else torch.device(
+ f"cuda:{torch.cuda.current_device()}"
+ )
+ if not dist.is_initialized():
+ raise RuntimeError(
+ "embedded_in_vllm_worker requires torch.distributed to be "
+ "initialized (vLLM worker)."
+ )
+ if dist.get_world_size() != tp_ws:
+ raise NotImplementedError(
+ "KV-prune compactor embedded in vLLM currently requires "
+ "dist.get_world_size() == tensor_parallel_size "
+ "(pipeline_parallel_size=1, data_parallel_size=1). "
+ f"Got dist.get_world_size()={dist.get_world_size()}, "
+ f"tp_ws={tp_ws}."
+ )
+ else:
+ self.rank = rank
+ _dev = device if device is not None else torch.device(f"cuda:{rank}")
+
+ self._device = _dev
+ assert config.eos_token_ids is not None and len(config.eos_token_ids) > 0, (
+ "LLMConfig.eos_token_ids must be set (filled in LLMEngine from tokenizer)."
+ )
+ self._stop_token_ids = torch.tensor(
+ config.eos_token_ids, dtype=torch.int64, device=_dev
+ )
+ hf_config = config.hf_config
+ self.enforce_eager = config.enforce_eager
+ if config.attention_schedule == KvpruneAttentionSchedule.PDFA:
+ if not self.enforce_eager and self.rank == 0:
+ logger.info(
+ "attention_schedule=PDFA: disabling compactor decode CUDA graphs "
+ "(FlashAttention decode path)."
+ )
+ self.enforce_eager = True
+ # Embedded in vLLM worker (TP>1): respect :attr:`LLMConfig.enforce_eager` from
+ # ``v1_tp_runner._apply_compactor_env_overrides``. Set
+ # ``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` to force eager if graph replay is unstable
+ # with shared vLLM VRAM / streams / NCCL on your stack.
+ if embedded_in_vllm_worker:
+ _tp_graph = os.environ.get(
+ "VLLM_KVPRUNE_TP_EMBEDDED_GRAPH", "1"
+ ).strip().lower()
+ if _tp_graph in ("0", "false", "no"):
+ if not self.enforce_eager:
+ logger.info(
+ "embedded_in_vllm_worker: VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0 → "
+ "forcing compactor enforce_eager=True (skip compactor CUDA graph "
+ "capture)."
+ )
+ self.enforce_eager = True
+ self.world_size = config.tensor_parallel_size
+ self.leverage_sketch_size = config.leverage_sketch_size
+ self.show_progress_bar = config.show_progress_bar
+ self.max_num_batches = config.max_num_seqs
+ self.max_model_len = config.max_model_len
+ self.num_layers = hf_config.num_hidden_layers
+ self.model_dtype = hf_config.torch_dtype
+ self.head_dim = getattr(hf_config, "head_dim", None)
+
+ init_kwargs = {}
+ if not embedded_in_vllm_worker:
+ if "device_id" in inspect.signature(dist.init_process_group).parameters:
+ init_kwargs["device_id"] = torch.device(f"cuda:{rank}")
+ if not dist.is_initialized():
+ dist.init_process_group(
+ "nccl",
+ f"tcp://localhost:{config.nccl_port}",
+ world_size=self.world_size,
+ rank=rank,
+ **init_kwargs,
+ )
+ else:
+ ws = dist.get_world_size()
+ if ws != self.world_size:
+ raise RuntimeError(
+ "torch.distributed is already initialized with "
+ f"world_size={ws}, but compactor ModelRunner expects "
+ f"tensor_parallel_size={self.world_size}. "
+ "Use tensor_parallel_size matching the active process group "
+ "(typically 1 when sharing weights with vLLM)."
+ )
+ torch.cuda.set_device(_dev)
+ default_dtype = torch.get_default_dtype()
+ torch.set_default_dtype(hf_config.torch_dtype)
+ torch.set_default_device("cuda")
+ model_type = hf_config.model_type
+ if external_model is not None:
+ self.model = external_model
+ else:
+ self.model = MODEL_REGISTRY[model_type](hf_config)
+ self.model.load_model(
+ config.path, use_tqdm=self.is_master and self.show_progress_bar
+ )
+ self.sampler = Sampler()
+
+ pre_warmup_mem = torch.cuda.memory_stats().get("allocated_bytes.all.current", 0)
+ # No paged KV yet: FA-only varlen path (see :meth:`warmup`).
+ self.warmup(num_warmup_tokens=self.max_model_len, with_kv=False)
+ post_warmup_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
+
+ self.kv_manager = KVCacheManager(
+ self.rank, config, device=str(self._device)
+ )
+ self.kv_manager.init_cache(self.model)
+
+ self.store_stream: Optional[torch.cuda.Stream] = torch.cuda.Stream()
+ torch.set_default_device("cpu")
+ torch.set_default_dtype(default_dtype)
+
+ self.batch_ready = batch_ready
+ self.peer_events = peer_events if peer_events is not None else []
+ # Embedded TP peers: session end is signaled via TP-group broadcast in
+ # maybe_release_peers (no multiprocessing.Event — not pickleable over RPC).
+ self._embedded_peer_continue = True
+ self.captured_graphs = {}
+ self.min_captured_len = {}
+ self.max_batched_tokens = self.kv_manager.estimate_max_batched_tokens(
+ self.max_model_len, pre_warmup_mem, post_warmup_peak
+ )
+ if self.is_master:
+ logger.info(f"Estimated max batched tokens of {self.max_batched_tokens}")
+ self.warmup(num_warmup_tokens=self.max_model_len, with_kv=True)
+
+ if not self.enforce_eager:
+ bs = [1 << i for i in range(self.max_num_batches.bit_length())]
+ for bs in (
+ tqdm(bs, desc="Capturing CUDA Graphs")
+ if self.is_master and self.show_progress_bar
+ else bs
+ ):
+ for seq_len in [1024, 4096, 8192, 16384]:
+ self.capture_cudagraph(bs, seq_len)
+
+ if not self.captured_graphs:
+ logger.warning(
+ "No compactor CUDA graphs were captured (KV budget tight or "
+ "allocate_sequences failed during capture). Using eager decode "
+ "for this session."
+ )
+ self.enforce_eager = True
+
+ self.packed_args = PackedTensorArguments(
+ rank=self.rank,
+ max_batched_tokens=self.max_batched_tokens,
+ config=self.config,
+ device=self._device,
+ use_tp_group_for_collectives=embedded_in_vllm_worker,
+ )
+ atexit.register(self.exit)
+
+ @torch.inference_mode()
+ def warmup(self, num_warmup_tokens: int, *, with_kv: bool):
+ sched = (
+ self.config.attention_schedule
+ if with_kv
+ else KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ )
+ if self.rank == 0:
+ logger.info(
+ "Warming up compactor attention (%s KV init): schedule=%s",
+ "after" if with_kv else "before",
+ sched.name,
+ )
+ device = self._device
+ input_ids = torch.tensor(
+ [self.config.eos] * num_warmup_tokens, device=device, dtype=torch.int64
+ )
+ positions = torch.arange(num_warmup_tokens, device=device, dtype=torch.int64)
+ cu_seqlens_q = torch.tensor(
+ [0, num_warmup_tokens], device=device, dtype=torch.int32
+ )
+ cu_seqlens_k = torch.tensor(
+ [0, num_warmup_tokens], device=device, dtype=torch.int32
+ )
+ if with_kv:
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ [-1], [num_warmup_tokens]
+ )
+ assert success
+ max_bh_len = int(
+ self.kv_manager.paged_cache.bh_seq_lens.index_select(
+ 1, index=batch_mapping
+ )
+ .max()
+ .item()
+ )
+ else:
+ batch_mapping = None
+ max_bh_len = 0
+ set_context(
+ is_prefill=True,
+ do_compression=False,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ cu_seqlens_q_host=(0, num_warmup_tokens),
+ cu_seqlens_k_host=(0, num_warmup_tokens),
+ max_seqlen_q=num_warmup_tokens,
+ max_seqlen_k=num_warmup_tokens,
+ batch_mapping=batch_mapping,
+ max_bh_len=max_bh_len,
+ attention_schedule=sched,
+ )
+ for _ in range(2):
+ torch.cuda.reset_peak_memory_stats()
+ h = self.model(input_ids, positions)
+ self.model.compute_logits(h)
+ barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
+ if with_kv:
+ self.kv_manager.paged_cache.bh_seq_lens.index_fill_(
+ 1, batch_mapping.to(torch.long), 0
+ )
+ reset_context()
+ if with_kv:
+ self.kv_manager.free_sequences([-1])
+
+ def exit(self):
+ if getattr(self, "_exited", False):
+ return
+ self._exited = True
+ try:
+ if hasattr(self, "captured_graphs"):
+ self.captured_graphs.clear()
+ finally:
+ if getattr(self, "embedded_in_vllm_worker", False):
+ return
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+ def loop(self):
+ while True:
+ if self.batch_ready.wait(1.0):
+ self._process_batches_peer()
+
+ @torch.inference_mode()
+ def run_prefill(
+ self, prefill_args: PrefillBatchArguments, batch_mapping: torch.Tensor
+ ):
+ assert prefill_args.B > 0 and prefill_args.N > 0
+ max_bh_len = (
+ self.kv_manager.paged_cache.bh_seq_lens.index_select(1, index=batch_mapping)
+ .max()
+ .item()
+ )
+ compression_context = CompressionContext(
+ compression_method=prefill_args.compression_method,
+ compression_chunk_size=prefill_args.compression_chunk_size,
+ batch_tokens_to_retain=prefill_args.batch_tokens_to_retain,
+ max_tokens_to_retain=prefill_args.max_tokens_to_retain,
+ context_lens=prefill_args.context_lens.tolist(),
+ PHI=prefill_args.PHI,
+ sketch_dimension=self.leverage_sketch_size,
+ protected_first_tokens=prefill_args.protected_first,
+ protected_last_tokens=prefill_args.protected_last,
+ compression_ratio=prefill_args.compression_ratio,
+ )
+ cu_q_host = tuple(
+ int(x) for x in prefill_args.cu_seqlens_q.detach().cpu().view(-1).tolist()
+ )
+ cu_k_host = tuple(
+ int(x) for x in prefill_args.cu_seqlens_k.detach().cpu().view(-1).tolist()
+ )
+ set_context(
+ is_prefill=True,
+ do_compression=prefill_args.do_compression,
+ cu_seqlens_q=prefill_args.cu_seqlens_q,
+ cu_seqlens_k=prefill_args.cu_seqlens_k,
+ cu_seqlens_q_host=cu_q_host,
+ cu_seqlens_k_host=cu_k_host,
+ max_seqlen_q=prefill_args.max_seqlen_q,
+ max_seqlen_k=prefill_args.max_seqlen_k,
+ batch_mapping=batch_mapping,
+ max_bh_len=max_bh_len,
+ compression_context=compression_context,
+ STORE_STREAM=self.store_stream,
+ attention_schedule=self.config.attention_schedule,
+ )
+ # int32 token ids break vLLM-delegated embedding (expects long indices) on some paths.
+ _iid = (
+ prefill_args.input_ids
+ if prefill_args.input_ids.dtype == torch.int64
+ else prefill_args.input_ids.long()
+ )
+ _pos = (
+ prefill_args.positions
+ if prefill_args.positions.dtype == torch.int64
+ else prefill_args.positions.long()
+ )
+ hidden = self.model(_iid, _pos)
+ logits = self.model.compute_logits(hidden)
+ reset_context()
+ return logits
+
+ def maybe_broadcast(self, tensor: torch.Tensor, *, label: str = "tensor") -> None:
+ if self.world_size > 1:
+ broadcast_from_tp_rank0(
+ tensor, use_tp_group=self.embedded_in_vllm_worker
+ )
+ return None
+
+ def maybe_release_peers(self, do_release=False):
+ if self.world_size <= 1:
+ return
+ if self.embedded_in_vllm_worker:
+ flag = torch.zeros(1, dtype=torch.int32, device=self._device)
+ if self.is_master:
+ flag[0] = 0 if do_release else 1
+ broadcast_from_tp_rank0(flag, use_tp_group=True)
+ if not self.is_master:
+ self._embedded_peer_continue = bool(flag[0].item())
+ barrier_sync(use_tp_group=True)
+ return
+ if self.is_master:
+ if do_release:
+ for event in self.peer_events:
+ event.clear()
+ barrier_sync(use_tp_group=False)
+ else:
+ barrier_sync(use_tp_group=False)
+
+ def _peer_outer_loop_active(self) -> bool:
+ if self.batch_ready is not None:
+ return self.batch_ready.is_set()
+ if self.embedded_in_vllm_worker:
+ return self._embedded_peer_continue
+ return False
+
+ @torch.inference_mode()
+ def generate(
+ self,
+ all_sequences: List[Sequence],
+ batch_compression_params: Optional[BatchCompressionParams] = None,
+ ):
+ assert self.is_master, "generate can only be called on the master process"
+ if not self.embedded_in_vllm_worker:
+ for begin_execution_event in self.peer_events:
+ begin_execution_event.set()
+ if batch_compression_params is None:
+ batch_compression_params = BatchCompressionParams()
+ self._process_batches_master(all_sequences, batch_compression_params)
+
+ @property
+ def is_master(self):
+ return self.rank == 0
+
+ @torch.inference_mode()
+ def _process_batches_master(
+ self,
+ all_sequences: List[Sequence],
+ batch_compression_params: BatchCompressionParams,
+ ):
+ assert self.is_master
+ compression_details = f"Applying Compression Method: {batch_compression_params.compression_method}"
+ if any(seq.compression_params.compression_ratio < 1.0 for seq in all_sequences):
+ logger.info(compression_details)
+ scheduler = Scheduler(
+ all_sequences=all_sequences,
+ kv_manager=self.kv_manager,
+ use_tqdm=self.show_progress_bar,
+ )
+ decode_batch = DecodeBatchArguments()
+ decode_flags = torch.empty(2, dtype=torch.int32, device=self._device)
+ while not scheduler.is_finished():
+ sequences = scheduler.get_prefill_batch()
+ if not sequences:
+ if scheduler.pending_sequence_ids:
+ raise RuntimeError(
+ "KV-prune compactor cannot schedule any prefill (KV/token budget). "
+ f"max_batched_tokens={self.kv_manager.max_batched_tokens}, "
+ f"pending_sequences={len(scheduler.pending_sequence_ids)}. "
+ "Lower v1 gpu_memory_utilization / max_model_len, set "
+ "VLLM_KVPRUNE_RELEASE_V1_KV=1 to discard v1 KV (sleep+wake), "
+ "or free GPU memory. Diagnostics: "
+ f"{scheduler.diagnose_prefill_failure()}"
+ )
+ # Pending is empty: either finished or decode-only continuation.
+ if decode_batch.token_ids is None:
+ break
+ run_decode = True
+ occupancy = -1
+ else:
+ seq_ids_cpu = [seq.seq_id for seq in sequences]
+ scheduler.add_running_sequence_ids(seq_ids_cpu, update_status=True)
+ temps = torch.tensor(
+ [s.sampling_params.temperature for s in sequences],
+ dtype=torch.float32,
+ pin_memory=True,
+ ).to(device=self._device, non_blocking=True)
+ prefill_arguments = self.packed_args.build_prefill_args(
+ sequences, batch_compression_params=batch_compression_params
+ )
+ max_ctx_lens = (
+ prefill_arguments.max_new_tokens + prefill_arguments.context_lens
+ )
+
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ seq_ids_cpu, max_ctx_lens.tolist()
+ )
+ assert success, "failed to allocate pages for sequences"
+
+ logits = self.run_prefill(prefill_arguments, batch_mapping)
+ # Must match prefill `positions` dtype (int64). `context_lens` is int32
+ # from the packed buffer; using int32 here breaks RoPE indexing
+ # (`cos_sin_cache[positions]`) on CUDA for decode vs prefill.
+ positions = prefill_arguments.context_lens.to(dtype=torch.int64)
+ token_ids = self.sampler(logits, temps)
+ # Prefill KV writes + bh_seq_lens updates run on STORE_STREAM; reclaim
+ # reads bh_seq_lens on the default stream and must not race.
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+ # TODO: synchronize page counts accross dist
+ if self.world_size == 1:
+ self.kv_manager.reclaim_pages(
+ seq_ids_cpu, prefill_arguments.max_new_tokens
+ )
+ # with logging_redirect_tqdm():
+ # logger.info(
+ # f"Reclaimed {reclaimed_bytes / 1e6:.2f} MB from the KV cache"
+ # )
+
+ if scheduler.any_pending_sequences():
+ num_pending_batches = (
+ 0
+ if decode_batch.token_ids is None
+ else decode_batch.token_ids.shape[0]
+ )
+ occupancy = int((num_pending_batches + len(seq_ids_cpu)) * 0.66)
+ else:
+ occupancy = -1
+ run_decode = not scheduler.can_prefill_another_batch()
+ decode_batch = decode_batch.update(
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ prefill_arguments.seq_ids,
+ temps,
+ occupancy,
+ )
+ if self.world_size > 1:
+ decode_flags[0] = int(run_decode)
+ decode_flags[1] = occupancy
+ self.maybe_broadcast(decode_flags, label="decode_flags")
+ if not run_decode:
+ continue
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+
+ decode_output, decode_batch = self.run_decode_loop(decode_batch)
+ finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
+ decode_batch.seq_ids.tolist()
+ )
+ scheduler.record_finished_sequence_ids(
+ finished_sequence_ids, update_status=True
+ )
+ self.kv_manager.free_sequences(finished_sequence_ids)
+ self.maybe_release_peers(scheduler.is_finished())
+ scheduler.update_sequences(
+ decode_output.output_tokens.tolist(),
+ decode_output.output_seq_ids.tolist(),
+ )
+ scheduler.close()
+
+ @torch.inference_mode()
+ def run_peer_session(self) -> None:
+ """Non-master TP ranks: run one peer session (used when embedded in vLLM)."""
+ if self.embedded_in_vllm_worker:
+ self._embedded_peer_continue = True
+ self._process_batches_peer()
+
+ @torch.inference_mode()
+ def _process_batches_peer(self):
+ assert not self.is_master
+ scheduler = Scheduler([], kv_manager=self.kv_manager)
+ decode_batch = DecodeBatchArguments()
+ decode_flags = torch.empty(2, dtype=torch.int32, device=self._device)
+ while self._peer_outer_loop_active():
+ prefill_arguments = self.packed_args.build_prefill_args()
+
+ B = prefill_arguments.B
+ max_ctx_lens = (
+ prefill_arguments.max_new_tokens + prefill_arguments.context_lens
+ )
+
+ seq_ids_cpu = prefill_arguments.seq_ids.tolist()
+ scheduler.add_running_sequence_ids(seq_ids_cpu)
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ seq_ids_cpu, max_ctx_lens.tolist()
+ )
+ assert success, "failed to allocate pages for sequences"
+
+ self.run_prefill(prefill_arguments, batch_mapping)
+ positions = prefill_arguments.context_lens.to(dtype=torch.int64)
+ self.maybe_broadcast(decode_flags, label="decode_flags")
+ run_decode = bool(decode_flags[0].item())
+ occupancy = int(decode_flags[1].item())
+ token_ids = torch.empty(B, dtype=torch.int64, device=self._device)
+ decode_batch = decode_batch.update(
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ prefill_arguments.seq_ids,
+ None, # temps not used in peer process
+ occupancy,
+ )
+
+ if not run_decode:
+ continue
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+
+ _, decode_batch = self.run_decode_loop(decode_batch)
+ finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
+ decode_batch.seq_ids.tolist()
+ )
+ scheduler.record_finished_sequence_ids(finished_sequence_ids)
+ self.kv_manager.free_sequences(finished_sequence_ids)
+ self.maybe_release_peers()
+ scheduler.close()
+
+ @torch.inference_mode()
+ def run_decode_loop(
+ self,
+ decode_batch: DecodeBatchArguments,
+ ) -> tuple[DecodeBatchOutput, DecodeBatchArguments]:
+ if self.is_master:
+ num_stashed_batches = decode_batch.num_stashed_batches
+ tok_buffer = [
+ decode_batch.token_ids[num_stashed_batches:].to(
+ "cpu", non_blocking=True
+ )
+ ]
+ seq_buffer = [
+ decode_batch.seq_ids[num_stashed_batches:].to("cpu", non_blocking=True)
+ ]
+ while True:
+ self.maybe_broadcast(decode_batch.token_ids, label="decode_token_ids")
+ not_stopped = ~torch.isin(decode_batch.token_ids, self._stop_token_ids)
+ running_batches = (decode_batch.positions < decode_batch.max_ctx_lens) & (
+ not_stopped
+ )
+ decode_batch.token_ids = torch.masked_select(
+ decode_batch.token_ids, running_batches
+ )
+ decode_batch.positions = torch.masked_select(
+ decode_batch.positions, running_batches
+ )
+ decode_batch.batch_mapping = torch.masked_select(
+ decode_batch.batch_mapping, running_batches
+ )
+ decode_batch.max_ctx_lens = torch.masked_select(
+ decode_batch.max_ctx_lens, running_batches
+ )
+ decode_batch.seq_ids = torch.masked_select(
+ decode_batch.seq_ids, running_batches
+ )
+ if self.is_master:
+ decode_batch.temps = torch.masked_select(
+ decode_batch.temps, running_batches
+ )
+ num_remaining = decode_batch.token_ids.numel()
+ if (
+ num_remaining == 0
+ or num_remaining <= decode_batch.desired_batch_occupancy
+ ):
+ decode_batch.num_stashed_batches = num_remaining
+ break
+ logits = self._decode_step_logits(decode_batch)
+
+ if self.is_master:
+ decode_batch.token_ids = self.sampler(logits, decode_batch.temps)
+ tok_buffer.append(decode_batch.token_ids.to("cpu", non_blocking=True))
+ seq_buffer.append(decode_batch.seq_ids.to("cpu", non_blocking=True))
+ decode_batch.positions += 1
+
+ if self.is_master:
+ # non_blocking D2H copies must finish before cat/tolist read CPU data.
+ torch.cuda.synchronize()
+ output = DecodeBatchOutput(
+ output_tokens=torch.cat(tok_buffer),
+ output_seq_ids=torch.cat(seq_buffer),
+ )
+ else:
+ output = DecodeBatchOutput(None, None)
+ return output, decode_batch
+
+ def _decode_logits_eager(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ ):
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=batch_mapping,
+ attention_schedule=self.config.attention_schedule,
+ )
+ _iid = input_ids if input_ids.dtype == torch.int64 else input_ids.long()
+ _pos = positions if positions.dtype == torch.int64 else positions.long()
+ hidden = self.model(_iid, _pos)
+ return self.model.compute_logits(hidden)
+
+ @torch.inference_mode()
+ def _decode_step_logits(self, decode_batch: DecodeBatchArguments):
+ """Graph decode when possible; otherwise eager (never raises on missing graph)."""
+ if self.enforce_eager or not self.captured_graphs:
+ return self._decode_logits_eager(
+ decode_batch.token_ids,
+ decode_batch.positions,
+ decode_batch.batch_mapping,
+ )
+ try:
+ return self.run_graph_decode(
+ decode_batch.token_ids,
+ decode_batch.positions,
+ decode_batch.batch_mapping,
+ )
+ except Exception as e:
+ logger.warning(
+ "CUDA graph decode failed (%s); switching to eager decode for "
+ "remaining steps.",
+ e,
+ )
+ self.enforce_eager = True
+ return self._decode_logits_eager(
+ decode_batch.token_ids,
+ decode_batch.positions,
+ decode_batch.batch_mapping,
+ )
+
+ @torch.inference_mode()
+ def run_graph_decode(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ ):
+ bs = input_ids.shape[0]
+ max_k = int(positions.max())
+ graph_dict = self.get_cuda_graph(bs, max_k)
+ if graph_dict is None:
+ return self._decode_logits_eager(input_ids, positions, batch_mapping)
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=batch_mapping,
+ attention_schedule=self.config.attention_schedule,
+ )
+ graph_dict["input_ids"][:bs] = input_ids
+ graph_dict["positions"][:bs] = positions
+ graph_dict["batch_mapping"].fill_(RESERVED_BATCH)
+ graph_dict["batch_mapping"][:bs] = batch_mapping
+ graph_dict["graph"].replay()
+ logits_out = graph_dict["logits"]
+ return logits_out[:bs].contiguous()
+
+ @torch.inference_mode()
+ def capture_cudagraph(self, batch_size: int, max_seqlen_k: int):
+ barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
+ device = torch.device("cuda")
+ logger.debug(
+ f"Capturing CUDA graph for batch size {batch_size} ({max_seqlen_k} tokens)"
+ )
+ _g_input_ids = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ _g_positions = torch.zeros(batch_size, dtype=torch.int64, device=device)
+ _g_hidden = None
+ key_split = num_splits_heuristic(
+ batch_size * self.kv_manager.num_kv_heads,
+ max_seq_len=max_seqlen_k,
+ num_sms=torch.cuda.get_device_properties(device).multi_processor_count,
+ max_splits=12,
+ )
+
+ success, _g_batch_mapping = self.kv_manager.allocate_sequences(
+ list(range(batch_size)), [256] * batch_size
+ )
+ if not success:
+ # Shared GPU with vLLM: compactor KV pool is small; large batch capture
+ # often cannot reserve [256]*batch_size per sequence. Skip this graph.
+ logger.warning(
+ "Skipping CUDA graph capture for batch_size=%s max_seqlen_k=%s "
+ "(KV allocate_sequences failed; decode will use eager or other graphs).",
+ batch_size,
+ max_seqlen_k,
+ )
+ barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
+ return
+
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=_g_batch_mapping,
+ key_split=key_split,
+ attention_schedule=self.config.attention_schedule,
+ )
+ _gw = self.model(_g_input_ids, _g_positions)
+ self.model.compute_logits(_gw)
+ barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
+ decode_graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(decode_graph):
+ _g_hidden = self.model(_g_input_ids, _g_positions)
+ _g_logits = self.model.compute_logits(_g_hidden)
+ graph_vars = {
+ "graph": decode_graph,
+ "input_ids": _g_input_ids,
+ "positions": _g_positions,
+ "batch_mapping": _g_batch_mapping,
+ "hidden": _g_hidden,
+ "logits": _g_logits,
+ "key_split": key_split,
+ }
+ if batch_size not in self.captured_graphs:
+ self.captured_graphs[batch_size] = {}
+ self.min_captured_len[batch_size] = float("inf")
+
+ self.captured_graphs[batch_size][max_seqlen_k] = graph_vars
+ self.min_captured_len[batch_size] = min(
+ max_seqlen_k, self.min_captured_len[batch_size]
+ )
+ self.kv_manager.free_sequences(list(range(batch_size)))
+
+ def get_cuda_graph(
+ self, batch_size: int, max_seqlen_k: int
+ ) -> Optional[dict[str, Any]]:
+ """Return a captured graph dict, or None if no compatible capture exists."""
+ if not self.captured_graphs:
+ return None
+ eligible_bs = [x for x in self.captured_graphs.keys() if x >= batch_size]
+ if not eligible_bs:
+ return None
+ bs_key = min(eligible_bs)
+ batch_size_graphs = self.captured_graphs[bs_key]
+ candidates = [sl for sl in batch_size_graphs.keys() if sl <= max_seqlen_k]
+ if not candidates:
+ return None
+ best_sl = max(candidates)
+ return batch_size_graphs[best_sl]
+
+
diff --git a/vllm/kvprune/core/scheduler.py b/vllm/kvprune/core/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..32365a3f8ce584f45bdb0d636a4a5ab41bd2ff93
--- /dev/null
+++ b/vllm/kvprune/core/scheduler.py
@@ -0,0 +1,215 @@
+import time
+from typing import Iterable, List
+
+from vllm.kvprune.core.memory_manager import KVCacheManager
+from vllm.kvprune.utils.sequence import Sequence, SequenceStatus
+from tqdm import tqdm
+
+
+def cdiv(a, b):
+ """ceiling division"""
+ return (a + b - 1) // b
+
+
+class Scheduler:
+ """
+ Simple sequence scheduler for prefill + decode with a paged KV cache.
+ The scheduler tracks three disjoint sets of sequence IDs:
+
+ * ``pending_sequence_ids`` 鈥?sequences that have not yet been started.
+ * ``active_sequence_ids`` 鈥?sequences currently running.
+ * ``finished_sequence_ids`` 鈥?sequences that have generated all tokens.
+
+ At prefill time, :meth:`get_prefill_batch` selects a subset of pending
+ sequences that can fit into the available KV cache and per-step token
+ budget, given the constraints from the associated :class:`KVCacheManager`.
+
+ The class also handles basic bookkeeping of sequence statuses.
+
+ Args:
+ :param all_sequences:
+ Iterable of :class:`Sequence` objects to be scheduled. Each
+ sequence must have a unique ``seq_id``.
+ :param kv_manager:
+ A :class:`KVCacheManager` instance that this scheduler will use
+ to determine whether additional batches can be scheduled.
+ :param use_tqdm:
+ If True, two progress bars are created:
+ * "Started Batches" 鈥?increments when a sequence moves from
+ pending to running.
+ * "Finished Batches" 鈥?increments when a sequence finishes.
+ """
+
+ def __init__(
+ self,
+ all_sequences: Iterable[Sequence],
+ kv_manager: KVCacheManager,
+ *,
+ use_tqdm=False,
+ ):
+ self.allseq_mapping: dict[int, Sequence] = {s.seq_id: s for s in all_sequences}
+ self.pending_sequence_ids: set[int] = set([s.seq_id for s in all_sequences])
+ self.active_sequence_ids: set[int] = set()
+ self.finished_sequence_ids: set[int] = set()
+ self.manager = kv_manager
+ self.use_tqdm = use_tqdm
+ self.start_time = time.perf_counter()
+ self.total_tokens_generated = 0
+ self.total_tokens_input = 0
+ self.pbar = None
+ if use_tqdm:
+ self.pbar = tqdm(
+ total=len(self.pending_sequence_ids),
+ desc="Completed Batches",
+ )
+
+ def get_prefill_batch(self) -> List[Sequence]:
+ """
+ Select a batch of pending sequences to prefill under KV/memory constraints.
+
+ The selection is greedy over ``pending_sequence_ids`` in iteration order.
+ A sequence is added to the batch if:
+
+ * The sum of its prompt length and the total prompt tokens selected so
+ far does not exceed ``manager.max_batched_tokens``, and
+ * There is at least one free KV "batch slot" left
+ (``manager.num_free_batches``), and
+ * The total number of KV pages required by the sequence's prompt +
+ max_new_tokens does not exceed the remaining free pages.
+ Returns:
+ :return List[Sequence]:
+ The list of :class:`Sequence` objects chosen for prefill in
+ this step. The caller is responsible for marking them as
+ active via :meth:`add_running_sequence_ids`.
+ """
+ total_tok, sequences = 0, []
+ num_free_batches, num_free_pages = (
+ self.manager.num_free_batches,
+ self.manager.num_free_pages,
+ )
+ for seq_id in self.pending_sequence_ids:
+ seq = self.allseq_mapping[seq_id]
+ prompt_length = seq.prompt_len
+ pages_needed = (
+ cdiv(
+ prompt_length + seq.sampling_params.max_new_tokens,
+ self.manager.page_size,
+ )
+ * self.manager.num_kv_heads
+ )
+ if (
+ prompt_length + total_tok <= self.manager.max_batched_tokens
+ and num_free_batches > 0
+ and pages_needed < num_free_pages
+ ):
+ sequences.append(seq)
+ total_tok += prompt_length
+ num_free_pages -= pages_needed
+ num_free_batches -= 1
+ return sequences
+
+ def is_finished(self) -> bool:
+ """
+ Check whether all sequences have completed.
+ """
+ return (
+ len(self.pending_sequence_ids) == 0 and len(self.active_sequence_ids) == 0
+ )
+
+ def any_pending_sequences(self) -> bool:
+ """
+ Check whether any sequences are still pending (not yet started).
+ """
+ return len(self.pending_sequence_ids) != 0
+
+ def add_running_sequence_ids(
+ self, active_sequence_ids: Iterable[int], *, update_status: bool = False
+ ):
+ """
+ Mark a set of sequences as active / running. This moves sequence IDs
+ from ``pending_sequence_ids`` into ``active_sequence_ids``. Optionally,
+ it also updates the per-sequence status and progress bar.
+
+ Args:
+ :param active_sequence_ids:
+ Iterable of sequence IDs that have been scheduled for prefill
+ or decode and should now be considered running.
+ :param update_status:
+ If True, set each corresponding :class:`Sequence`'s
+ ``status = SequenceStatus.RUNNING`` and increment the
+ "Started Batches" progress bar if ``use_tqdm`` is enabled.
+ """
+ self.active_sequence_ids.update(active_sequence_ids)
+ self.pending_sequence_ids.difference_update(self.active_sequence_ids)
+ if update_status:
+ for seq_id in active_sequence_ids:
+ self.allseq_mapping[seq_id].status = SequenceStatus.RUNNING
+ self.total_tokens_input += self.allseq_mapping[seq_id].prompt_len
+
+ def get_finished_sequence_ids_from_unfinished(
+ self, unfinished_sequence_ids: Iterable[int]
+ ) -> set[int]:
+ """
+ Infer which active sequences have finished given the
+ unfinished set (for decode steps where the caller knows
+ which sequences are still generating but not necessarily
+ which have just completed).
+ Args:
+ :param unfinished_sequence_ids:
+ Iterable of sequence IDs that are still running
+ Returns:
+ :return set[int]:
+ The inferred set of sequence IDs that transitioned from active
+ to finished.
+ """
+ return self.active_sequence_ids.difference(unfinished_sequence_ids)
+
+ def record_finished_sequence_ids(
+ self, finished_sequence_ids: Iterable[int], *, update_status: bool = False
+ ):
+ """
+ Record that a set of sequences has finished generation.
+
+ This moves IDs from ``active_sequence_ids`` into
+ ``finished_sequence_ids``.
+
+ Args:
+ :param finished_sequence_ids:
+ Iterable of sequence IDs that have completed generation and
+ no longer require KV cache.
+ :param update_status:
+ If True, set each corresponding :class:`Sequence`'s
+ ``status = SequenceStatus.FINISHED``
+ """
+ self.active_sequence_ids.difference_update(finished_sequence_ids)
+ self.finished_sequence_ids.update(finished_sequence_ids)
+ if update_status:
+ for seq_id in finished_sequence_ids:
+ self.allseq_mapping[seq_id].status = SequenceStatus.FINISHED
+ if self.pbar is not None:
+ self.pbar.update(1)
+
+ def update_sequences(self, tokens: Iterable[int], seq_ids: Iterable[int]):
+ """
+ Append newly generated tokens to their corresponding sequences.
+ Args:
+ :param tokens:
+ Iterable of generated token IDs, one per sequence.
+ :param seq_ids:
+ Iterable of sequence IDs aligned with ``tokens``.
+ """
+ cur_time = time.perf_counter()
+ for tok, seq_id in zip(tokens, seq_ids):
+ self.allseq_mapping[seq_id].add_new_token(tok)
+ self.total_tokens_generated += 1
+ if self.pbar is not None:
+ self.pbar.set_description(
+ f"Throughput: {(self.total_tokens_generated + self.total_tokens_input) / (cur_time - self.start_time):.2f} tok/s"
+ )
+
+ def close(self):
+ if self.pbar is not None:
+ self.pbar.close()
+
+ def can_prefill_another_batch(self) -> bool:
+ return len(self.get_prefill_batch()) > 0
diff --git a/vllm/kvprune/integration/__init__.py b/vllm/kvprune/integration/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1470f0ad554e2282b26acd66c47a98824c12245c
--- /dev/null
+++ b/vllm/kvprune/integration/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""KV-pruning integration: compactor ``LLMEngine`` sharing weights with :class:`~vllm.LLM`."""
+
+from vllm.kvprune.integration.compression_params import CompressionParams
+
+__all__ = ["CompressionParams"]
diff --git a/vllm/kvprune/integration/compactor_shared.py b/vllm/kvprune/integration/compactor_shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..148df4f06dd6397f57a96080ec340dc6d9eaa1d0
--- /dev/null
+++ b/vllm/kvprune/integration/compactor_shared.py
@@ -0,0 +1,140 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Construct compactor :class:`LLMEngine` sharing weight tensors with an in-process vLLM ``LLM``."""
+
+from __future__ import annotations
+
+import os
+
+import torch.nn as nn
+
+from vllm.config import VllmConfig
+from vllm.kvprune.config.engine_config import LLMConfig
+from vllm.kvprune.core.llm_engine import LLMEngine
+from vllm.kvprune.integration.config_adapter import vllm_config_to_llm_config
+from vllm.kvprune.integration.vllm_model_access import extract_vllm_causal_lm
+from vllm.kvprune.integration.weight_tie import (
+ delegate_kvprune_compute_logits_to_vllm,
+ delegate_kvprune_embed_tokens_to_vllm,
+ tie_kvprune_rope_buffers_from_vllm,
+ tie_kvprune_weights_from_vllm,
+)
+from vllm.kvprune.models import MODEL_REGISTRY
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def build_llm_config_for_compactor(vc: VllmConfig) -> LLMConfig:
+ """Public helper: vLLM config → compactor :class:`LLMConfig`."""
+ return vllm_config_to_llm_config(vc)
+
+
+def create_compactor_engine_with_shared_weights(llm: object) -> LLMEngine:
+ """Single GPU, TP=1: compactor ``LLMEngine`` whose weights alias vLLM tensors.
+
+ Call after the vLLM ``LLM`` has loaded weights. Requires in-process executor
+ (``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
+ """
+ llm_engine = getattr(llm, "llm_engine", None)
+ if llm_engine is None:
+ raise RuntimeError("Expected ``llm.llm_engine``.")
+ vc: VllmConfig = llm_engine.vllm_config
+ if vc.parallel_config.tensor_parallel_size != 1:
+ raise ValueError(
+ "Shared-weight compactor backend requires tensor_parallel_size=1"
+ )
+
+ cfg = vllm_config_to_llm_config(vc)
+ # ``cfg.enforce_eager`` is for the compactor ``ModelRunner`` only (decode CUDA
+ # graphs), not v1. v1 graph capture is controlled solely by ``LLM(...,
+ # enforce_eager=...)`` / ``kvprune_compression=True`` on the entrypoint ``LLM``.
+ # Large vLLM max_num_seqs blows up compactor page-table GPU memory; sharing the GPU
+ # with v1 leaves little room for metadata + KV tensors. Default cap 32 so physical
+ # KV pages stay usable; set VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS=0 to disable cap,
+ # or raise (e.g. 128) if you have VRAM headroom.
+ _cap = os.environ.get("VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS", "32").strip()
+ if _cap:
+ lim = int(_cap)
+ if lim > 0:
+ cfg.max_num_seqs = min(cfg.max_num_seqs, lim)
+
+ # Compactor decode graphs (``enforce_eager=False``): honored for non-shared-weight
+ # engines. **Shared-weight** path (below) forces ``enforce_eager=True`` after
+ # delegating ``compute_logits`` to vLLM unless ``VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1``.
+ # Opt out of graphs for non-shared runs: ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1`` or
+ # ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0``.
+ _ce = os.environ.get("VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER", "").strip().lower()
+ if _ce in ("1", "true", "yes"):
+ cfg.enforce_eager = True
+ logger.info(
+ "KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1 → "
+ "enforce_eager=True (skip compactor decode CUDA graphs)."
+ )
+ elif _ce in ("0", "false", "no"):
+ cfg.enforce_eager = False
+ logger.info(
+ "KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=0 → "
+ "enforce_eager=False (try compactor CUDA graph capture)."
+ )
+ else:
+ _dg = os.environ.get(
+ "VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH", "1"
+ ).strip().lower()
+ if _dg in ("0", "false", "no"):
+ cfg.enforce_eager = True
+ logger.info(
+ "KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0 → "
+ "enforce_eager=True (skip compactor decode CUDA graphs)."
+ )
+ else:
+ cfg.enforce_eager = False
+ logger.info(
+ "KV-prune compactor: default try decode CUDA graphs; ModelRunner "
+ "falls back to eager if capture yields none. Set "
+ "VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1 or "
+ "VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0 to skip capture."
+ )
+
+ hf = cfg.hf_config
+ assert hf is not None
+ model_type = hf.model_type
+ if model_type not in MODEL_REGISTRY:
+ raise ValueError(
+ f"Compactor MODEL_REGISTRY has no entry for model_type={model_type!r}; "
+ f"supported: {sorted(MODEL_REGISTRY)}"
+ )
+
+ vllm_model = extract_vllm_causal_lm(llm)
+ device = next(vllm_model.parameters()).device
+ dtype = next(vllm_model.parameters()).dtype
+
+ # Build compactor shell on CPU first. **Do not** call ``.to(device)`` before tying:
+ # that allocates a full second copy of weights on GPU; tying then frees the
+ # duplicate but peak memory can OOM on large models. Tie first so parameters
+ # alias vLLM tensors directly (no extra weight VRAM).
+ kv_model: nn.Module = MODEL_REGISTRY[model_type](hf)
+ tie_kvprune_weights_from_vllm(vllm_model, kv_model)
+ # Buffers (e.g. RoPE tables) not in ``named_parameters`` may still be on CPU.
+ kv_model.to(device=device, dtype=dtype)
+ tie_kvprune_rope_buffers_from_vllm(vllm_model, kv_model)
+ delegate_kvprune_embed_tokens_to_vllm(vllm_model, kv_model)
+ delegate_kvprune_compute_logits_to_vllm(vllm_model, kv_model)
+
+ # Compactor decode CUDA graphs capture ``model.forward`` + ``compute_logits`` in one
+ # graph. Here ``compute_logits`` is delegated to vLLM's LM head / LogitsProcessor
+ # (cublas GEMM, padded vocab, etc.). Embedding that in a nested capture commonly
+ # fails with ``CUBLAS_STATUS_EXECUTION_FAILED`` and invalidates stream capture
+ # (``cudaErrorStreamCaptureInvalidated``). Default: skip graphs for this integration.
+ _sw_graph = os.environ.get(
+ "VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH", "0"
+ ).strip().lower() in ("1", "true", "yes")
+ if not _sw_graph:
+ cfg.enforce_eager = True
+ logger.info(
+ "KV-prune shared-weight compactor: enforce_eager=True (skip compactor "
+ "decode CUDA graphs; logits delegated to vLLM). Set "
+ "VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1 only to attempt capture (often fails)."
+ )
+
+ return LLMEngine(cfg, external_model=kv_model)
diff --git a/vllm/kvprune/integration/compressed_generate.py b/vllm/kvprune/integration/compressed_generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e8db81b3a4af494602941aeee2f680f044055a5
--- /dev/null
+++ b/vllm/kvprune/integration/compressed_generate.py
@@ -0,0 +1,452 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""KV-pruning (compactor) path invoked from :meth:`vllm.entrypoints.llm.LLM.generate`."""
+
+from __future__ import annotations
+
+import os
+from collections.abc import Callable, Sequence
+from pathlib import Path
+from typing import Any
+
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer
+
+from vllm.kvprune.compression.compression_config import (
+ BatchCompressionParams,
+ SequenceCompressionParams,
+)
+from vllm.kvprune.config.sampling_params import SamplingParams as CompactorSamplingParams
+from vllm.kvprune.core.compression_bridge import (
+ compression_method_id_to_enum,
+ compression_method_str_to_id,
+)
+from vllm.kvprune.core.llm_engine import LLMEngine, _infer_stop_token_ids
+from vllm.kvprune.integration.compactor_shared import create_compactor_engine_with_shared_weights
+from vllm.kvprune.integration.compression_params import CompressionParams
+from vllm.logger import init_logger
+from vllm.outputs import CompletionOutput, RequestOutput
+from vllm.sampling_params import SamplingParams
+
+logger = init_logger(__name__)
+
+_MP_ENV = "VLLM_ENABLE_V1_MULTIPROCESSING"
+_RELEASE_V1_KV_ENV = "VLLM_KVPRUNE_RELEASE_V1_KV"
+
+
+def _maybe_release_v1_kv_for_compactor(llm: Any) -> None:
+ """Optionally discard v1's KV cache so more GPU memory is free for compactor.
+
+ v1 reserves KV blocks at engine init; shared-weight compactor then competes for
+ the same VRAM. ``sleep(level=1)`` discards v1 KV and may offload tagged weights
+ per v1 sleep policy, then ``wake_up()`` reloads — compactor still ties the same
+ v1 tensors after.
+
+ **Default:** ``vllm.env_override`` sets ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` (no
+ sleep/wake; v1 KV stays on GPU). Set ``=1`` if you need extra VRAM for compactor
+ before the first compressed step (then ``llm.sleep`` / ``CuMemAllocator`` /
+ ``Sleep mode freed …`` logs are expected). This does **not** remove v1's KV
+ reservation at init; it only runs the optional sleep/wake cycle before compactor.
+
+ Tests keep ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` in ``conftest``.
+ """
+ if os.environ.get(_RELEASE_V1_KV_ENV, "0").strip().lower() not in (
+ "1",
+ "true",
+ "yes",
+ ):
+ return
+ try:
+ logger.info(
+ "%s=1: discarding v1 KV via sleep(level=1) then wake_up() "
+ "(reloads model weights to GPU).",
+ _RELEASE_V1_KV_ENV,
+ )
+ llm.sleep(level=1, mode="abort")
+ llm.wake_up()
+ except Exception as e:
+ logger.warning("%s: sleep/wake failed: %s", _RELEASE_V1_KV_ENV, e)
+
+
+def ensure_inprocess_engine_for_weight_sharing() -> None:
+ """Compactor must see ``worker.get_model()`` in the same process as vLLM."""
+ if os.environ.get(_MP_ENV, "1") != "0":
+ os.environ[_MP_ENV] = "0"
+ logger.info(
+ "KV cache pruning: set %s=0 so the model stays in-process for "
+ "shared-weight compactor (no manual env needed).",
+ _MP_ENV,
+ )
+
+
+def _normalize_prompt_list(prompts: Any) -> list[Any]:
+ if isinstance(prompts, str):
+ return [prompts]
+ if isinstance(prompts, dict):
+ return [prompts]
+ return list(prompts)
+
+
+def _normalize_sampling_params(
+ sampling_params: SamplingParams | Sequence[SamplingParams] | None,
+ n: int,
+) -> list[SamplingParams]:
+ if sampling_params is None:
+ return [SamplingParams() for _ in range(n)]
+ if isinstance(sampling_params, SamplingParams):
+ return [sampling_params] * n
+ sps = list(sampling_params)
+ if len(sps) != n:
+ raise ValueError(
+ f"sampling_params length {len(sps)} != prompts length {n}"
+ )
+ return sps
+
+
+def _normalize_compression_params(
+ compression: CompressionParams | Sequence[CompressionParams] | None,
+ n: int,
+) -> list[CompressionParams]:
+ if compression is None:
+ return [CompressionParams(compression_ratio=1.0) for _ in range(n)]
+ if isinstance(compression, CompressionParams):
+ return [compression] * n
+ comp = list(compression)
+ if len(comp) != n:
+ raise ValueError(f"compression length {len(comp)} != prompts length {n}")
+ return comp
+
+
+def _any_compactor(comps: list[CompressionParams]) -> bool:
+ return any(c.compression_ratio < 1.0 for c in comps)
+
+
+_FORCE_COMPACTOR_PATH_ENV = "VLLM_KVPRUNE_FORCE_COMPACTOR_PATH"
+
+
+def _should_use_kvprune_compactor_path(comps: list[CompressionParams]) -> bool:
+ """Use integrated compactor when any prompt requests compression, or when forced.
+
+ If all ``compression_ratio >= 1.0``, the default is to return ``None`` from
+ :func:`try_compressed_generate` and fall back to the standard v1 engine
+ (``Processed prompts`` loop). That hides TP/kvprune bugs behind a different
+ code path. Set ``VLLM_KVPRUNE_FORCE_COMPACTOR_PATH=1`` to run the same
+ compactor + collective RPC path as compression-on, with no KV pruning.
+ """
+ if _any_compactor(comps):
+ return True
+ return os.environ.get(_FORCE_COMPACTOR_PATH_ENV, "").strip().lower() in (
+ "1",
+ "true",
+ "yes",
+ )
+
+
+def _to_compactor_sampling(sp: SamplingParams) -> CompactorSamplingParams:
+ mt = sp.max_tokens
+ if mt is None:
+ mt = 16
+ return CompactorSamplingParams(
+ temperature=float(sp.temperature),
+ max_new_tokens=int(mt),
+ )
+
+
+def _to_sequence_compression(cp: CompressionParams) -> SequenceCompressionParams:
+ return SequenceCompressionParams(
+ compression_ratio=float(cp.compression_ratio),
+ protected_first_tokens=int(cp.protected_first_tokens),
+ protected_last_tokens=int(cp.protected_last_tokens),
+ )
+
+
+def _batch_compression_from_comps(comps: list[CompressionParams]) -> BatchCompressionParams:
+ for c in comps:
+ if c.compression_ratio < 1.0:
+ mid = compression_method_str_to_id(c.compression_method)
+ return BatchCompressionParams(
+ compression_method=compression_method_id_to_enum(mid)
+ )
+ return BatchCompressionParams()
+
+
+def _kvprune_compactor_hf_tokenizer(llm: Any):
+ """HF tokenizer matching :meth:`vllm.kvprune.core.llm_engine.LLMEngine.__init__`.
+
+ Loads from the **resolved on-disk** model tree (local dir or HF cache snapshot), not
+ the bare repo id, to avoid redundant Hub downloads.
+ """
+ cached = getattr(llm, "_kvprune_compactor_hf_tokenizer", None)
+ if cached is not None:
+ return cached
+ mc = llm.llm_engine.vllm_config.model_config
+ model_s = str(mc.model)
+ src = model_s
+ try:
+ p = Path(model_s)
+ if p.is_dir() and (p / "config.json").is_file():
+ src = str(p.resolve())
+ else:
+ from huggingface_hub import snapshot_download
+
+ src = snapshot_download(repo_id=model_s, local_files_only=False)
+ except Exception:
+ src = model_s
+ hf_cfg = mc.hf_config
+ _trust = bool(getattr(hf_cfg, "trust_remote_code", False)) if hf_cfg is not None else False
+ tok = AutoTokenizer.from_pretrained(src, use_fast=True, trust_remote_code=_trust)
+ llm._kvprune_compactor_hf_tokenizer = tok
+ return tok
+
+
+def _prompt_to_compactor_input(prompt: Any) -> str | list[int]:
+ if isinstance(prompt, str):
+ return prompt
+ # Decoder-only `list[int]` token ids (see `vllm.inputs.PromptType`).
+ if isinstance(prompt, list):
+ if not prompt:
+ raise TypeError("Empty token-id prompt is not supported for compactor path.")
+ if all(isinstance(t, int) for t in prompt):
+ return list(prompt)
+ if isinstance(prompt, dict):
+ if "prompt_token_ids" in prompt:
+ ids = prompt["prompt_token_ids"]
+ return list(ids) if not isinstance(ids, list) else ids
+ p = prompt.get("prompt")
+ if isinstance(p, str):
+ return p
+ raise TypeError(
+ f"Unsupported prompt type for compactor path: {type(prompt)}. "
+ "Use str, list[int] token ids, or dict with 'prompt_token_ids' or 'prompt'."
+ )
+
+
+def _prompt_to_token_ids_for_tp(llm: Any, prompt: Any) -> list[int]:
+ """Driver-side token ids for the TP collective path (same tokenizer as vLLM ``LLM``)."""
+ comp_in = _prompt_to_compactor_input(prompt)
+ if isinstance(comp_in, str):
+ return llm.get_tokenizer().encode(comp_in)
+ return list(comp_in)
+
+
+def _compressed_generate_tp_collective(
+ llm: Any,
+ plist: list[Any],
+ sps: list[SamplingParams],
+ comps: list[CompressionParams],
+) -> list[RequestOutput]:
+ """TP>1: run compactor on each worker via ``collective_rpc`` (all ranks)."""
+ vc = llm.llm_engine.vllm_config
+ pc = vc.parallel_config
+ if pc.pipeline_parallel_size != 1 or pc.data_parallel_size != 1:
+ raise NotImplementedError(
+ "KV-prune TP compression requires pipeline_parallel_size=1 and "
+ f"data_parallel_size=1 (got PP={pc.pipeline_parallel_size}, "
+ f"DP={pc.data_parallel_size})."
+ )
+
+ hf = vc.model_config.hf_config
+ tok = llm.get_tokenizer()
+ eos_token_ids = _infer_stop_token_ids(tok, hf)
+
+ prompt_token_ids = [_prompt_to_token_ids_for_tp(llm, p) for p in plist]
+
+ max_len = int(vc.model_config.max_model_len)
+ for i, ids in enumerate(prompt_token_ids):
+ if len(ids) > max_len:
+ raise ValueError(
+ f"KV-prune TP compressed generate: prompt {i} length {len(ids)} "
+ f"exceeds max_model_len ({max_len}). Shorten the prompt or raise "
+ "max_model_len when constructing LLM()."
+ )
+
+ # Payload must be picklable for multiproc/Ray RPC: do not pass multiprocessing
+ # synchronization primitives (workers are separate processes).
+ payload: dict[str, Any] = {
+ "eos_token_ids": eos_token_ids,
+ "prompt_token_ids": prompt_token_ids,
+ "sampling_params": [
+ {
+ "temperature": float(sp.temperature),
+ "max_new_tokens": int(sp.max_tokens if sp.max_tokens is not None else 16),
+ }
+ for sp in sps
+ ],
+ "compression_params": [
+ {
+ "compression_ratio": float(c.compression_ratio),
+ "compression_method": str(c.compression_method),
+ "protected_first_tokens": int(c.protected_first_tokens),
+ "protected_last_tokens": int(c.protected_last_tokens),
+ }
+ for c in comps
+ ],
+ }
+
+ _maybe_release_v1_kv_for_compactor(llm)
+ try:
+ results = llm.llm_engine.collective_rpc(
+ "kvprune_v1_compressed_generate",
+ args=(payload,),
+ )
+ except RuntimeError as e:
+ if "cancelled" in str(e).lower():
+ raise RuntimeError(
+ "collective_rpc was cancelled (a GPU worker likely crashed). "
+ "Scroll up for the first worker traceback — often NCCL/CUDA before "
+ "TCPStore/Broken pipe on the driver."
+ ) from e
+ raise
+ master: dict[str, Any] | None = None
+ for r in results:
+ if isinstance(r, dict) and r.get("tensor_parallel_rank") == 0:
+ master = r
+ break
+ if master is None:
+ raise RuntimeError(
+ "collective_rpc did not return a dict from tensor parallel rank 0."
+ )
+ return _tp_payload_to_request_outputs(llm, master)
+
+
+def _tp_payload_to_request_outputs(llm: Any, master: dict[str, Any]) -> list[RequestOutput]:
+ tok = llm.get_tokenizer()
+ out: list[RequestOutput] = []
+ pids_list = master["prompt_token_ids"]
+ cids_list = master["completion_token_ids"]
+ for i, (pids, cids) in enumerate(zip(pids_list, cids_list)):
+ text = tok.decode(cids, skip_special_tokens=True)
+ # Match ``_sequences_to_request_outputs``: if decode is only special tokens,
+ # skip_special_tokens=True yields blank text while token list is non-empty.
+ if not text.strip() and cids:
+ text = tok.decode(cids, skip_special_tokens=False)
+ co = CompletionOutput(
+ index=0,
+ text=text,
+ token_ids=list(cids),
+ cumulative_logprob=None,
+ logprobs=None,
+ finish_reason="stop",
+ )
+ ro = RequestOutput(
+ request_id=f"kvprune-tp-{i}",
+ prompt=None,
+ prompt_token_ids=list(pids),
+ prompt_logprobs=None,
+ outputs=[co],
+ finished=True,
+ )
+ out.append(ro)
+ return out
+
+
+def _ensure_compactor_engine(llm: Any) -> LLMEngine:
+ if llm._kvprune_compactor_engine is None:
+ pc = llm.llm_engine.vllm_config.parallel_config
+ if pc.tensor_parallel_size != 1:
+ raise ValueError(
+ "KV-pruning compactor path requires tensor_parallel_size=1 "
+ "for shared weights."
+ )
+ llm._kvprune_compactor_engine = create_compactor_engine_with_shared_weights(llm)
+ logger.info("Initialized compactor LLMEngine with weights shared from vLLM.")
+ return llm._kvprune_compactor_engine
+
+
+def try_compressed_generate(
+ llm: Any,
+ prompts: Any,
+ sampling_params: SamplingParams | Sequence[SamplingParams] | None,
+ *,
+ compression: CompressionParams | Sequence[CompressionParams] | None,
+ use_tqdm: bool | Callable[..., tqdm] = True,
+ lora_request: Any = None,
+ priority: list[int] | None = None,
+ tokenization_kwargs: dict[str, Any] | None = None,
+) -> list[RequestOutput] | None:
+ """Return completions on the compactor engine, or ``None`` to use normal v1.
+
+ ``lora_request`` / ``priority`` / ``tokenization_kwargs`` are accepted for API
+ parity with :meth:`~vllm.entrypoints.llm.LLM.generate` but are not passed to the
+ compactor engine yet.
+ """
+ del lora_request, priority, tokenization_kwargs, use_tqdm
+
+ plist = _normalize_prompt_list(prompts)
+ sps = _normalize_sampling_params(sampling_params, len(plist))
+ comps = _normalize_compression_params(compression, len(plist))
+
+ pc = llm.llm_engine.vllm_config.parallel_config
+ # TP>1: every worker must run the same collective_rpc session. If all
+ # compression_ratio >= 1, the old code returned None and only the driver ran
+ # v1 _run_engine — other ranks never joined a matching collective, which can
+ # deadlock NCCL / leave workers unsynchronized (hang at "Processed prompts:").
+ if pc.tensor_parallel_size > 1:
+ if not _should_use_kvprune_compactor_path(comps):
+ comps = [CompressionParams(compression_ratio=1.0) for _ in plist]
+ elif not _should_use_kvprune_compactor_path(comps):
+ return None
+
+ v1_eager = bool(
+ getattr(llm.llm_engine.vllm_config.model_config, "enforce_eager", False)
+ )
+ if not v1_eager:
+ logger.warning(
+ "KV-prune compression: v1 CUDA graphs are still enabled on this LLM. "
+ "The compactor does not reuse v1 graphs; capture wastes VRAM. "
+ "Set kvprune_compression=True, enforce_eager=True, or "
+ "VLLM_KVPRUNE_COMPRESSION_DEFAULT=1 before import vllm."
+ )
+
+ if pc.tensor_parallel_size > 1:
+ return _compressed_generate_tp_collective(llm, plist, sps, comps)
+
+ ensure_inprocess_engine_for_weight_sharing()
+ if llm._kvprune_compactor_engine is None:
+ _maybe_release_v1_kv_for_compactor(llm)
+ engine = _ensure_compactor_engine(llm)
+ comp_sp = [_to_compactor_sampling(sp) for sp in sps]
+ seq_c = [_to_sequence_compression(c) for c in comps]
+ batch_c = _batch_compression_from_comps(comps)
+ comp_in = [_prompt_to_compactor_input(p) for p in plist]
+
+ _, seqs = engine.generate(
+ comp_in,
+ sampling_params=comp_sp,
+ batch_compression_params=batch_c,
+ per_sequence_compression_params=seq_c,
+ return_sequences=True,
+ )
+
+ return _sequences_to_request_outputs(seqs, engine)
+
+
+def _sequences_to_request_outputs(seqs: list[Any], engine: LLMEngine) -> list[RequestOutput]:
+ tok = engine.tokenizer
+ out: list[RequestOutput] = []
+ for i, seq in enumerate(seqs):
+ text = tok.decode(seq.completion_token_ids, skip_special_tokens=True)
+ # If every emitted id is “special” (e.g. EOS / chat boundary), the stripped
+ # string is empty while ``completion_token_ids`` is non-empty — avoid
+ # presenting a blank answer so users can see boundary tokens / debug.
+ if not text.strip() and seq.completion_token_ids:
+ text = tok.decode(seq.completion_token_ids, skip_special_tokens=False)
+ co = CompletionOutput(
+ index=0,
+ text=text,
+ token_ids=list(seq.completion_token_ids),
+ cumulative_logprob=None,
+ logprobs=None,
+ finish_reason="stop",
+ )
+ ro = RequestOutput(
+ request_id=f"kvprune-{i}",
+ prompt=None,
+ prompt_token_ids=list(seq.prompt_token_ids),
+ prompt_logprobs=None,
+ outputs=[co],
+ finished=True,
+ )
+ out.append(ro)
+ return out
+
diff --git a/vllm/kvprune/integration/compression_params.py b/vllm/kvprune/integration/compression_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26511afb5445522fc759e8acfbe379f0ff59936
--- /dev/null
+++ b/vllm/kvprune/integration/compression_params.py
@@ -0,0 +1,52 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Per-request KV compression for :meth:`vllm.LLM.generate` (``compression=`` kwarg)."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+
+@dataclass
+class CompressionParams:
+ """Per-prompt compression intent for :meth:`vllm.LLM.generate`.
+
+ If **any** prompt in the batch has ``compression_ratio < 1.0``, the **whole** batch
+ is run on the compactor ``LLMEngine`` (same stack as standalone compactor-vllm:
+ ``PagedKVCache`` + pruning kernels). If all prompts have ``compression_ratio >= 1.0``,
+ the batch stays on standard vLLM.
+
+ ``compression_method`` follows :mod:`vllm.kvprune.core.compression_bridge` aliases:
+ ``none``, ``criticaladakv``, ``compactor``, ``snapkv`` (ignored when
+ ``compression_ratio`` is effectively 1).
+
+ ``protected_*`` map to compactor :class:`~vllm.kvprune.compression.compression_config.SequenceCompressionParams`
+ (defaults match standalone compactor-vllm-style usage).
+ """
+
+ compression_ratio: float = 1.0
+ compression_method: str = "compactor"
+ protected_first_tokens: int = 16
+ protected_last_tokens: int = 64
+
+ def __post_init__(self) -> None:
+ if not 0.0 < self.compression_ratio <= 1.0:
+ raise ValueError(
+ f"compression_ratio must be in (0, 1], got {self.compression_ratio}"
+ )
+ self.compression_method = (
+ self.compression_method or "compactor"
+ ).strip().lower()
+ from vllm.kvprune.core.compression_bridge import VALID_ALIASES_FOR_SAMPLING
+
+ if self.compression_method not in VALID_ALIASES_FOR_SAMPLING:
+ raise ValueError(
+ f"compression_method must be one of {sorted(VALID_ALIASES_FOR_SAMPLING)}, "
+ f"got {self.compression_method!r}"
+ )
+ if self.compression_ratio >= 1.0 - 1e-9:
+ self.compression_method = "none"
+ elif self.compression_method == "none":
+ raise ValueError(
+ "When compression_ratio < 1.0, compression_method cannot be 'none'."
+ )
diff --git a/vllm/kvprune/integration/config_adapter.py b/vllm/kvprune/integration/config_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..38ada90dacb5b429d0a60e77a8e058f0ce9559d1
--- /dev/null
+++ b/vllm/kvprune/integration/config_adapter.py
@@ -0,0 +1,143 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Build :class:`vllm.kvprune.config.engine_config.LLMConfig` from :class:`VllmConfig`."""
+
+from __future__ import annotations
+
+import os
+from pathlib import Path
+
+from vllm.config import VllmConfig
+from vllm.kvprune.config.engine_config import LLMConfig, KvpruneAttentionSchedule
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def _attention_schedule_from_env() -> KvpruneAttentionSchedule:
+ """Resolve :class:`KvpruneAttentionSchedule` from env.
+
+ Primary (``VLLM_KVPRUNE_ATTENTION_SCHEDULE``):
+
+ - ``fa_triton`` — FA prefill, Triton decode (default). Aliases: ``fa_prefill``,
+ ``default``, empty.
+ - ``pdtriton`` — Triton prefill + Triton decode. Aliases: ``triton``,
+ ``triton_prefill``, ``compactor_prefill``, ``pd_triton``.
+ - ``pdfa`` — FA prefill + FA decode (KV stores still Triton). Aliases:
+ ``fa_full``, ``fa_both``.
+
+ Legacy: ``VLLM_KVPRUNE_ATTENTION_BACKEND`` maps ``flash``/``fa`` → ``fa_triton``,
+ ``compactor``/``triton`` → ``pdtriton``.
+ """
+ s = os.environ.get("VLLM_KVPRUNE_ATTENTION_SCHEDULE", "").strip().lower()
+ if s in ("fa_triton", "fa_prefill", "default", ""):
+ return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ if s in ("pdtriton", "pd_triton", "triton", "triton_prefill", "compactor_prefill"):
+ return KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
+ if s in ("pdfa", "fa_full", "fa_both"):
+ return KvpruneAttentionSchedule.PDFA
+ if s:
+ logger.warning(
+ "Unknown VLLM_KVPRUNE_ATTENTION_SCHEDULE=%r; using FA_PREFILL_TRITON_DECODE",
+ s,
+ )
+ return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+
+ v = os.environ.get("VLLM_KVPRUNE_ATTENTION_BACKEND", "").strip().lower()
+ if v in ("flash", "fa", "flash_attention", "flashattention"):
+ return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ if v in ("compactor", "triton", "compactor_triton", ""):
+ return KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
+ logger.warning(
+ "Unknown VLLM_KVPRUNE_ATTENTION_BACKEND=%r; using FA_PREFILL_TRITON_DECODE", v
+ )
+ return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+
+
+def _compactor_kvcache_page_size(vllm_block_size: int | None) -> int:
+ """Tokens per physical KV page for compactor :class:`LLMConfig`.
+
+ ``compactor-vllm`` uses ``kvcache_page_size=128`` by default. Keeping that page
+ size is important for correctness comparisons when validating the integrated
+ ``kvprune`` backend against standalone compactor, especially for ``pdtriton``
+ where paged-KV layout and page-padding behavior are part of the observed
+ divergence on DCU.
+
+ Override with ``VLLM_KVPRUNE_PAGE_SIZE``:
+
+ - unset: use standalone-compactor-compatible ``128``
+ - positive integer: use that exact page size (must be divisible by 32)
+ - ``vllm`` / ``inherit`` / ``block``: derive from vLLM ``block_size`` and round up
+ to the next multiple of 32 (the older integrated behavior)
+ """
+ env_v = os.environ.get("VLLM_KVPRUNE_PAGE_SIZE", "").strip().lower()
+ if env_v:
+ if env_v in ("vllm", "inherit", "block"):
+ bs = 128 if vllm_block_size is None else int(vllm_block_size)
+ if bs <= 0:
+ return 128
+ if bs % 32 == 0:
+ return bs
+ return ((bs + 31) // 32) * 32
+ page_size = int(env_v)
+ if page_size <= 0 or page_size % 32 != 0:
+ raise ValueError(
+ "VLLM_KVPRUNE_PAGE_SIZE must be a positive multiple of 32, "
+ f"got {page_size}."
+ )
+ return page_size
+
+ return 128
+
+
+def vllm_config_to_llm_config(vc: VllmConfig) -> LLMConfig:
+ """Map vLLM engine config to compactor :class:`LLMConfig`."""
+ mc = vc.model_config
+ cc = vc.cache_config
+ pc = vc.parallel_config
+ sc = vc.scheduler_config
+ block_size = cc.block_size
+ if block_size is None:
+ block_size = 16
+ max_num_seqs = getattr(sc, "max_num_seqs", 256)
+ # Do **not** forward ``model_config.enforce_eager`` (v1) into compactor
+ # :class:`LLMConfig`. They are independent flags: v1 uses it only to skip
+ # *v1* ``capture_model()``; kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner`
+ # uses :attr:`LLMConfig.enforce_eager` only for *compactor* decode CUDA graphs.
+ # Shared-weight setup in ``compactor_shared`` defaults compactor to eager decode;
+ # see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` (default try graphs) /
+ # ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``.
+ # Local checkpoint directory: forward so compactor skips redundant Hub fetches.
+ _model_s = str(mc.model)
+ _path: str | None = None
+ try:
+ if _model_s and Path(_model_s).is_dir() and (Path(_model_s) / "config.json").is_file():
+ _path = str(Path(_model_s).resolve())
+ except OSError:
+ pass
+
+ page_size = _compactor_kvcache_page_size(block_size)
+ attention_schedule = _attention_schedule_from_env()
+ logger.info(
+ "kvprune compactor config: attention_schedule=%s, kvcache_page_size=%d",
+ attention_schedule.name,
+ page_size,
+ )
+
+ return LLMConfig(
+ model=_model_s,
+ path=_path,
+ nccl_port=1218,
+ max_num_seqs=max_num_seqs,
+ max_model_len=mc.max_model_len,
+ gpu_memory_utilization=cc.gpu_memory_utilization,
+ tensor_parallel_size=pc.tensor_parallel_size,
+ enforce_eager=False,
+ hf_config=mc.hf_config,
+ eos=-1,
+ eos_token_ids=None,
+ kvcache_page_size=page_size,
+ leverage_sketch_size=48,
+ attention_schedule=attention_schedule,
+ attention_backend=None,
+ )
diff --git a/vllm/kvprune/integration/v1_tp_runner.py b/vllm/kvprune/integration/v1_tp_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc8fc6b5245f9316e574133fbe4f5f91533e02d
--- /dev/null
+++ b/vllm/kvprune/integration/v1_tp_runner.py
@@ -0,0 +1,203 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""TP>1: one kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner` per vLLM worker.
+
+Invoked via v1 ``collective_rpc("kvprune_v1_compressed_generate", ...)`` so every tensor-
+parallel rank participates in the same compactor forward/broadcast sequence as the
+standalone multi-process compactor.
+
+Compactor decode CUDA graphs (when not ``enforce_eager``) capture the full decode step
+including ``compute_logits``. To force eager on embedded TP workers, set
+``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` or ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``.
+
+Peer/master session boundaries use TP-group ``broadcast``/``barrier`` (see
+``ModelRunner.maybe_release_peers``), not ``multiprocessing.Event`` — RPC payloads must
+be picklable across worker processes.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Any
+
+import torch
+import torch.nn as nn
+
+from vllm.kvprune.compression.compression_config import (
+ BatchCompressionParams,
+ SequenceCompressionParams,
+)
+from vllm.kvprune.config.sampling_params import SamplingParams as CompactorSamplingParams
+from vllm.kvprune.core.compression_bridge import (
+ compression_method_id_to_enum,
+ compression_method_str_to_id,
+)
+from vllm.kvprune.core.model_runner import ModelRunner
+from vllm.kvprune.integration.config_adapter import vllm_config_to_llm_config
+from vllm.kvprune.utils.kv_dist import barrier_sync
+from vllm.kvprune.integration.weight_tie import (
+ delegate_kvprune_compute_logits_to_vllm,
+ delegate_kvprune_embed_tokens_to_vllm,
+ tie_kvprune_rope_buffers_from_vllm,
+ tie_kvprune_weights_from_vllm,
+)
+from vllm.kvprune.models import MODEL_REGISTRY
+from vllm.kvprune.utils.sequence import Sequence
+
+_ATTR = "_kvprune_tp_embedded_runner"
+
+
+def _apply_compactor_env_overrides(cfg: Any) -> None:
+ """Match :func:`~vllm.kvprune.integration.compactor_shared.create_compactor_engine_with_shared_weights` caps."""
+ _cap = os.environ.get("VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS", "32").strip()
+ if _cap:
+ lim = int(_cap)
+ if lim > 0:
+ cfg.max_num_seqs = min(cfg.max_num_seqs, lim)
+
+ _ce = os.environ.get("VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER", "").strip().lower()
+ if _ce in ("1", "true", "yes"):
+ cfg.enforce_eager = True
+ elif _ce in ("0", "false", "no"):
+ cfg.enforce_eager = False
+ else:
+ _dg = os.environ.get("VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH", "1").strip().lower()
+ cfg.enforce_eager = _dg in ("0", "false", "no")
+
+
+def _build_sequences(payload: dict[str, Any]) -> list[Sequence]:
+ prompt_ids: list[list[int]] = payload["prompt_token_ids"]
+ sps: list[dict[str, Any]] = payload["sampling_params"]
+ cps: list[dict[str, Any]] = payload["compression_params"]
+ seqs: list[Sequence] = []
+ for i, ids in enumerate(prompt_ids):
+ sp = CompactorSamplingParams(
+ temperature=float(sps[i]["temperature"]),
+ max_new_tokens=int(sps[i]["max_new_tokens"]),
+ )
+ cp = SequenceCompressionParams(
+ compression_ratio=float(cps[i]["compression_ratio"]),
+ protected_first_tokens=int(cps[i].get("protected_first_tokens", 16)),
+ protected_last_tokens=int(cps[i].get("protected_last_tokens", 64)),
+ )
+ if cp.protected_first_tokens + cp.protected_last_tokens >= len(ids):
+ cp.compression_ratio = 1.0
+ seqs.append(
+ Sequence(
+ prompt_token_ids=list(ids),
+ sampling_params=sp,
+ compression_params=cp,
+ )
+ )
+ return seqs
+
+
+def _batch_compression_from_payload(payload: dict[str, Any]) -> BatchCompressionParams:
+ cps = payload["compression_params"]
+ for c in cps:
+ if float(c["compression_ratio"]) < 1.0:
+ mid = compression_method_str_to_id(str(c.get("compression_method", "none")))
+ return BatchCompressionParams(
+ compression_method=compression_method_id_to_enum(mid)
+ )
+ return BatchCompressionParams()
+
+
+def _get_or_create_runner(worker: Any, payload: dict[str, Any]) -> ModelRunner:
+ existing = getattr(worker, _ATTR, None)
+ if existing is not None:
+ return existing
+
+ from vllm.distributed.parallel_state import (
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ )
+
+ vc = worker.vllm_config
+ pc = vc.parallel_config
+ if pc.pipeline_parallel_size != 1 or pc.data_parallel_size != 1:
+ raise NotImplementedError(
+ "KV-prune TP compressed generate requires pipeline_parallel_size=1 and "
+ f"data_parallel_size=1; got PP={pc.pipeline_parallel_size}, "
+ f"DP={pc.data_parallel_size}."
+ )
+
+ tp_ws = get_tensor_model_parallel_world_size()
+ if tp_ws != pc.tensor_parallel_size:
+ raise RuntimeError(
+ f"parallel_state TP world size {tp_ws} != config.tensor_parallel_size "
+ f"{pc.tensor_parallel_size}"
+ )
+
+ hf = vc.model_config.hf_config
+ model_type = getattr(hf, "model_type", None)
+ if model_type not in MODEL_REGISTRY:
+ raise ValueError(
+ f"KV-prune TP path: unsupported model_type={model_type!r}; "
+ f"registry has {sorted(MODEL_REGISTRY)}"
+ )
+
+ cfg = vllm_config_to_llm_config(vc)
+ eos_ids = payload["eos_token_ids"]
+ cfg.eos_token_ids = sorted({int(x) for x in eos_ids})
+ cfg.eos = int(cfg.eos_token_ids[0])
+ _apply_compactor_env_overrides(cfg)
+
+ vllm_model = worker.get_model()
+ kv_model: nn.Module = MODEL_REGISTRY[model_type](hf)
+ tie_kvprune_weights_from_vllm(vllm_model, kv_model)
+
+ dev = next(vllm_model.parameters()).device
+ dtype = next(vllm_model.parameters()).dtype
+ kv_model.to(device=dev, dtype=dtype)
+ tie_kvprune_rope_buffers_from_vllm(vllm_model, kv_model)
+ delegate_kvprune_embed_tokens_to_vllm(vllm_model, kv_model)
+ delegate_kvprune_compute_logits_to_vllm(vllm_model, kv_model)
+
+ tp_rank = get_tensor_model_parallel_rank()
+ device = torch.device(f"cuda:{torch.cuda.current_device()}")
+
+ if tp_rank == 0:
+ runner = ModelRunner(
+ cfg,
+ rank=0,
+ peer_events=[],
+ external_model=kv_model,
+ embedded_in_vllm_worker=True,
+ device=device,
+ )
+ else:
+ runner = ModelRunner(
+ cfg,
+ rank=tp_rank,
+ batch_ready=None,
+ external_model=kv_model,
+ embedded_in_vllm_worker=True,
+ device=device,
+ )
+
+ setattr(worker, _ATTR, runner)
+ return runner
+
+
+def run_kvprune_tp_compressed_generate(worker: Any, payload: dict[str, Any]) -> dict[str, Any]:
+ """Execute one compressed generation session on this worker (all TP ranks)."""
+ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
+
+ tp_rank = get_tensor_model_parallel_rank()
+ runner = _get_or_create_runner(worker, payload)
+ sequences = _build_sequences(payload)
+ batch_c = _batch_compression_from_payload(payload)
+
+ barrier_sync(use_tp_group=True)
+
+ if tp_rank == 0:
+ runner.generate(sequences, batch_c)
+ return {
+ "tensor_parallel_rank": 0,
+ "prompt_token_ids": [list(s.prompt_token_ids) for s in sequences],
+ "completion_token_ids": [list(s.completion_token_ids) for s in sequences],
+ }
+
+ runner.run_peer_session()
+ return {"tensor_parallel_rank": int(tp_rank), "ok": True}
diff --git a/vllm/kvprune/integration/vllm_model_access.py b/vllm/kvprune/integration/vllm_model_access.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b23c91f958c5376061c4bb499e1868491350f1c
--- /dev/null
+++ b/vllm/kvprune/integration/vllm_model_access.py
@@ -0,0 +1,46 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Access the in-process vLLM model weights for compactor weight sharing."""
+
+from __future__ import annotations
+
+import torch.nn as nn
+
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def extract_vllm_causal_lm(llm: object) -> nn.Module:
+ """Return the root ``nn.Module`` holding transformer + lm_head from a v1 ``LLM``.
+
+ Requires ``LLMEngine`` to have been constructed with ``multiprocess_mode=False``
+ so ``model_executor`` lives in-process (set ``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
+ """
+ llm_engine = getattr(llm, "llm_engine", None)
+ if llm_engine is None:
+ raise RuntimeError("Expected an object with a ``llm_engine`` attribute (e.g. ``vllm.LLM``).")
+
+ ex = getattr(llm_engine, "model_executor", None)
+ if ex is None:
+ raise RuntimeError(
+ "model_executor is unavailable (multiprocess engine mode). "
+ "Set environment variable VLLM_ENABLE_V1_MULTIPROCESSING=0 for "
+ "in-process weight sharing."
+ )
+
+ driver = getattr(ex, "driver_worker", None)
+ if driver is None:
+ raise RuntimeError(
+ "Executor has no driver_worker (unexpected executor type for weight sharing)."
+ )
+
+ worker = getattr(driver, "worker", None)
+ if worker is None:
+ raise RuntimeError("Worker wrapper has no worker loaded.")
+
+ get_model = getattr(worker, "get_model", None)
+ if not callable(get_model):
+ raise RuntimeError("Worker does not expose get_model().")
+
+ return get_model()
diff --git a/vllm/kvprune/integration/weight_tie.py b/vllm/kvprune/integration/weight_tie.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d2356e763dd970e8d0002ae5d7df205d56e5f13
--- /dev/null
+++ b/vllm/kvprune/integration/weight_tie.py
@@ -0,0 +1,192 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Share vLLM parameter storage with compactor ``MODEL_REGISTRY`` models (TP=1)."""
+
+from __future__ import annotations
+
+import types
+
+import torch
+import torch.nn as nn
+
+from vllm.kvprune.utils.context import get_context
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def tie_kvprune_weights_from_vllm(
+ vllm_model: nn.Module,
+ kvprune_model: nn.Module,
+ *,
+ strict: bool = True,
+) -> int:
+ """Point compactor parameters to the same tensors as vLLM where names match.
+
+ Returns the number of parameters tied. Requires identical parameter names
+ and shapes for overlapping weights (typical when both stacks mirror HF
+ naming for the same architecture).
+
+ Args:
+ vllm_model: Model returned by ``worker.get_model()`` (e.g. ``Qwen3ForCausalLM``).
+ kvprune_model: Instance from ``vllm.kvprune.models.MODEL_REGISTRY``.
+ strict: If True, raise when any ``kvprune`` parameter name is missing from
+ ``vllm_model`` or shapes differ.
+ """
+ vd = dict(vllm_model.named_parameters())
+ kd = dict(kvprune_model.named_parameters())
+ tied = 0
+ for name, kp in kd.items():
+ if name not in vd:
+ if strict:
+ raise ValueError(
+ f"kvprune parameter {name!r} not found in vLLM model; "
+ "architecture/layout may differ (disable strict tying only "
+ "for expert debugging)."
+ )
+ continue
+ vp = vd[name]
+ if vp.shape != kp.shape:
+ raise ValueError(
+ f"Shape mismatch for {name}: vllm {vp.shape} vs kvprune {kp.shape}"
+ )
+ kp.data = vp.data
+ tied += 1
+ if tied == 0:
+ raise ValueError(
+ "No parameters were tied — check that vLLM and kvprune model types match "
+ "and use the same state_dict names."
+ )
+ logger.info("Tied %d parameters from vLLM into compactor model (shared storage).", tied)
+ return tied
+
+
+def tie_kvprune_rope_buffers_from_vllm(
+ vllm_model: nn.Module,
+ kvprune_model: nn.Module,
+) -> int:
+ """Copy RoPE ``cos_sin_cache`` buffers from vLLM into kvprune.
+
+ :func:`tie_kvprune_weights_from_vllm` only aliases :class:`~torch.nn.Parameter`
+ tensors. RoPE tables live in buffers; kvprune's simplified ``RotaryEmbedding``
+ can disagree with vLLM's ``rope_parameters`` (YaRN, etc.). Copying
+ ``cos_sin_cache`` after ``.to(device, dtype)`` keeps Q/K rotation aligned with
+ the main model.
+
+ kvprune uses layout ``[max_len, 1, rotary_dim]``; vLLM uses ``[max_len,
+ rotary_dim]``. The singleton dim is filled via ``unsqueeze(1)`` on the vLLM
+ tensor when copying.
+ """
+ vd = dict(vllm_model.named_buffers())
+ copied = 0
+ for name, kb in kvprune_model.named_buffers():
+ if "cos_sin_cache" not in name:
+ continue
+ if name not in vd:
+ logger.warning(
+ "kvprune RoPE buffer %r not found in vLLM; leaving kvprune cache",
+ name,
+ )
+ continue
+ vb = vd[name]
+ if vb.shape == kb.shape:
+ kb.copy_(vb)
+ copied += 1
+ elif kb.dim() == 3 and vb.dim() == 2:
+ if (
+ kb.shape[0] != vb.shape[0]
+ or kb.shape[2] != vb.shape[1]
+ or kb.shape[1] != 1
+ ):
+ raise ValueError(
+ f"cos_sin_cache shape mismatch for {name!r}: "
+ f"vLLM {tuple(vb.shape)} vs kvprune {tuple(kb.shape)}"
+ )
+ kb.copy_(vb.unsqueeze(1))
+ copied += 1
+ else:
+ raise ValueError(
+ f"Unsupported cos_sin_cache layout for {name!r}: "
+ f"vLLM {tuple(vb.shape)} vs kvprune {tuple(kb.shape)}"
+ )
+ if copied:
+ logger.info(
+ "Copied %d RoPE cos_sin_cache buffer(s) from vLLM into kvprune model.",
+ copied,
+ )
+ return copied
+
+
+def delegate_kvprune_embed_tokens_to_vllm(
+ vllm_model: nn.Module,
+ kvprune_model: nn.Module,
+) -> bool:
+ """Use vLLM's ``model.embed_tokens`` forward for kvprune (TP-safe token→shard mapping).
+
+ Even with tied weights, kvprune's simplified contiguous
+ ``VocabParallelEmbedding`` (``vocab_start = rank * partition``) can disagree with
+ vLLM's padded vocabulary and org/added shard ranges, producing invalid indices for
+ ``F.embedding`` on non-zero TP ranks (``index_copy_`` / device-side assert).
+
+ Delegating the forward to vLLM's embedding module keeps masks and indices aligned
+ with the main model while parameters remain shared storage.
+ """
+ if not hasattr(vllm_model, "model") or not hasattr(kvprune_model, "model"):
+ return False
+ vm = getattr(vllm_model.model, "embed_tokens", None)
+ km = getattr(kvprune_model.model, "embed_tokens", None)
+ if vm is None or km is None:
+ logger.warning(
+ "delegate_kvprune_embed_tokens_to_vllm: embed_tokens missing; skipped"
+ )
+ return False
+
+ def _forward(_self_unused: nn.Module, x):
+ return vm(x)
+
+ km.forward = types.MethodType(_forward, km)
+ logger.info(
+ "kvprune model.embed_tokens forward delegated to vLLM (correct vocab-parallel masks)."
+ )
+ return True
+
+
+def delegate_kvprune_compute_logits_to_vllm(
+ vllm_model: nn.Module,
+ kvprune_model: nn.Module,
+) -> bool:
+ """Route ``kvprune_model.compute_logits`` through vLLM's ``compute_logits``.
+
+ Standalone compactor used :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`
+ with ``F.linear`` + TP gather. vLLM applies :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
+ (gather/all-gather, padded-vocab trim, quant hooks). Mismatch here commonly
+ produces garbage token distributions while the rest of the stack looks fine.
+
+ After weight tying, ``vllm_model.compute_logits(hidden)`` uses the same lm_head
+ storage as kvprune; only the *application* path matches production vLLM.
+ """
+ if not callable(getattr(vllm_model, "compute_logits", None)):
+ logger.warning(
+ "delegate_kvprune_compute_logits_to_vllm: vLLM model has no compute_logits; skipped"
+ )
+ return False
+
+ def _compute_logits(_self: nn.Module, hidden_states):
+ # Match kvprune :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`:
+ # prefill logits are for the **last** token of each packed sequence only.
+ context = get_context()
+ if context.is_prefill and context.cu_seqlens_q is not None:
+ cuq = context.cu_seqlens_q
+ last_indices = (cuq[1:] - 1).to(torch.long)
+ n_tok = hidden_states.shape[0]
+ if n_tok > 0:
+ last_indices = last_indices.clamp(min=0, max=n_tok - 1)
+ hidden_states = hidden_states[last_indices].contiguous()
+ # vLLM lm_head + gather expect contiguous activations; non-contiguous views have
+ # caused garbage logits under TP in edge cases.
+ hidden_states = hidden_states.contiguous()
+ logits = vllm_model.compute_logits(hidden_states)
+ return logits
+
+ kvprune_model.compute_logits = types.MethodType(_compute_logits, kvprune_model)
+ return True
diff --git a/vllm/kvprune/kv_cache/__init__.py b/vllm/kvprune/kv_cache/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5ddb214b7e6e1b22d8a1e2892643accb28d38d3
--- /dev/null
+++ b/vllm/kvprune/kv_cache/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Paged KV cache helpers and Triton KV store."""
+
+from vllm.kvprune.kv_cache.store_kv_cache import (
+ decode_store_kv,
+ prefill_store_all_kv,
+ prefill_store_topk_kv,
+)
+
+__all__ = [
+ "decode_store_kv",
+ "prefill_store_all_kv",
+ "prefill_store_topk_kv",
+]
diff --git a/vllm/kvprune/kv_cache/page_table.py b/vllm/kvprune/kv_cache/page_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fbf00b43b2be3b6fa0cbb161cb6706cc194b0bc
--- /dev/null
+++ b/vllm/kvprune/kv_cache/page_table.py
@@ -0,0 +1,313 @@
+import heapq
+import logging
+from enum import Enum, auto
+from typing import List, Optional, Union
+
+import torch
+from vllm.kvprune.config.constants import RESERVED_BATCH
+from vllm.kvprune.kv_cache.write_page_table import scatter_to_page_table
+
+logger = logging.getLogger(__name__)
+
+
+def cdiv(a, b):
+ return (a + b - 1) // b
+
+
+def next_multiple(a, b):
+ return cdiv(a, b) * b
+
+
+class KVAllocationStatus(Enum):
+ EXCEEDS_MAX_SEQUENCE_LENGTH = auto()
+ EXCEEDS_CURRENTLY_AVAILABLE_PAGES = auto()
+ EXCEEDS_MAX_NUM_BATCHES = auto()
+ SUCCESS = auto()
+
+
+class PagedKVCache(torch.nn.Module):
+ """
+ Global paged KV cache.
+ This module manages:
+ * A global K/V backing buffer for all layers:
+ ``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
+ where the first dimension indexes K vs V.
+ * A per-layer page table:
+ ``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
+ mapping logical (batch, kv-head, logical_page) to a physical page ID
+ in the global K/V buffer.
+ * Per-layer, per-(batch, kv-head) logical sequence lengths
+ ``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
+ the number of allocated pages ``bh_num_pages`` for each (layer, batch,
+ head).
+ * A page allocator implemented as a min-heap of free physical pages
+ per layer, plus free batch indices.
+ Pages are of fixed size ``page_size`` tokens.
+ Args:
+ :param num_layers:
+ Number of transformer layers that will use this cache.
+ :param max_logical_pages_per_head:
+ Maximum number of logical pages that can be assigned to a single
+ (batch, kv-head) pair.
+ :param num_pages:
+ Total number of physical pages available in the global cache per
+ layer. The global K/V buffers are of length
+ ``num_pages * page_size`` along the token dimension.
+ :param page_size:
+ Number of tokens stored per page.
+ :param H_kv:
+ Number of KV heads per layer.
+ :param head_dim:
+ Head dimension for K/V.
+ :param max_num_batches:
+ Maximum number of concurrent batches / sequences supported. One
+ batch index is reserved for internal use (``RESERVED_BATCH``).
+ :param dtype:
+ Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
+ :param device:
+ Device on which to allocate the cache (string, torch.device, or
+ int; defaults to ``"cuda"``).
+ """
+
+ def __init__(
+ self,
+ num_layers: int,
+ max_logical_pages_per_head: int,
+ num_pages: int,
+ page_size: int, # tokens per page
+ H_kv: int,
+ head_dim: int,
+ max_num_batches: int,
+ dtype: torch.dtype,
+ device: Union[str, torch.device, int] = "cuda",
+ ):
+ super().__init__()
+ self.n_pages = num_pages
+ self.num_layers = num_layers
+ self.page_size: int = int(page_size)
+ self.H_kv = int(H_kv)
+ self.max_pages_per_head = max_logical_pages_per_head
+ max_num_batches += 1
+ self.max_num_batches = max_num_batches
+ self.head_dim = head_dim
+ cache_shape = (2, num_layers, num_pages * page_size, head_dim)
+ self.kv_cache = torch.empty(cache_shape, dtype=dtype, device=device)
+
+ self.page_table = torch.empty(
+ (num_layers, max_num_batches, H_kv, self.max_pages_per_head),
+ device=device,
+ dtype=torch.int32,
+ )
+
+ # Per-(batch, head) logical seq length (tokens)
+ self.bh_seq_lens = torch.zeros(
+ (num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
+ )
+ # self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
+ self.bh_num_pages = torch.zeros(
+ (num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
+ )
+
+ # Page allocator (min-heap of free physical pages)
+ self.free_pages: List[List[int]] = [
+ list(range(num_pages)) for _ in range(num_layers)
+ ]
+ for free_pages in self.free_pages:
+ heapq.heapify(free_pages)
+ # batch zero is reserved
+ self.free_batches: List[int] = list(reversed(range(max_num_batches)))
+ self.free_batches.remove(RESERVED_BATCH)
+ # Record of physical page ids owned by a batch (for freeing)
+ self.pages_indices_per_batch: List[List[set[int]]] = [
+ [set() for _ in range(num_layers)] for _ in range(max_num_batches)
+ ]
+
+ def new_batch(self) -> Optional[int]:
+ """
+ Reserve a new batch slot.
+ A batch slot corresponds to a row in ``bh_seq_lens`` /
+ ``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
+ heads. This method checks whether a free batch index is available, and
+ whether each layer has at least ``H_kv`` free pages remaining.
+ If both checks pass, it returns a batch index and removes it from
+ ``free_batches``. Otherwise, it returns ``None``.
+
+ Returns:
+ :return Optional[int]:
+ Newly reserved batch index, or ``None`` if no capacity is
+ available.
+ """
+ if self.free_batches and all([self.H_kv <= len(fp) for fp in self.free_pages]):
+ return self.free_batches.pop()
+ return None
+
+ def reserve_tokens(self, batch_index: int, add_tokens: int) -> KVAllocationStatus:
+ """
+ Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
+ Args:
+ :param batch_index:
+ Batch index to reserve space for.
+ :param add_tokens:
+ Number of additional tokens to reserve capacity for.
+ All heads in this batch and all layers reserve
+ the same number of extra tokens.
+ Returns:
+ :return bool:
+ ``True`` if the reservation succeeds; ``False`` otherwise .
+ """
+ cur_bh_lens = self.bh_seq_lens[:, batch_index] # [L, H]
+ curr_pages = self.bh_num_pages[:, batch_index] # [L, H]
+ curr_cap_tokens = curr_pages * self.page_size # [L, H]
+ need_tokens = cur_bh_lens + add_tokens # [L, H]
+ if (need_tokens <= curr_cap_tokens).all():
+ return KVAllocationStatus.SUCCESS
+ missing_tokens = need_tokens - curr_cap_tokens
+ add_pages = cdiv(missing_tokens, self.page_size)
+ new_total_pages = curr_pages + add_pages
+ if (new_total_pages > self.max_pages_per_head).any():
+ return KVAllocationStatus.EXCEEDS_MAX_SEQUENCE_LENGTH
+ # CPU work
+ pages_per_layer_cpu = add_pages.sum(dim=-1).tolist()
+ new_phys_pages = []
+ for layer_index in range(self.num_layers):
+ if pages_per_layer_cpu[layer_index] > len(self.free_pages[layer_index]):
+ return KVAllocationStatus.EXCEEDS_CURRENTLY_AVAILABLE_PAGES
+ for layer_index in range(self.num_layers):
+ this_layer_pages = [
+ heapq.heappop(self.free_pages[layer_index])
+ for _ in range(pages_per_layer_cpu[layer_index])
+ ]
+ self.pages_indices_per_batch[batch_index][layer_index] |= set(
+ this_layer_pages
+ )
+ new_phys_pages.extend(this_layer_pages)
+
+ new_phys_pages = torch.tensor(new_phys_pages, dtype=torch.int32, device="cuda")
+
+ scatter_to_page_table(
+ add_pages=add_pages,
+ new_phys_pages=new_phys_pages,
+ curr_pages=curr_pages,
+ page_table=self.page_table[:, batch_index],
+ max_pages_per_head=self.max_pages_per_head,
+ )
+
+ self.bh_num_pages[:, batch_index, :] = new_total_pages.to(
+ self.bh_num_pages.dtype
+ )
+ return KVAllocationStatus.SUCCESS
+
+ def reclaim_pages(
+ self,
+ batch_index: int,
+ future_reserve_tokens: int = 0,
+ ):
+ """
+ Reclaim unused pages for a single batch index. This shrinks the KV
+ allocation for the batch down to the minimum number of pages needed
+ to hold the current (plus optional future) sequence length.
+
+ Args:
+ :param batch_index:
+ Batch index whose pages should be compacted.
+ :param future_reserve_tokens:
+ Optional number of extra tokens to keep capacity for, beyond
+ the current sequence length. This can reduce churn when
+ sequences are expected to grow slightly in the near future.
+
+ Returns:
+ :return int:
+ Approximate number of bytes freed across both K and V.
+ """
+ device = self.bh_seq_lens.device
+ L, B, H = self.bh_seq_lens.shape
+ assert 0 <= batch_index < B
+
+ seq = self.bh_seq_lens[:, batch_index, :] + future_reserve_tokens # [L, H]
+ alloc = self.bh_num_pages[:, batch_index, :] # [L, H]
+ pt = self.page_table[:, batch_index, :, :].reshape(-1) # [L, H, P]
+
+ # Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
+ used_pages = cdiv(seq, self.page_size)
+ used_pages = torch.minimum(used_pages, alloc)
+
+ # page indices [0..P-1], broadcasted over [L, H, P]
+ p = torch.arange(
+ self.max_pages_per_head, device=device, dtype=torch.int32
+ ).view(1, 1, self.max_pages_per_head)
+
+ # allocated: p < alloc
+ alloc_mask = p < alloc.unsqueeze(-1) # [L, H, P]
+ # to free: allocated and p in [used_pages, alloc)
+ free_mask = alloc_mask & (p >= used_pages.unsqueeze(-1))
+ free_mask_flat = free_mask.view(-1) # [L*H*P]
+ if not free_mask_flat.any():
+ return 0
+
+ idx = free_mask_flat.nonzero(as_tuple=False).squeeze(
+ -1
+ ) # indices of freed slots
+
+ # Freed physical page ids
+ freed_pages = pt[idx]
+ # Compute layer index for each freed slot:
+ # layout is [L, H, P] 鈫?flat index = ((l * H) + h) * P + p
+ freed_layers = (idx // (H * self.max_pages_per_head)).to(torch.int32)
+ freed_pages = freed_pages.tolist()
+ layer_mapping = freed_layers.tolist()
+ self.bh_num_pages[:, batch_index, :] = used_pages
+ for page, layer in zip(freed_pages, layer_mapping):
+ self.pages_indices_per_batch[batch_index][layer].remove(page)
+ heapq.heappush(self.free_pages[layer], page)
+ approximate_bytes_freed = (
+ len(freed_pages)
+ * (self.page_size * self.head_dim * self.kv_cache.element_size())
+ * 2
+ ) # multiply for two for K + V
+ return approximate_bytes_freed
+
+ def _free_batch_layer(self, layer_index: int, batch_index: int) -> None:
+ """
+ Free all pages belonging to batch_index and reset its metadata.
+ """
+ # Return pages to the global heap
+ for phys in self.pages_indices_per_batch[batch_index][layer_index]:
+ heapq.heappush(self.free_pages[layer_index], int(phys))
+
+ self.pages_indices_per_batch[batch_index][layer_index] = set()
+
+ def free_batch(self, batch_index: int) -> None:
+ """
+ Free all resources associated with a batch index.
+ Args:
+ :param batch_index:
+ Batch index to release. Must have been previously allocated
+ via :meth:`new_batch`.
+ """
+ for layer in range(self.num_layers):
+ self._free_batch_layer(layer, batch_index)
+ self.bh_seq_lens[:, batch_index].zero_()
+ self.bh_num_pages[:, batch_index].zero_()
+ self.free_batches.append(batch_index)
+
+ def layer_slices(self, layer: int):
+ """
+ Return layer-local views needed by the attention module.
+
+ For a given ``layer`` index, this method returns the slices of the
+ global K/V cache, page table, and per-(batch, head) sequence lengths
+ corresponding to that layer.
+ Args:
+ :param layer:
+ Layer index ``l`` in ``[0, num_layers)``.
+
+ Returns:
+ :return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ ``(k, v, pt, bh)`` as described above.
+ """
+ assert 0 <= layer < self.num_layers
+ k = self.kv_cache[0, layer]
+ v = self.kv_cache[1, layer]
+ pt = self.page_table[layer]
+ bh = self.bh_seq_lens[layer]
+ return k, v, pt, bh
diff --git a/vllm/kvprune/kv_cache/store_kv_cache.py b/vllm/kvprune/kv_cache/store_kv_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..dda3737fb7ae3e3c4124da3b30270a6275adf080
--- /dev/null
+++ b/vllm/kvprune/kv_cache/store_kv_cache.py
@@ -0,0 +1,473 @@
+import torch
+import triton
+import triton.language as tl
+from vllm.kvprune.config.constants import (
+ TRITON_RESERVED_BATCH as _TRITON_RESERVED_BATCH,
+)
+
+
+@triton.jit
+def _prefill_store_topk_kv_kernel(
+ key,
+ value, # [N_total, H, D] (D stride assumed 1)
+ batch_mapping, # [B] int32 (local b -> true batch)
+ num_tokens_to_retain, # [B] int32
+ indices_topk, # [B, MAX_SEL] int32 (across all heads)
+ # Lengths & page table:
+ bh_lens, # [B, H] int32 (contiguous)
+ page_table, # [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
+ k_cache,
+ v_cache, # [N_PAGES * PAGE_SIZE, D]
+ sk_n,
+ sk_h, # strides for key,value. D stride assumed 1
+ sv_n,
+ sv_h,
+ # Runtime ints
+ MAX_SEL, # num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ K_TILE: tl.constexpr, # how many selected tokens each program processes
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ b_local = tl.program_id(0)
+ tile_id = tl.program_id(1)
+ offs = tl.arange(0, D)
+ # how many tokens we actually keep for this batch
+ k_total = tl.load(num_tokens_to_retain + b_local)
+ if k_total == 0:
+ return
+ # map to true batch row in the page table
+ b_true = tl.load(batch_mapping + b_local)
+ if b_true == TRITON_RESERVED_BATCH:
+ return
+ base = tile_id * K_TILE
+ # process up to K_TILE tokens
+ for j in tl.range(0, K_TILE):
+ sel_idx = base + j
+ if sel_idx < k_total and sel_idx < MAX_SEL:
+ # flattened selection: sel = token * H + head
+ sel = tl.load(indices_topk + b_local * MAX_SEL + sel_idx)
+ tok = sel // HKV
+ head = sel - (tok * HKV)
+ # atomically reserve one position in (b_local, hed)
+ # i.e the KV cache is scrambled when storing
+ len_ptr = bh_lens + b_local * HKV + head
+ pos = tl.atomic_add(len_ptr, 1) # old length (int32)
+ lp = pos // PAGE_SIZE
+ off = pos - lp * PAGE_SIZE
+ # translate logical page to physical page
+ pt_base = (b_true * HKV + head) * N_LOGICAL_PAGES_MAX
+ phys = tl.load(page_table + pt_base + lp).to(tl.int64)
+ # destination row and element offset
+ dst_row = phys * PAGE_SIZE + off
+ dst_off = dst_row * D + offs
+ # load one vector from [N_total, H, D]
+ k_src = key + tok * sk_n + head * sk_h + offs
+ v_src = value + tok * sv_n + head * sv_h + offs
+ tl.store(
+ k_cache + dst_off,
+ tl.load(k_src, cache_modifier=".cv", eviction_policy="evict_first"),
+ eviction_policy="evict_first",
+ )
+ tl.store(
+ v_cache + dst_off,
+ tl.load(v_src, cache_modifier=".cv", eviction_policy="evict_first"),
+ eviction_policy="evict_first",
+ )
+
+
+def prefill_store_topk_kv(
+ *,
+ new_keys: torch.Tensor, # [N_total, H, D]
+ new_vals: torch.Tensor, # [N_total, H, D]
+ indices_topk: torch.Tensor, # [B, MAX_SEL] int32 (global flattened token*H + head)
+ candidate_counts: torch.Tensor, # [B] int32, valid candidates in indices_topk
+ num_tokens_to_retain: torch.Tensor, # [B] int32
+ page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
+ batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
+ bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
+ k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ PAD_TO_PAGE_SIZE: bool = True,
+ cu_seqlens_k: torch.Tensor | None = None,
+ K_TILE: int = 16,
+ TRITON_RESERVED_BATCH: int = None,
+):
+ assert new_keys.shape == new_vals.shape
+ N_total, H, D = new_keys.shape
+ B = indices_topk.shape[0]
+ assert page_table.shape[1] == H
+ assert bh_lens.shape == (B, H)
+ assert new_keys.device == k_cache.device == v_cache.device
+ assert page_table.is_contiguous(), "page table must be contiguous."
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+ assert new_keys.stride(-1) == 1 and new_vals.stride(-1) == 1, (
+ "new_keys/new_vals last dim must be contiguous."
+ )
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ page_table = page_table.to(torch.int32)
+ bh_lens = bh_lens.to(torch.int32)
+ batch_mapping = batch_mapping.to(torch.int32)
+ indices_topk = indices_topk.to(torch.int32)
+ candidate_counts = candidate_counts.to(torch.int32)
+ num_tokens_to_retain = num_tokens_to_retain.to(torch.int32)
+
+ # strides (elements) for [N_total, H, D]
+ sk_n, sk_h, _ = new_keys.stride()
+ sv_n, sv_h, _ = new_vals.stride()
+
+ # tile second grid dim
+ MAX_SEL = indices_topk.shape[-1]
+ N_TILES = (MAX_SEL + K_TILE - 1) // K_TILE
+ grid = (B, max(1, N_TILES))
+ if TRITON_RESERVED_BATCH is None:
+ TRITON_RESERVED_BATCH = _TRITON_RESERVED_BATCH
+ _prefill_store_topk_kv_kernel[grid](
+ key=new_keys,
+ value=new_vals,
+ batch_mapping=batch_mapping,
+ num_tokens_to_retain=num_tokens_to_retain,
+ indices_topk=indices_topk,
+ bh_lens=bh_lens,
+ page_table=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ MAX_SEL=int(MAX_SEL),
+ HKV=H,
+ N_LOGICAL_PAGES_MAX=page_table.shape[2],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ K_TILE=K_TILE,
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
+ )
+ if PAD_TO_PAGE_SIZE:
+ assert cu_seqlens_k is not None
+ assert indices_topk.is_contiguous()
+ assert page_table.is_contiguous()
+ _prefill_store_topk_pad_kernel[(B, H)](
+ key=new_keys,
+ value=new_vals,
+ batch_mapping=batch_mapping,
+ candidate_counts=candidate_counts,
+ num_tokens_to_retain=num_tokens_to_retain,
+ indices=indices_topk,
+ local_lens=bh_lens,
+ page_table_flat=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ cu_seqlens_k=cu_seqlens_k,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ MAX_SEL=int(MAX_SEL),
+ H=H, # type: ignore
+ N_LOGICAL_PAGES_MAX=page_table.shape[2], # type: ignore
+ D=D, # type: ignore
+ PAGE_SIZE=PAGE_SIZE, # type: ignore
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
+ )
+
+
+@triton.jit
+def _prefill_store_topk_pad_kernel(
+ key, # [N_total, H, D]
+ value, # [N_total, H, D]
+ batch_mapping, # [B] int32 (local b -> true batch)
+ candidate_counts, # [B] int32
+ num_tokens_to_retain, # [B] int32
+ indices, # [B, MAX_SEL] int32 (across all heads)
+ local_lens, # [B, H] int32 (contiguous)
+ page_table_flat, # [B_total*H*N_LOGICAL_PAGES_MAX] int32
+ k_cache,
+ v_cache, # [N_PAGES*PAGE_SIZE, D]
+ cu_seqlens_k,
+ sk_n,
+ sk_h,
+ sv_n,
+ sv_h,
+ MAX_SEL,
+ # Constexprs
+ H: tl.constexpr, # number of KV heads
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ b_local = tl.program_id(0)
+ h = tl.program_id(1)
+ offs_d = tl.arange(0, D)
+ L = tl.load(local_lens + b_local * H + h)
+ modulo_page_size = L - (L // PAGE_SIZE) * PAGE_SIZE
+ if modulo_page_size == 0:
+ return
+ need = PAGE_SIZE - modulo_page_size
+ b_true = tl.load(batch_mapping + b_local)
+ if b_true == TRITON_RESERVED_BATCH:
+ return
+ pt_base = (b_true * H + h) * N_LOGICAL_PAGES_MAX
+ written_tokens = 0
+ idx = tl.load(num_tokens_to_retain + b_local)
+ candidate_count = tl.load(candidate_counts + b_local)
+ this_batch_ctx_len = tl.load(cu_seqlens_k + b_local + 1) - tl.load(
+ cu_seqlens_k + b_local
+ )
+ max_additional = this_batch_ctx_len - L
+ while (written_tokens < need and idx < candidate_count) and (
+ written_tokens < max_additional
+ ):
+ # candidate head
+ cand_idx = tl.load(indices + b_local * MAX_SEL + idx)
+ cand_h = cand_idx % H
+ if cand_h == h:
+ tok = cand_idx // H
+ pos = L + written_tokens
+ lp = pos // PAGE_SIZE
+ off = pos - lp * PAGE_SIZE
+ phys = tl.load(page_table_flat + pt_base + lp).to(tl.int32)
+
+ dst_row = phys * PAGE_SIZE + off
+ dst_off = dst_row.to(tl.int64) * D + offs_d
+
+ k_src = key + tok * sk_n + h * sk_h + offs_d
+ v_src = value + tok * sv_n + h * sv_h + offs_d
+
+ tl.store(
+ k_cache + dst_off,
+ tl.load(k_src),
+ )
+ tl.store(
+ v_cache + dst_off,
+ tl.load(v_src),
+ )
+
+ written_tokens += 1
+ idx += 1
+ tl.store(local_lens + b_local * H + h, L + written_tokens)
+
+
+@triton.jit
+def _prefill_store_all_kv_kernel(
+ key,
+ value, # [N, H, D] (D contiguous)
+ cu_seqlens_k, # [B + 1] int32
+ batch_mapping, # [B] int32 (local b -> true batch index)
+ bh_lens, # [B * HKV] int32 (UPDATED)
+ pt_flat, # [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
+ k_cache,
+ v_cache, # [N_PAGES * PAGE_SIZE, D]
+ # source strides (elements)
+ sk_n,
+ sk_h,
+ sv_n,
+ sv_h,
+ # constexpr
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ K_TILE: tl.constexpr, # number of (token, head) pairs processed per program
+):
+ pid_b = tl.program_id(0)
+ pid_blk = tl.program_id(1)
+
+ start = tl.load(cu_seqlens_k + pid_b)
+ end = tl.load(cu_seqlens_k + pid_b + 1)
+ num_toks_this_batch = end - start
+ if num_toks_this_batch <= 0:
+ return
+
+ total_elems = num_toks_this_batch * HKV
+
+ # base linear index in (token, head) grid for this program
+ base = pid_blk * K_TILE
+
+ offs_d = tl.arange(0, D)
+
+ # Iterate K_TILE elements in this tile
+ for i in tl.range(0, K_TILE):
+ idx = base + i
+ if idx < total_elems:
+ # map linear idx -> (t, h)
+ t = idx // HKV
+ h = idx - t * HKV
+
+ len_idx = pid_b * HKV + h
+ L0 = tl.load(bh_lens + len_idx)
+
+ token_idx_in_cache = L0 + t
+ lp = token_idx_in_cache // PAGE_SIZE # logical page
+ off_in_pg = token_idx_in_cache - lp * PAGE_SIZE # pos in page
+
+ # physical page
+ b_true = tl.load(batch_mapping + pid_b).to(tl.int32)
+ pt_base = (b_true * HKV + h) * N_LOGICAL_PAGES_MAX
+ phys = tl.load(pt_flat + pt_base + lp).to(tl.int64)
+
+ row = phys * PAGE_SIZE + off_in_pg
+ dst_off = row * D + offs_d
+
+ n_global = (start + t).to(tl.int64)
+
+ # Use strides for non-contiguous [N, H, D] (D stride == 1)
+ k_src = key + n_global * sk_n + h * sk_h + offs_d
+ v_src = value + n_global * sv_n + h * sv_h + offs_d
+
+ tl.store(k_cache + dst_off, tl.load(k_src))
+ tl.store(v_cache + dst_off, tl.load(v_src))
+
+
+def prefill_store_all_kv(
+ *,
+ new_keys: torch.Tensor,
+ new_values: torch.Tensor, # [N, H_kv, D]
+ cu_seqlens_k: torch.Tensor, # [B + 1] int32
+ max_seqlen_k: int,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ page_table: torch.Tensor, # [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
+ bh_lens: torch.Tensor, # [B, H_kv] int32 (UPDATED)
+ batch_mapping: torch.Tensor, # [B] int32 (local->true)
+ PAGE_SIZE: int,
+ K_TILE: int = 32, # how many (token, head) pairs per program
+):
+ assert new_keys.stride(-1) == 1 and new_values.stride(-1) == 1, (
+ "last dim must be contiguous"
+ )
+ assert page_table.is_contiguous(), "page table must be contiguous"
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous"
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous"
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+
+ N, HKV, D = new_keys.shape
+ B = batch_mapping.shape[0]
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+
+ sk_n, sk_h, _ = new_keys.stride()
+ sv_n, sv_h, _ = new_values.stride()
+ n_tiles = (max_seqlen_k * HKV + K_TILE - 1) // K_TILE
+ grid = (B, n_tiles)
+ _prefill_store_all_kv_kernel[grid](
+ new_keys,
+ new_values,
+ cu_seqlens_k,
+ batch_mapping,
+ bh_lens,
+ page_table,
+ k_cache,
+ v_cache,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ HKV=HKV,
+ N_LOGICAL_PAGES_MAX=page_table.shape[-1],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ K_TILE=K_TILE,
+ )
+ bh_lens += cu_seqlens_k.diff()[:, None]
+
+
+@triton.jit
+def _decode_store_kv_kernel(
+ key,
+ value,
+ batch_mapping, # [B] int32
+ bh_lens, # [B*HKV] int32
+ page_table, # [B_total*HKV*N_LOGICAL_PAGES_MAX]
+ k_cache,
+ v_cache, # [N_PAGES*PAGE_SIZE, D]
+ sk_b,
+ sk_h,
+ sv_b,
+ sv_h,
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ h = tl.program_id(1)
+ mapped_b = tl.load(batch_mapping + pid_b)
+ if mapped_b == TRITON_RESERVED_BATCH:
+ return
+ offs_d = tl.arange(0, D)
+
+ length = tl.load(bh_lens + pid_b * HKV + h)
+ logical_page = length // PAGE_SIZE
+ internal_offset = length - logical_page * PAGE_SIZE
+
+ pt_base = (mapped_b * HKV + h) * N_LOGICAL_PAGES_MAX
+ physical_page = tl.load(page_table + pt_base + logical_page).to(tl.int64)
+
+ dst_row = physical_page * PAGE_SIZE + internal_offset
+
+ # Source addressing using strides (D stride == 1)
+ k_src = key + pid_b * sk_b + h * sk_h + offs_d
+ v_src = value + pid_b * sv_b + h * sv_h + offs_d
+
+ dst_off = dst_row * D + offs_d
+ tl.store(k_cache + dst_off, tl.load(k_src))
+ tl.store(v_cache + dst_off, tl.load(v_src))
+ tl.store(bh_lens + pid_b * HKV + h, length + 1)
+
+
+def decode_store_kv(
+ *,
+ key: torch.Tensor, # [B, HKV, D]
+ value: torch.Tensor, # [B, HKV, D]
+ batch_mapping: torch.Tensor, # [B] int32
+ bh_lens: torch.Tensor, # [B, HKV] or flattened [B*HKV] int32
+ page_table: torch.Tensor, # [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor, # [N_PAGES*PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ TRITON_RESERVED_BATCH: int = None,
+):
+ assert key.shape == value.shape and key.ndim == 3, "key/value must be [B, HKV, D]"
+ B, HKV, D = key.shape
+ assert key.stride(-1) == 1 and value.stride(-1) == 1, (
+ "key/value last dim must be contiguous."
+ )
+ assert page_table.is_contiguous(), "page table must be contiguous."
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ sk_b, sk_h, _ = key.stride()
+ sv_b, sv_h, _ = value.stride()
+ grid = (
+ int(batch_mapping.shape[0]),
+ HKV,
+ )
+ _decode_store_kv_kernel[grid](
+ key=key,
+ value=value,
+ batch_mapping=batch_mapping,
+ bh_lens=bh_lens,
+ page_table=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sk_b=sk_b,
+ sk_h=sk_h,
+ sv_b=sv_b,
+ sv_h=sv_h,
+ HKV=HKV,
+ N_LOGICAL_PAGES_MAX=page_table.shape[2],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH
+ if TRITON_RESERVED_BATCH is not None
+ else _TRITON_RESERVED_BATCH,
+ )
diff --git a/vllm/kvprune/kv_cache/write_page_table.py b/vllm/kvprune/kv_cache/write_page_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99c4e1f566af65c4586c47c727ae671f9c801d7
--- /dev/null
+++ b/vllm/kvprune/kv_cache/write_page_table.py
@@ -0,0 +1,110 @@
+import torch
+import triton
+import triton.language as tl
+
+
+def scatter_to_page_table(
+ add_pages: torch.Tensor, # [L, H] int32
+ new_phys_pages: torch.Tensor, # [N]
+ curr_pages: torch.Tensor, # [L, H] int32
+ page_table: torch.Tensor, # [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
+ max_pages_per_head: int,
+):
+ """
+ Append newly allocated physical pages into a layered page table via Triton.
+ For each (layer ``l``, head ``h``):
+ Args:
+ :param add_pages:
+ Tensor of shape ``[L, H]`` (int32) indicating how many pages to
+ append for each (layer, head).
+ :param new_phys_pages:
+ 1D tensor of shape ``[N]`` (int32) containing physical page IDs
+ for all (layer, head) pairs, concatenated in row-major (L, H)
+ order. ``N`` must equal ``add_pages.sum()``.
+ :param curr_pages:
+ Tensor of shape ``[L, H]`` (int32) with the current logical page
+ counts per (layer, head) before this update.
+ :param page_table:
+ Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
+ the logical to physical page mapping. The last dimension is
+ logically indexed as logical_page ∈ [0, max_pages_per_head).
+ :param max_pages_per_head:
+ Maximum number of logical pages permitted per (layer, head). The
+ kernel skips writes beyond this bound.
+ Returns:
+ None. The function updates ``page_table`` in-place.
+ """
+ L, H = add_pages.shape
+ if L == 0 or H == 0:
+ return
+ add_flat = add_pages.to(torch.int32).contiguous().view(-1)
+ curr_flat = curr_pages.to(torch.int32).contiguous().view(-1)
+ cum_page_heads = torch.empty(L * H + 1, device="cuda", dtype=torch.int32)
+ cum_page_heads[0] = 0
+ torch.cumsum(add_flat, 0, out=cum_page_heads[1:])
+ stride_pl, stride_ph, stride_pp = page_table.stride()
+ grid = (L, H)
+ _scatter_pages_kernel_lh[grid](
+ add_flat,
+ cum_page_heads,
+ new_phys_pages,
+ curr_flat,
+ page_table,
+ stride_pl,
+ stride_ph,
+ stride_pp,
+ L=L,
+ H=H,
+ max_pages_per_head=max_pages_per_head,
+ )
+
+
+@triton.jit
+def _scatter_pages_kernel_lh(
+ add_pages, # int32 [L*H]
+ cum_page_heads, # int32 [L*H], base offset in flat_new_phys per (l,h)
+ flat_new_phys, # int32 [total_pages]
+ curr_pages, # int32 [L*H], existing logical pages per (l,h)
+ page_table_ptr, # int32* base pointer to page_table
+ stride_pl, # int, stride for layer dim
+ stride_ph, # int, stride for head dim
+ stride_pp, # int, stride for page dim
+ L: tl.constexpr,
+ H: tl.constexpr,
+ max_pages_per_head: tl.constexpr,
+):
+ layer_idx = tl.program_id(0)
+ h = tl.program_id(1)
+ if layer_idx >= L or h >= H:
+ return
+
+ lh = layer_idx * H + h
+ ap = tl.load(add_pages + lh)
+ if ap <= 0:
+ return
+
+ base = tl.load(cum_page_heads + lh)
+ cp = tl.load(curr_pages + lh)
+
+ # Append ap pages: logical pages [cp .. cp+ap)
+ for i in tl.range(0, ap):
+ phys = tl.load(flat_new_phys + base + i)
+ lp = cp + i
+ if lp < max_pages_per_head:
+ offset = layer_idx * stride_pl + h * stride_ph + lp * stride_pp
+ tl.store(page_table_ptr + offset, phys)
+
+
+# TODO: write reclaim kernel
+@triton.jit
+def reclaim_page_kernel():
+ pass
+
+
+def reclaim_pages(
+ batch_index: int,
+ bh_seq_lens: torch.Tensor,
+ bh_num_pages: torch.Tensor,
+ page_table: torch.Tensor,
+):
+ pass
diff --git a/vllm/kvprune/kvprune_to_vllm.md b/vllm/kvprune/kvprune_to_vllm.md
new file mode 100644
index 0000000000000000000000000000000000000000..8e361a886de9711edf275742a6a4508e9284b5c0
--- /dev/null
+++ b/vllm/kvprune/kvprune_to_vllm.md
@@ -0,0 +1,68 @@
+# KV-prune 与上游 vLLM 的集成说明
+
+本文说明:**剪枝/压缩(Compactor)功能**在「官网 vLLM 主仓库」里改动了哪些位置、是否只有少量文件、以及随 vLLM 版本升级时如何预期合并成本。
+
+## 1. 是否「仅仅」改了少数几个脚本?
+
+**核心运行时接线**确实集中在少数几个**非** `vllm/kvprune/` 下的文件;功能主体在 `vllm/kvprune/` 包内独立维护。
+
+| 路径 | 作用简述 |
+|------|-----------|
+| `vllm/env_override.py` | 在 `import vllm` 最早阶段设置与 kvprune 相关的默认环境变量(如 v1 多进程默认、压缩默认开关、可选释放 v1 KV 等)。 |
+| `vllm/__init__.py` | 对外导出 `CompressionParams`(懒加载至 `vllm.kvprune.integration.compression_params`)。 |
+| `vllm/entrypoints/llm.py` | `kvprune_compression` 参数、`generate(..., compression=...)`、v1 `enforce_eager` / `num_gpu_blocks_override` 策略、懒加载 compactor、委托 `compressed_generate`。 |
+| `vllm/v1/worker/gpu_worker.py` | `kvprune_v1_compressed_generate`:供 `collective_rpc` 调用的 TP 多卡压缩生成入口。 |
+| `tests/conftest.py` | 测试在导入 vLLM 前覆盖部分 `VLLM_KVPRUNE_*` 默认值,避免全量测试默认走压缩路径。 |
+| `vllm\vllm\envs.py` | envs.py 中对 VLLM_KVPRUNE_* 的集中注册 |
+
+**此外(可选/示例,非引擎必需):**
+
+- `examples/offline_inference/` 下若干 `*kvprune*` 示例脚本:演示用法,不参与核心引擎加载。
+
+**结论:**
+- **「官网 vLLM 主包」里与 kvprune 强相关的改动,主要就是上表 4 个文件 + 测试根配置**(若把测试也算进「集成面」,共 5 处常见提法)。
+- **算法、Compactor、TP 内嵌 runner 等**均在 `vllm/kvprune/`(及该目录下的 `integration/`)中,与上游 diff 相对隔离。
+
+## 2. 随 vLLM 版本更新,是否「很容易」同步剪枝压缩功能?
+
+**相对容易的部分:**
+
+- **集成面小**:合并冲突主要出现在上述少数文件,而不是遍布整个 executor / attention / model 层。
+- **逻辑内聚**:大量代码在 `vllm/kvprune/`,可整体移植或 `git` 三方合并时以子树为主处理。
+
+**仍需人工跟进的点(不能假设「自动无痛」):**
+
+- **`entrypoints/llm.py` 属于高频变更文件**:上游每次大版本可能重构 `LLM` 构造参数、`generate` 签名或引擎初始化;需要**逐次解决冲突**并回归压缩路径。
+- **`v1/worker/gpu_worker.py`** 同样会随 executor / RPC 接口变动;`collective_rpc` 方法名或 worker 基类若有变化,需对齐。
+- **`env_override.py`** 若上游调整导入顺序或新增全局默认环境变量,需避免覆盖冲突或行为打架。
+- **vLLM v1 内部 API**(如 `worker.get_model()`、`vllm_config` 结构)若变更,`vllm/kvprune/integration/*` 也可能要跟着改——这类改动**不在**「仅 5 个文件」里,但仍是**集成层**维护成本。
+
+**建议同步流程(简版):**
+
+1. 在新上游 tag 上先合并/应用 `vllm/kvprune/` 目录。
+2. 再手动合并上述 4 个主包文件 + `tests/conftest.py`。
+3. 跑与 kvprune 相关的测试与至少一条离线 `compression` 示例。
+4. 关注发行说明中 `LLM`、`EngineArgs`、`gpu_worker`、多进程默认的破坏性变更。
+
+## 3. 与「深度改内核」方案的区别
+
+当前设计**没有**在 `model_executor` 的统一注意力路径上大规模插入 kvprune 钩子(相关辅助逻辑主要在 `vllm/kvprune` 内部)。因此:
+
+- **上游同步时**,通常不必与 FlashAttention / 每层模型代码逐文件对打;
+- **代价是**:功能边界以「共享权重 + compactor 引擎 + 可选 TP RPC」为主,与「原生 KV 算子级一体化」的改动面不同。
+
+---
+
+## 4. 目录重建说明(与 `compactor-vllm` 对齐)
+
+`vllm/kvprune/` 以 `vllm/compactor-vllm/src/compactor_vllm/` 为算法与内核基线整体迁入(`compactor_vllm` → `vllm.kvprune`),再叠加上游集成层:
+
+- **集成**:`integration/*`(`compressed_generate`、`compactor_shared`、`config_adapter`、`v1_tp_runner`、`weight_tie` 等)仍负责「同 `LLM.generate` 前端、双后端」。
+- **TP / 调度**:`core/model_runner.py`、`utils/tp_utils.py`、`utils/tp_collectives.py`、`utils/kv_dist.py` 等保留 vLLM 内嵌 TP 与 `collective_rpc` 路径。
+- **三种 attention 模式**:`config/engine_config.py` 的 `KvpruneAttentionSchedule` + `integration/config_adapter.py` 的环境变量解析;`layers/attention.py` + `attention/fa_paged_bridge.py` 实现 `fa_triton` / `pdtriton` / `pdfa`。
+
+临时备份目录 `vllm/kvprune_legacy_save/` 可在确认无误后手动删除。
+
+---
+
+*文档随仓库维护;若集成文件列表有增删,请同步更新本节表格。*
diff --git a/vllm/kvprune/layers/__init__.py b/vllm/kvprune/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b10a0da49360a96886ab956e04e8977f0ebd842f
--- /dev/null
+++ b/vllm/kvprune/layers/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Layers from upstream compactor (attention, linear, MoE, …).
+
+Prefer importing concrete modules, e.g. ``from vllm.kvprune.layers.attention import ...``.
+"""
+
+__all__: list[str] = []
diff --git a/vllm/kvprune/layers/activation.py b/vllm/kvprune/layers/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a19e488cf3f5d25670fcdc8f4a17161ca64e1010
--- /dev/null
+++ b/vllm/kvprune/layers/activation.py
@@ -0,0 +1,13 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class SiluAndMul(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # @torch.compile
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, y = x.chunk(2, -1)
+ return F.silu(x) * y
diff --git a/vllm/kvprune/layers/attention.py b/vllm/kvprune/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..564da676a6a5e029d41641cb7120f4ece0101d4c
--- /dev/null
+++ b/vllm/kvprune/layers/attention.py
@@ -0,0 +1,208 @@
+from typing import Optional
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_varlen_func
+from torch import nn
+
+from vllm.kvprune.attention.fa_paged_bridge import (
+ flash_decode_from_paged,
+ flash_prefill_from_paged,
+)
+from vllm.kvprune.attention.sparse_decode_kernel import head_sparse_decode_attention
+from vllm.kvprune.attention.sparse_varlen_kernel import (
+ causal_sparse_varlen_with_cache,
+)
+from vllm.kvprune.compression.common import extract_and_store_top_kv
+from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule
+from vllm.kvprune.kv_cache.store_kv_cache import decode_store_kv, prefill_store_all_kv
+from vllm.kvprune.utils.context import Context, get_context
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ num_heads,
+ head_dim,
+ scale,
+ num_kv_heads,
+ ):
+ super().__init__()
+ self.num_heads: int = num_heads
+ self.head_dim = head_dim
+ self.scale: float = scale
+ self.num_kv_heads = int(num_kv_heads)
+
+ self.k_cache: Optional[torch.Tensor] = None
+ self.v_cache: Optional[torch.Tensor] = None
+ self.page_table: Optional[torch.Tensor] = None
+ self.bh_seq_lens: Optional[torch.Tensor] = None
+ self.page_size: Optional[int] = None
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ scores: Optional[torch.Tensor] = None,
+ ):
+ context: Context = get_context()
+ batch_mapping = context.batch_mapping
+ seq_lens = (
+ None
+ if self.bh_seq_lens is None
+ else self.bh_seq_lens.index_select(0, batch_mapping).contiguous()
+ )
+ sched = context.attention_schedule
+ use_triton_prefill_attn = (
+ sched == KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
+ )
+ use_fa_decode = sched == KvpruneAttentionSchedule.PDFA
+
+ if context.is_prefill:
+ seq_lens_copy = seq_lens.clone() if seq_lens is not None else None
+ if (
+ self.k_cache is not None
+ and context.do_compression
+ and scores is not None
+ ):
+ compression_context = context.compression_context
+ assert compression_context is not None
+ maybe_execute_in_stream(
+ extract_and_store_top_kv,
+ scores=scores,
+ cu_seqlens_k=context.cu_seqlens_k,
+ max_k_len=context.max_seqlen_k,
+ top_k=compression_context.max_tokens_to_retain,
+ H=int(self.num_kv_heads),
+ new_keys=k,
+ new_vals=v,
+ num_tokens_to_retain=compression_context.batch_tokens_to_retain,
+ page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=seq_lens,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ PAGE_SIZE=self.page_size,
+ PAD_TO_PAGE_SIZE=True,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ elif self.k_cache is not None:
+ maybe_execute_in_stream(
+ prefill_store_all_kv,
+ new_keys=k,
+ new_values=v,
+ cu_seqlens_k=context.cu_seqlens_k,
+ max_seqlen_k=context.max_seqlen_k,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ page_table=self.page_table,
+ bh_lens=seq_lens,
+ batch_mapping=batch_mapping,
+ PAGE_SIZE=self.page_size,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+
+ if use_triton_prefill_attn:
+ if context.do_compression and context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+ assert seq_lens_copy is not None
+ o = causal_sparse_varlen_with_cache(
+ q,
+ k,
+ v,
+ self.k_cache,
+ self.v_cache,
+ seq_lens_bh=seq_lens_copy,
+ global_page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_q=context.max_seqlen_q,
+ max_seqlen_k_cache=context.max_bh_len,
+ HKV=int(self.num_kv_heads),
+ PAGE_SIZE=self.page_size,
+ sm_scale=self.scale,
+ )
+ elif context.do_compression:
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+ assert seq_lens_copy is not None
+ o = flash_prefill_from_paged(
+ q,
+ k,
+ v,
+ self.k_cache,
+ self.v_cache,
+ seq_lens_bh_before=seq_lens_copy,
+ global_page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_q=context.max_seqlen_q,
+ PAGE_SIZE=self.page_size,
+ HKV=int(self.num_kv_heads),
+ sm_scale=self.scale,
+ )
+ else:
+ o = flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ max_seqlen_q=context.max_seqlen_q,
+ cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_k=context.max_seqlen_k,
+ cu_seqlens_k=context.cu_seqlens_k,
+ softmax_scale=self.scale,
+ causal=True,
+ )
+ else:
+ assert self.k_cache is not None, "KV Cache must be initialized for decoding"
+ decode_store_kv(
+ key=k,
+ value=v,
+ batch_mapping=batch_mapping,
+ bh_lens=seq_lens,
+ page_table=self.page_table,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ PAGE_SIZE=self.page_size,
+ )
+
+ if use_fa_decode:
+ assert seq_lens is not None
+ o = flash_decode_from_paged(
+ q,
+ self.k_cache,
+ self.v_cache,
+ seq_lens_bh=seq_lens,
+ global_page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ PAGE_SIZE=self.page_size,
+ HKV=int(self.num_kv_heads),
+ sm_scale=self.scale,
+ )
+ else:
+ o = head_sparse_decode_attention(
+ q,
+ self.k_cache,
+ self.v_cache,
+ seq_lens,
+ self.page_table,
+ batch_mapping,
+ int(self.num_kv_heads),
+ self.page_size,
+ self.scale,
+ key_split=context.key_split,
+ )
+
+ if self.bh_seq_lens is not None:
+ longbm = batch_mapping.to(
+ device=self.bh_seq_lens.device, dtype=torch.long
+ )
+ maybe_execute_in_stream(
+ self.bh_seq_lens.index_copy_,
+ 0,
+ longbm,
+ seq_lens,
+ STORE_STREAM=context.STORE_STREAM if context.is_prefill else None,
+ )
+ return o
diff --git a/vllm/kvprune/layers/embed_head.py b/vllm/kvprune/layers/embed_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1c19ab17b708dd8d2b0c3e7f17e946cc1983ca
--- /dev/null
+++ b/vllm/kvprune/layers/embed_head.py
@@ -0,0 +1,111 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from vllm.kvprune.utils.context import get_context
+from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
+from vllm.kvprune.utils.tp_utils import (
+ tensor_parallel_rank_for_sharding,
+ tensor_parallel_world_size_for_sharding,
+)
+from torch import nn
+
+
+class VocabParallelEmbedding(nn.Module):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ ):
+ super().__init__()
+ self.tp_rank = tensor_parallel_rank_for_sharding()
+ self.tp_size = tensor_parallel_world_size_for_sharding()
+ assert num_embeddings % self.tp_size == 0
+ self.num_embeddings = num_embeddings
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
+ self.weight = nn.Parameter(
+ torch.empty(self.num_embeddings_per_partition, embedding_dim)
+ )
+ self.weight.weight_loader = self.weight_loader
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(0)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor):
+ if self.tp_size > 1:
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
+ x = mask * (x - self.vocab_start_idx)
+ y = F.embedding(x, self.weight)
+ if self.tp_size > 1:
+ y = mask.unsqueeze(1) * y
+ tensor_parallel_all_reduce(y)
+ return y
+
+
+class ParallelLMHead(VocabParallelEmbedding):
+ """LM head with TP vocab sharding.
+
+ When embedded in a vLLM worker, logits must be gathered on the **tensor-
+ parallel** process group (see :func:`~vllm.distributed.communication_op.tensor_model_parallel_gather`),
+ not the default :func:`torch.distributed.gather` — otherwise shard order / group
+ mismatch yields garbage logits and decoded gibberish.
+
+ After gather, logits are truncated to ``org_vocab_size`` (HF tokenizer vocab),
+ matching :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
+ removal of padded vocabulary columns.
+ """
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ bias: bool = False,
+ *,
+ org_vocab_size: int | None = None,
+ ):
+ assert not bias
+ super().__init__(num_embeddings, embedding_dim)
+ # Original (unpadded) vocab size for logits truncation; defaults to num_embeddings.
+ self.org_vocab_size = (
+ int(org_vocab_size) if org_vocab_size is not None else num_embeddings
+ )
+
+ def forward(self, x: torch.Tensor):
+ context = get_context()
+ if context.is_prefill:
+ cu = context.cu_seqlens_q
+ last_indices = (cu[1:] - 1).to(torch.long)
+ n_tok = x.shape[0]
+ if n_tok > 0:
+ last_indices = last_indices.clamp(min=0, max=n_tok - 1)
+ x = x[last_indices].contiguous()
+ logits = F.linear(x, self.weight)
+ if self.tp_size > 1:
+ logits = self._gather_logits_tp(logits)
+ if logits is not None and logits.shape[-1] > self.org_vocab_size:
+ logits = logits[..., : self.org_vocab_size]
+ return logits
+
+ def _gather_logits_tp(self, logits: torch.Tensor) -> torch.Tensor | None:
+ try:
+ from vllm.distributed.parallel_state import model_parallel_is_initialized
+ from vllm.distributed.communication_op import (
+ tensor_model_parallel_gather,
+ )
+
+ if model_parallel_is_initialized():
+ return tensor_model_parallel_gather(logits, dst=0, dim=-1)
+ except Exception:
+ pass
+ all_logits = (
+ [torch.empty_like(logits) for _ in range(self.tp_size)]
+ if self.tp_rank == 0
+ else None
+ )
+ dist.gather(logits, all_logits, 0)
+ return torch.cat(all_logits, -1) if self.tp_rank == 0 else None
diff --git a/vllm/kvprune/layers/layernorm.py b/vllm/kvprune/layers/layernorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dabaad38ce9dec79b9e7c40a1405809c9235f3c
--- /dev/null
+++ b/vllm/kvprune/layers/layernorm.py
@@ -0,0 +1,49 @@
+import torch
+from torch import nn
+
+
+class RMSNorm(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ eps: float = 1e-6,
+ ) -> None:
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+
+ # @torch.compile
+ def rms_forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ orig_dtype = x.dtype
+ x = x.float()
+ var = x.pow(2).mean(dim=-1, keepdim=True)
+ x.mul_(torch.rsqrt(var + self.eps))
+ x = x.to(orig_dtype).mul_(self.weight)
+ return x
+
+ # @torch.compile
+ def add_rms_forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_dtype = x.dtype
+ x = x.float().add_(residual.float())
+ residual = x.to(orig_dtype)
+ var = x.pow(2).mean(dim=-1, keepdim=True)
+ x.mul_(torch.rsqrt(var + self.eps))
+ x = x.to(orig_dtype).mul_(self.weight)
+ return x, residual
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ return self.rms_forward(x)
+ else:
+ return self.add_rms_forward(x, residual)
diff --git a/vllm/kvprune/layers/linear.py b/vllm/kvprune/layers/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..be86096d2b1694c866f170fc572b6068511dee85
--- /dev/null
+++ b/vllm/kvprune/layers/linear.py
@@ -0,0 +1,158 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
+from vllm.kvprune.utils.tp_utils import (
+ tensor_parallel_rank_for_sharding,
+ tensor_parallel_world_size_for_sharding,
+)
+from torch import nn
+
+
+def divide(numerator, denominator):
+ assert numerator % denominator == 0
+ return numerator // denominator
+
+
+class LinearBase(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ tp_dim: int | None = None,
+ ):
+ super().__init__()
+ self.tp_dim = tp_dim
+ self.tp_rank = tensor_parallel_rank_for_sharding()
+ self.tp_size = tensor_parallel_world_size_for_sharding()
+ self.weight = nn.Parameter(torch.empty(output_size, input_size))
+ self.weight.weight_loader = self.weight_loader
+ if bias:
+ self.bias = nn.Parameter(torch.empty(output_size))
+ self.bias.weight_loader = self.weight_loader
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class ReplicatedLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ super().__init__(input_size, output_size, bias)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param.data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.linear(x, self.weight, self.bias)
+
+
+class ColumnParallelLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ tp_size = tensor_parallel_world_size_for_sharding()
+ super().__init__(input_size, divide(output_size, tp_size), bias, 0)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(self.tp_dim)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.linear(x, self.weight, self.bias)
+
+
+class MergedColumnParallelLinear(ColumnParallelLinear):
+ def __init__(
+ self,
+ input_size: int,
+ output_sizes: list[int],
+ bias: bool = False,
+ ):
+ self.output_sizes = output_sizes
+ super().__init__(input_size, sum(output_sizes), bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int
+ ):
+ param_data = param.data
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_data.copy_(loaded_weight)
+
+
+class QKVParallelLinear(ColumnParallelLinear):
+ def __init__(
+ self,
+ hidden_size: int,
+ head_size: int,
+ total_num_heads: int,
+ total_num_kv_heads: int | None = None,
+ bias: bool = False,
+ ):
+ tp_size = tensor_parallel_world_size_for_sharding()
+ total_num_kv_heads = total_num_kv_heads or total_num_heads
+ self.head_size = head_size
+ self.num_heads = divide(total_num_heads, tp_size)
+ self.num_kv_heads = divide(total_num_kv_heads, tp_size)
+ output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
+ super().__init__(hidden_size, output_size, bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str
+ ):
+ param_data = param.data
+ assert loaded_shard_id in ["q", "k", "v"]
+ if loaded_shard_id == "q":
+ shard_size = self.num_heads * self.head_size
+ shard_offset = 0
+ elif loaded_shard_id == "k":
+ shard_size = self.num_kv_heads * self.head_size
+ shard_offset = self.num_heads * self.head_size
+ else:
+ shard_size = self.num_kv_heads * self.head_size
+ shard_offset = (
+ self.num_heads * self.head_size + self.num_kv_heads * self.head_size
+ )
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_data.copy_(loaded_weight)
+
+
+class RowParallelLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ tp_size = tensor_parallel_world_size_for_sharding()
+ super().__init__(divide(input_size, tp_size), output_size, bias, 1)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(self.tp_dim)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
+ if self.tp_size > 1:
+ tensor_parallel_all_reduce(y)
+ return y
diff --git a/vllm/kvprune/layers/moe.py b/vllm/kvprune/layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d487e650eb9b58396a6eb1998a7202c1baf3986
--- /dev/null
+++ b/vllm/kvprune/layers/moe.py
@@ -0,0 +1,177 @@
+import torch
+import torch.distributed as dist
+from vllm.kvprune.triton_kernels.matmul_ogs import matmul_ogs
+from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
+from vllm.kvprune.utils.tp_utils import (
+ tensor_parallel_rank_for_sharding,
+ tensor_parallel_world_size_for_sharding,
+)
+from torch import nn
+
+
+def divide(numerator, denominator):
+ assert numerator % denominator == 0
+ return numerator // denominator
+
+
+class TritonFusedMoeLinearBase(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ tp_dim: int | None = None,
+ ) -> None:
+ super().__init__()
+ self.tp_dim = tp_dim
+ self.tp_rank = tensor_parallel_rank_for_sharding()
+ self.tp_size = tensor_parallel_world_size_for_sharding()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.num_experts = num_experts
+
+ self.weight = nn.Parameter(
+ torch.empty((num_experts, in_features, out_features)).transpose(-1, -2)
+ )
+ self.weight.weight_loader = self.weight_loader
+
+ if bias:
+ self.bias = nn.Parameter(torch.empty((num_experts, out_features)))
+ self.bias.weight_loader = self.weight_loader
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class ReplicatedTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ super().__init__(in_features, out_features, num_experts, bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ param.data[expert_idx].copy_(loaded_weight, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ return matmul_ogs(
+ x,
+ self.weight,
+ self.bias,
+ **kwargs,
+ )
+
+
+class RowParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ tp_size = (
+ tensor_parallel_world_size_for_sharding()
+ if dist.is_initialized()
+ else 1
+ )
+ super().__init__(
+ divide(in_features, tp_size), out_features, num_experts, bias, 2
+ )
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ shard_size = param.size(2)
+ start_idx = self.tp_rank * shard_size
+ local_shard = loaded_weight[:, start_idx : start_idx + shard_size]
+ param.data[expert_idx].copy_(local_shard, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ y = matmul_ogs(
+ x,
+ w,
+ self.bias,
+ **kwargs,
+ )
+ if self.tp_size > 1:
+ tensor_parallel_all_reduce(y)
+ return y
+
+
+class ColumnParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ tp_size = (
+ tensor_parallel_world_size_for_sharding()
+ if dist.is_initialized()
+ else 1
+ )
+ super().__init__(
+ in_features, divide(out_features, tp_size), num_experts, bias, 1
+ )
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ shard_size = param.size(1)
+ start_idx = self.tp_rank * shard_size
+ local_shard = loaded_weight[start_idx : start_idx + shard_size, :]
+ param.data[expert_idx].copy_(local_shard, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ y = matmul_ogs(
+ x,
+ w,
+ self.bias,
+ **kwargs,
+ )
+ return y
+
+
+class MergedColumnParallelTritonFusedMoeLinear(ColumnParallelTritonFusedMoeLinear):
+ def __init__(
+ self,
+ in_features: int,
+ out_feature_list: list[int],
+ num_experts: int,
+ bias: bool = False,
+ ):
+ self.out_feature_list = out_feature_list
+ super().__init__(in_features, sum(out_feature_list), num_experts, bias)
+
+ def weight_loader(
+ self,
+ param: nn.Parameter,
+ loaded_weight: torch.Tensor,
+ expert_idx: int,
+ shard_id: int,
+ ):
+ param_data = param.data
+ shard_offset = sum(self.out_feature_list[:shard_id]) // self.tp_size
+ shard_size = self.out_feature_list[shard_id] // self.tp_size
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ local_weight = loaded_weight.chunk(self.tp_size, dim=self.tp_dim - 1)[
+ self.tp_rank
+ ]
+ param_data[expert_idx].copy_(local_weight, non_blocking=True)
diff --git a/vllm/kvprune/layers/rotary_embedding.py b/vllm/kvprune/layers/rotary_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..506616f912a57ff1dcf2543d62ec096b258e31d6
--- /dev/null
+++ b/vllm/kvprune/layers/rotary_embedding.py
@@ -0,0 +1,94 @@
+import math
+from functools import lru_cache
+
+import torch
+from torch import nn
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ x1, x2 = torch.chunk(x.float(), 2, dim=-1)
+ y1 = x1 * cos - x2 * sin
+ y2 = x2 * cos + x1 * sin
+ return torch.cat((y1, y2), dim=-1).to(x.dtype)
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: float,
+ rope_scaling: tuple,
+ ) -> None:
+ super().__init__()
+ self.head_size = head_size
+ assert rotary_dim == head_size
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)
+ )
+ if rope_scaling is not None:
+ (
+ rope_type,
+ factor,
+ low_freq_factor,
+ high_freq_factor,
+ original_max_position_embeddings,
+ ) = rope_scaling
+ assert rope_type == "llama3"
+ old_context_len = original_max_position_embeddings
+ low_freq_wavelen = old_context_len / low_freq_factor
+ high_freq_wavelen = old_context_len / high_freq_factor
+ wavelen = 2 * math.pi / inv_freq
+
+ inv_freq_llama = torch.where(
+ wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
+ )
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
+ high_freq_factor - low_freq_factor
+ )
+ smoothed_inv_freq = (
+ 1 - smooth_factor
+ ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(
+ wavelen > low_freq_wavelen
+ )
+ inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
+
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
+
+ # @torch.compile
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ query = apply_rotary_emb(query, cos, sin)
+ key = apply_rotary_emb(key, cos, sin)
+ return query, key
+
+
+@lru_cache(1)
+def get_rope(
+ head_size: int,
+ rotary_dim: int,
+ max_position: int,
+ base: float,
+ rope_scaling: tuple | None = None,
+):
+ rotary_emb = RotaryEmbedding(
+ head_size, rotary_dim, max_position, base, rope_scaling
+ )
+ return rotary_emb
diff --git a/vllm/kvprune/layers/sampler.py b/vllm/kvprune/layers/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0761b7c79bc7dc511180078c4d059c3423b3f8f
--- /dev/null
+++ b/vllm/kvprune/layers/sampler.py
@@ -0,0 +1,27 @@
+import torch
+from torch import nn
+
+
+class Sampler(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # @torch.compile
+ def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
+ temps = temperatures.view(-1)
+ scaled = logits.float()
+
+ greedy_mask = temps == 0.0
+ sample_mask = ~greedy_mask
+
+ if sample_mask.any():
+ temps_sample = temps[sample_mask].unsqueeze(-1) # [B_sample, 1]
+ scaled_sample = scaled[sample_mask].div(temps_sample) # temperature scaling
+
+ E = torch.empty_like(scaled_sample).exponential_(1).clamp_min_(1e-10).log()
+ scaled_sample = scaled_sample - E
+
+ scaled = scaled.clone()
+ scaled[sample_mask] = scaled_sample
+
+ return scaled.argmax(dim=-1)
diff --git a/vllm/kvprune/layers/triton_helpers.py b/vllm/kvprune/layers/triton_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c1a31669bac31c9fcef53259f1211e6de19bc37
--- /dev/null
+++ b/vllm/kvprune/layers/triton_helpers.py
@@ -0,0 +1,101 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _masked_index_select_kernel(
+ X_ptr,
+ IDX_ptr,
+ OUT_ptr,
+ N,
+ stride_xn,
+ stride_xh,
+ stride_ob,
+ stride_oh,
+):
+ b = tl.program_id(0) # which output row (0..B-1)
+ h = tl.program_id(1)
+ idx = tl.load(IDX_ptr + b) # int32
+ valid = (idx >= 0) & (idx < N)
+ out_ptrs = OUT_ptr + b * stride_ob + h * stride_oh
+
+ if not valid:
+ tl.store(out_ptrs, 0)
+ else:
+ x_ptrs = X_ptr + idx * stride_xn + h * stride_xh
+ vals = tl.load(x_ptrs)
+ tl.store(out_ptrs, vals)
+
+
+def masked_index_select_triton_dim0(
+ input: torch.Tensor, index: torch.Tensor
+) -> torch.Tensor:
+ """
+ X: [N, H] : contiguous in the H dimension
+ b_m: [B] int32/int64 on same device; out-of-range -> zeros)
+ Returns: [B, H]
+ """
+ assert input.ndim == 2 and index.ndim == 1
+ N, H = input.shape
+ B = index.numel()
+ out = torch.empty((B, H), dtype=input.dtype, device=input.device)
+ _masked_index_select_kernel[(B, H)](
+ input,
+ index,
+ out,
+ N,
+ input.stride(0),
+ input.stride(1),
+ out.stride(0),
+ out.stride(1),
+ )
+ return out
+
+
+@triton.jit
+def _masked_index_copy_kernel(
+ DST_ptr,
+ IDX_ptr,
+ SRC_ptr,
+ N,
+ stride_dn,
+ stride_dh,
+ stride_sb,
+ stride_sh,
+):
+ b = tl.program_id(0)
+ h = tl.program_id(1)
+ idx = tl.load(IDX_ptr + b)
+ valid = (idx >= 0) & (idx < N)
+ if valid:
+ src_ptrs = SRC_ptr + b * stride_sb + h * stride_sh
+ dst_ptrs = DST_ptr + idx * stride_dn + h * stride_dh
+ tl.store(dst_ptrs, tl.load(src_ptrs))
+
+
+def masked_index_copy_triton_dim0(
+ dst: torch.Tensor, index: torch.Tensor, src: torch.Tensor
+):
+ """
+ In-place: dst.index_copy_(0, index, src) but masked:
+ - rows with index[b] < 0 or >= dst.shape[0] are skipped (no write).
+ Shapes:
+ dst: [N, H]
+ src: [B, H]
+ index: [B]
+ """
+ assert dst.ndim == 2 and src.ndim == 2 and index.ndim == 1
+ N, H = dst.shape
+ B, Hs = src.shape
+ assert Hs == H and index.numel() == B
+ _masked_index_copy_kernel[(B, H)](
+ dst,
+ index,
+ src,
+ N,
+ dst.stride(0),
+ dst.stride(1),
+ src.stride(0),
+ src.stride(1),
+ )
diff --git a/vllm/kvprune/models/__init__.py b/vllm/kvprune/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0c88b70db7a0829df28ae30963c1d62e419185e
--- /dev/null
+++ b/vllm/kvprune/models/__init__.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import logging
+
+from vllm.kvprune.models.llama3 import LlamaForCausalLM
+from vllm.kvprune.models.qwen3 import Qwen3ForCausalLM
+
+logger = logging.getLogger(__name__)
+
+MODEL_REGISTRY = {
+ "llama": LlamaForCausalLM,
+ "qwen3": Qwen3ForCausalLM,
+}
+
+try:
+ from vllm.kvprune.models.qwen3_moe import Qwen3MoeForCausalLM
+except Exception as exc:
+ logger.debug("Skipping qwen3_moe registration due to import error: %s", exc)
+else:
+ MODEL_REGISTRY["qwen3_moe"] = Qwen3MoeForCausalLM
diff --git a/vllm/kvprune/models/llama3.py b/vllm/kvprune/models/llama3.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aa75492868aa1ae5e2e1d0205ed4d13d3454d4b
--- /dev/null
+++ b/vllm/kvprune/models/llama3.py
@@ -0,0 +1,281 @@
+import os
+from glob import glob
+
+import torch
+import torch.distributed as dist
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import LlamaConfig
+
+from vllm.kvprune.compression import (
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from vllm.kvprune.layers.activation import SiluAndMul
+from vllm.kvprune.layers.attention import Attention
+from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from vllm.kvprune.layers.layernorm import RMSNorm
+from vllm.kvprune.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.kvprune.layers.rotary_embedding import get_rope
+from vllm.kvprune.utils.context import get_context
+
+
+class LlamaAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: dict | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = dist.get_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ if rope_scaling is not None:
+ rope_scaling_tuple = (
+ rope_scaling["rope_type"],
+ rope_scaling["factor"],
+ rope_scaling["low_freq_factor"],
+ rope_scaling["high_freq_factor"],
+ rope_scaling["original_max_position_embeddings"],
+ )
+ else:
+ rope_scaling_tuple = None
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling_tuple,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = q.view(-1, self.num_heads, self.head_dim)
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ mlp_bias: bool,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=mlp_bias,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=mlp_bias,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: LlamaConfig,
+ ) -> None:
+ super().__init__()
+ self.self_attn = LlamaAttention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ qkv_bias=getattr(config, "attention_bias", False),
+ head_dim=getattr(config, "head_dim", None),
+ rope_theta=getattr(config, "rope_theta", 500000.0),
+ rope_scaling=getattr(config, "rope_scaling", None),
+ )
+ self.mlp = LlamaMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ mlp_bias=config.mlp_bias,
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class LlamaModel(nn.Module):
+ def __init__(
+ self,
+ config: LlamaConfig,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(positions, hidden_states, residual)
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class LlamaForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(self, config: LlamaConfig) -> None:
+ super().__init__()
+ self.model = LlamaModel(config)
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, positions)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", "cpu") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_loaded = False
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, loaded_weight: p.data.copy_(loaded_weight),
+ )
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/kvprune/models/qwen3.py b/vllm/kvprune/models/qwen3.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fd75eaa36053ddcf2d2f0c769a1f06d5c4e6d48
--- /dev/null
+++ b/vllm/kvprune/models/qwen3.py
@@ -0,0 +1,286 @@
+import os
+from glob import glob
+
+import torch
+import torch.distributed as dist
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import Qwen3Config
+
+from vllm.kvprune.compression import (
+ CompressionMethod,
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from vllm.kvprune.layers.activation import SiluAndMul
+from vllm.kvprune.layers.attention import Attention
+from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from vllm.kvprune.layers.layernorm import RMSNorm
+from vllm.kvprune.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.kvprune.layers.rotary_embedding import get_rope
+from vllm.kvprune.utils.context import get_context
+
+
+class Qwen3Attention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: tuple | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = dist.get_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
+ k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ cc = context.compression_context
+ if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
+ # 鍏抽敭锛氭敞鍏?wo_weight 鍒?compression_context
+ wo_raw = self.o_proj.weight
+ hidden_size, _ = wo_raw.shape
+ Hq, D = self.num_heads, self.head_dim
+ cc.wo_weight = (
+ wo_raw.transpose(0, 1)
+ .contiguous()
+ .view(Hq, D, hidden_size)
+ .to(dtype=torch.float32)
+ )
+
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class Qwen3MLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class Qwen3DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3Config,
+ ) -> None:
+ super().__init__()
+ self.self_attn = Qwen3Attention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, "attention_bias", False),
+ head_dim=getattr(config, "head_dim", None),
+ rope_theta=getattr(config, "rope_theta", 1000000),
+ rope_scaling=getattr(config, "rope_scaling", None),
+ )
+ self.mlp = Qwen3MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class Qwen3Model(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3Config,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(positions, hidden_states, residual)
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Qwen3ForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(self, config: Qwen3Config) -> None:
+ super().__init__()
+ self.model = Qwen3Model(config)
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, positions)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", "cpu") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_loaded = False
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, loaded_weight: p.data.copy_(loaded_weight),
+ )
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/kvprune/models/qwen3_moe.py b/vllm/kvprune/models/qwen3_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1e9fcdd8c8991a74a9353088cbf9d8a99c1428
--- /dev/null
+++ b/vllm/kvprune/models/qwen3_moe.py
@@ -0,0 +1,378 @@
+import os
+from glob import glob
+
+import torch
+import torch.distributed as dist
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import Qwen3MoeConfig
+
+from vllm.kvprune.compression import (
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from vllm.kvprune.layers.activation import SiluAndMul
+from vllm.kvprune.layers.attention import Attention
+from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from vllm.kvprune.layers.layernorm import RMSNorm
+from vllm.kvprune.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.kvprune.layers.moe import (
+ MergedColumnParallelTritonFusedMoeLinear,
+ RowParallelTritonFusedMoeLinear,
+)
+from vllm.kvprune.layers.rotary_embedding import get_rope
+from vllm.kvprune.triton_kernels.routing import routing
+from vllm.kvprune.utils.context import get_context
+
+
+class Qwen3MoeAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: tuple | None = None,
+ sliding_window: int | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = dist.get_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.sliding_window = sliding_window
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
+ k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class Qwen3MoeMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class Qwen3MoeTritonSparseMoeBlock(nn.Module):
+ def __init__(
+ self,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size: int,
+ num_experts_per_tok: int,
+ norm_topk_prob: bool,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.num_experts = num_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.norm_topk_prob = norm_topk_prob
+ self.hidden_size = hidden_size
+ self.moe_intermediate_size = intermediate_size
+
+ self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False)
+ self.gate_up_proj = MergedColumnParallelTritonFusedMoeLinear(
+ hidden_size, [intermediate_size] * 2, num_experts
+ )
+ self.down_proj = RowParallelTritonFusedMoeLinear(
+ intermediate_size, hidden_size, num_experts
+ )
+ self.act_fn = SiluAndMul()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ x = hidden_states
+ if x.numel() == 0:
+ return x
+ logits = self.gate(x)
+ rdata, gather_indx, scatter_indx = routing(
+ logits,
+ self.num_experts_per_tok,
+ simulated_ep=1, # single device, replicated experts
+ )
+ x = self.gate_up_proj(x, routing_data=rdata, gather_indx=gather_indx)
+ x = self.act_fn(x)
+ x = self.down_proj(
+ x, routing_data=rdata, scatter_indx=scatter_indx, gammas=rdata.gate_scal
+ )
+ return x
+
+
+class Qwen3MoeBlock(Qwen3MoeTritonSparseMoeBlock):
+ pass
+
+
+class Qwen3MoeRMSNorm(RMSNorm):
+ pass
+
+
+class Qwen3MoeDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ layer_idx: int,
+ ) -> None:
+ super().__init__()
+ self.self_attn = Qwen3MoeAttention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ head_dim=getattr(config, "head_dim", None),
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, "attention_bias", False),
+ rope_theta=config.rope_theta,
+ rope_scaling=config.rope_scaling,
+ sliding_window=config.sliding_window,
+ )
+ if (layer_idx not in config.mlp_only_layers) and (
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
+ ):
+ self.mlp = Qwen3MoeBlock(
+ num_experts=config.num_experts,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ num_experts_per_tok=config.num_experts_per_tok,
+ norm_topk_prob=config.norm_topk_prob,
+ hidden_act=config.hidden_act,
+ )
+ else:
+ self.mlp = Qwen3MoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = Qwen3MoeRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = Qwen3MoeRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ # Self Attention
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Qwen3MoeModel(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [
+ Qwen3MoeDecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_ids,
+ )
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+
+class Qwen3MoeForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ ) -> None:
+ super().__init__()
+ self.model = Qwen3MoeModel(config)
+ self.num_experts = config.num_experts
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, position_ids)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ rank = dist.get_rank()
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", f"cuda:{rank}") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_expert = "mlp.experts" in weight_name
+ is_loaded = False
+
+ # Process experts params name
+ if is_expert:
+ mlp_module_name, expert_module_name = weight_name.split(
+ ".experts."
+ )
+ expert_idx = int(expert_module_name.split(".")[0])
+ proj_name = expert_module_name.replace(f"{expert_idx}.", "")
+ weight_name = f"{mlp_module_name}.{proj_name}"
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ if is_expert:
+ weight_loader(
+ param, weight_tensor, expert_idx, shard_id
+ )
+ else:
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, lw: p.data.copy_(lw, non_blocking=True),
+ )
+ if is_expert:
+ weight_loader(param, weight_tensor, expert_idx)
+ else:
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/kvprune/triton_kernels/__init__.py b/vllm/kvprune/triton_kernels/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d446c6d0a024357f4659d254e327eb4a00e23a1
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/__init__.py
@@ -0,0 +1,22 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Triton kernel utilities (matmul_ogs, MoE, topk, …) plus KV-facing entrypoints.
+
+For KV pruning attention/store, see also ``vllm.kvprune.attention`` and
+``vllm.kvprune.kv_cache``.
+"""
+
+from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
+from vllm.kvprune.kv_cache.store_kv_cache import (
+ decode_store_kv,
+ prefill_store_all_kv,
+ prefill_store_topk_kv,
+)
+
+__all__ = [
+ "causal_sparse_varlen_with_cache",
+ "decode_store_kv",
+ "prefill_store_all_kv",
+ "prefill_store_topk_kv",
+]
diff --git a/vllm/kvprune/triton_kernels/compaction.py b/vllm/kvprune/triton_kernels/compaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..21d471befd0d710f96f01882fa9e8b8a84059bd9
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/compaction.py
@@ -0,0 +1,76 @@
+import torch
+from .compaction_details._masked_compaction import _masked_compaction
+from .tensor import Bitmatrix
+
+
+def compaction(yv, yi, bitmask, sentinel=-1):
+ """
+ Return compacted copies of *yv* and *yi* based on a per-row bitmask.
+
+ Only the elements whose index appears among the active bits of *bitmask*
+ are kept; the rest are replaced by *sentinel*. Kept elements preserve
+ their original left-to-right order.
+
+ Parameters
+ ----------
+ yv : torch.Tensor, shape (B, K)
+ Values tensor.
+ yi : torch.Tensor, shape (B, K), dtype torch.long
+ Integer indices (0 ≤ index < 32) associated with *yv*.
+ bitmask : torch.Tensor, shape (B,) **or** (B, 32)
+ Per-row mask of active indices. See the in-place version for details.
+ sentinel : int, default -1
+ Value written into dropped positions of the returned tensors.
+
+ Returns
+ -------
+ (yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
+ New tensors with the same dtype/device as the inputs.
+
+ """
+
+ n_rows, n_cols = yi.shape
+ ret_yv = torch.empty_like(yv)
+ ret_yi = torch.empty_like(yi)
+ if isinstance(bitmask, Bitmatrix):
+ bitmask = bitmask.storage.data
+
+ _masked_compaction[(n_rows,)](
+ yv,
+ yi,
+ bitmask,
+ bitmask.stride(0),
+ bitmask.stride(1), # inputs
+ ret_yv,
+ ret_yi, # outputs
+ sentinel, # sentinel
+ K=n_cols, # constants
+ )
+ return ret_yv, ret_yi
+
+
+def compaction_torch(
+ yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1
+):
+ """
+ reference implementation of `masked_compact`
+ """
+ B, K = yi.shape
+ device = yi.device
+ # Expand bitmask to a boolean matrix of active bits (B, 32)
+ w = 1 << torch.arange(32, device=device, dtype=bitmask.dtype)
+ bits = (bitmask.unsqueeze(-1) & w) != 0
+ mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
+ # For every yi element decide whether it should be kept
+ keep = mask.gather(1, yi.long())
+ # Build a stable permutation that brings all "keep" items forward
+ # False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
+ order = (~keep).to(torch.int).argsort(dim=1, stable=True)
+ # Re‑order tensors according to above permutation
+ yi_sorted = yi.gather(1, order)
+ yv_sorted = yv.gather(1, order)
+ # fill relevant positions with sentinel
+ keep_sorted = keep.gather(1, order)
+ yi_sorted[~keep_sorted] = sentinel
+ yv_sorted[~keep_sorted] = sentinel
+ return yv_sorted, yi_sorted
diff --git a/vllm/kvprune/triton_kernels/compaction_details/__init__.py b/vllm/kvprune/triton_kernels/compaction_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/compaction_details/_masked_compaction.py b/vllm/kvprune/triton_kernels/compaction_details/_masked_compaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..58fe2412cf19386dbbe73bea1a5daf75d464ffb2
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/compaction_details/_masked_compaction.py
@@ -0,0 +1,22 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _masked_compaction(
+ Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr
+):
+ pid_m = tl.program_id(0)
+ yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
+ yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
+ div = yi // 32
+ rem = yi % 32
+ active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
+ exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
+ active_flags = active_bits.to(tl.int1)
+ rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
+ write_indx = exc_cumsum + rev_arange
+ yv = tl.where(active_flags, yv, sentinel)
+ yi = tl.where(active_flags, yi, sentinel)
+ tl.store(RetYv + pid_m * K + write_indx, yv)
+ tl.store(RetYi + pid_m * K + write_indx, yi)
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs.py b/vllm/kvprune/triton_kernels/matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..284c9fb4359b1e05f2fdcfdde7b24fe9cb20e303
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/matmul_ogs.py
@@ -0,0 +1,609 @@
+# isort: off
+# fmt: off
+from dataclasses import dataclass
+import itertools
+import sys
+import torch
+import triton
+from enum import Enum, auto
+import math
+# utilities
+from vllm.kvprune.triton_kernels import target_info
+from vllm.kvprune.triton_kernels.numerics import InFlexData, OutFlexData
+from vllm.kvprune.triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
+from vllm.kvprune.triton_kernels.target_info import is_cuda
+# details
+from .matmul_ogs_details._matmul_ogs import _matmul_ogs
+from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
+from .matmul_ogs_details._reduce_grouped import _reduce_grouped
+from .numerics_details.mxfp import MXFP_BLOCK_SIZE
+from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
+from .specialize import specialize
+from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
+
+
+@dataclass(frozen=True)
+class FnSpecs:
+ name: str
+ fn: "triton.runtime.jit.JITFunction"
+ fn_arg_names: tuple[str]
+ fn_arg_do_not_specialize: tuple[str] = tuple()
+
+ @staticmethod
+ def default():
+ return FnSpecs("dflt", None, tuple())
+
+
+@dataclass(frozen=True)
+class FusedActivation:
+ specs: FnSpecs = FnSpecs.default()
+ fn_args: tuple[object] = tuple()
+ reduction_n: int = 1
+
+
+@dataclass(frozen=True)
+class Epilogue:
+ specs: FnSpecs = FnSpecs.default()
+ fn_arg_values_matmul: tuple[object] = tuple()
+ fn_arg_values_finalize: tuple[object] = tuple()
+ effective_itemsize: float = None
+
+class FnName(Enum):
+ QUANTIZE_MXFP8 = auto()
+
+
+EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
+
+_kernels = dict()
+
+
+def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
+ global _kernels
+ key = (fused_activation.name, epilogue.name)
+ if key in _kernels:
+ return _kernels[key]
+ spec_constants = {
+ "ACTIVATION_FN": fused_activation.fn,
+ "EPILOGUE_FN": epilogue.fn,
+ }
+ spec_tuples = {
+ "activation_fn_args": fused_activation.fn_arg_names,
+ "epilogue_fn_args": epilogue.fn_arg_names,
+ }
+ do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
+ import types
+
+ module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
+ sys.modules[module.__name__] = module
+ module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ module._reduce_grouped = specialize(_reduce_grouped, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ _kernels[key] = module
+ return module
+
+
+# -----------------------------------------------------------------------------
+# Matrix Multiplication + Outer Gather/Scatter
+# -----------------------------------------------------------------------------
+
+
+def can_overflow_int32(tensor: torch.Tensor):
+ max_int32 = (1 << 31) - 1
+ offset = 0
+ for i in range(tensor.ndim):
+ offset += (tensor.shape[i] - 1) * tensor.stride(i)
+ return offset > max_int32
+
+
+def should_upcast_indices(*args):
+ return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
+
+
+# ---------------------
+# Numerics
+# ---------------------
+
+# fmt: off
+
+@dataclass(frozen=True)
+class FlexCtx:
+ lhs_data: InFlexData = InFlexData()
+ rhs_data: InFlexData = InFlexData()
+ out_data: OutFlexData = OutFlexData()
+
+@dataclass
+class PrecisionConfig:
+ max_num_imprecise_acc: int = None
+ allow_tf32: bool = True
+ flex_ctx: FlexCtx = FlexCtx()
+ acc_scale: int = 1.0
+ flexpoint_saturate_inf: bool = False
+ report_quantization_err_fn: callable = None
+ act_scale: Tensor | None = None
+ weight_scale: Tensor| None = None
+ out_scale: Tensor | None = None
+ out_dtype: torch.dtype = None
+ enforce_bitwise_invariance: bool = False
+
+
+# TODO: merge in opt_flags
+def get_swap_xw(precision_config, opt_flags):
+ if target_info.cuda_capability_geq(10, 0):
+ return precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
+ return False
+
+# ---------------------
+# Allocation
+# ---------------------
+
+@dataclass
+class MatmulAllocation:
+ device: str
+ output: tuple[tuple[int], torch.dtype]
+ scratchpads: dict[str, tuple]
+
+def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags):
+ # ---- output ------
+ N = w.shape[-1]
+ # by default - M is number of rows in the activations
+ M = x.shape[-2]
+ # if the activations are gathered, then M is number of gather indices
+ if gather_indx is not None:
+ M = gather_indx.src_indx.shape[0]
+ # final output
+ if routing_data.n_expts_act == 1 or scatter_indx is None:
+ y_rows = M
+ else:
+ Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
+ y_rows = Mc
+ batch_dim = x.shape[0] if x.ndim == 3 else 1
+ out_shape = (batch_dim, y_rows, N // fused_activation.reduction_n)
+ out_dtype = precision_config.out_dtype or x.dtype
+ output = (out_shape, out_dtype)
+ # ---- scratchpad -----#
+ scratchpad = dict()
+ if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
+ scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
+ scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), scratch_out_dtype)
+ if "matmul" in scratchpad and precision_config.out_scale is not None:
+ scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
+ return MatmulAllocation(x.device, output, scratchpad)
+
+def apply_allocation(allocation: MatmulAllocation, output):
+ ret = dict()
+ if output is None:
+ output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
+ else:
+ assert output.shape == allocation.output[0]
+ ret["output"] = output[None, :, :]
+ ret["scratchpad"] = {
+ k: torch.empty(v[0], device=allocation.device, dtype=v[1])
+ for k, v in allocation.scratchpads.items()
+ }
+ return ret
+
+# -----------------------------------------------------------------------------
+# Canonicalize
+# -----------------------------------------------------------------------------
+# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
+# we can canonicalize storages to make the implementation more uniform
+
+def _canonicalize_storage(storage, out_ndim, flex_data):
+ assert out_ndim >= storage.data.ndim
+ # Need to use as_strided instead of view because for a tensor with
+ # shape[-2] == 1 can have ambuiguity related to col-wise. Fo example,
+ # > t = torch.randn(2, 5, 1).mT
+ # > t_view = t.view(t.shape)
+ # > t.stride(), t_view.stride()
+ # ((5, 1, 1), (5, 5, 1))
+ # Our check t_view is col-wise fails since t_view.stride(-2) != 1
+ # This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py
+ new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape)
+ new_storage_view = storage.data.view(new_storage_shape)
+ new_storage_stride = [new_storage_view.stride(0)] * (out_ndim - storage.data.ndim) + list(storage.data.stride())
+ new_storage_data = storage.data.as_strided(new_storage_shape, new_storage_stride)
+ if flex_data is not None:
+ new_storage_data = flex_data.reinterpret(new_storage_data)
+ return Storage(new_storage_data, storage.layout)
+
+#
+
+def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_mx_scale: torch.Tensor,
+ fused_activation, epilogue,
+ x_flex: InFlexData | None = None,
+ out_flex: OutFlexData | None = None, x_mx_scale: torch.Tensor | None = None,
+ out_dtype: bool = None, flexpoint_saturate_inf: bool = False):
+ """
+ In-place grouped row reduction.
+
+ Arguments
+ - x: Tensor[AnyFloat] of shape [(num_groups * K), N]
+ - indx: Tensor[Int] of shape [num_groups, K]
+
+ Description
+ For each group g in [0, num_groups), this routine sums the K rows of `x`
+ specified by `indx[g, :]` and overwrites the row corresponding to the first
+ valid (non-negative) index with the per-group sum. Accumulation is performed
+ in float32 for numerical stability, and the result is written back in the
+ dtype of `x`.
+
+ Behavior and edge cases
+ - Invalid (-1) entries are skipped during accumulation and do not generate
+ memory traffic. If a group has no valid entries, nothing is written for
+ that group.
+ - Reduction is performed tile-by-tile along the N dimension within a single
+ kernel launch (persistent along N) to minimize launch overhead.
+
+ Performance notes
+ - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x),
+ plus index reads. With no invalid entries, this becomes (K + 1) reads/writes
+ of length N per group.
+
+ Returns
+ - The input tensor `x` (modified in place).
+ """
+ if indx is None and x.shape[0] == 1:
+ return x.squeeze(0), None
+ if indx is not None:
+ num_groups = indx.shape[0]
+ else:
+ num_groups = x.shape[-2]
+ if x_flex is None:
+ x_flex = InFlexData()
+ if out_flex is None:
+ out_flex = OutFlexData()
+ K = 1 if indx is None else indx.shape[1]
+ out_dtype = x.dtype if out_dtype is None else out_dtype
+ assert x.shape[-1] % fused_activation.reduction_n == 0
+ BLOCK_N = 512
+ # Resolve scalar flex scales (may be None)
+ x_expected_scale = None if x_flex is None else x_flex.scale
+ out_expected_scale = None if out_flex is None else out_flex.expected_scale
+ out_actual_scale = None if out_flex is None else out_flex.actual_scale
+ out_checksum_scale = None if out_flex is None else out_flex.checksum_scale
+ # Resolve MXFP output scale row stride
+ stride_mxb = 0 if x_mx_scale is None else x_mx_scale.stride(0)
+ stride_mxs = 0 if x_mx_scale is None else x_mx_scale.stride(1)
+ stride_omxs = 0 if out_mx_scale is None else out_mx_scale.stride(0)
+ kernels = get_kernels(epilogue.specs, fused_activation.specs)
+ kernels._reduce_grouped[(num_groups, )](
+ x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), #
+ x_expected_scale, # scalar input scale
+ out_flex.reinterpret(out), out.stride(1), out.stride(2), #
+ out_expected_scale, out_actual_scale, out_checksum_scale, indx, #
+ x.shape[0], x.shape[-1], #
+ x_mx_scale, stride_mxb, stride_mxs, #
+ out_mx_scale, stride_omxs, #
+ *fused_activation.fn_args, fused_activation.reduction_n,
+ *epilogue.fn_arg_values_finalize,
+ HAS_IN_MX_SCALE=x_mx_scale is not None, HAS_OUT_MX_SCALE=out_mx_scale is not None,
+ FLEXPOINT_SATURATE_INF=flexpoint_saturate_inf, #
+ BLOCK_N=BLOCK_N, K=K, #
+ num_warps=1, #
+ )
+ return out, out_mx_scale
+
+# -----------------------------------------------------------------------------
+# Triton Implementation
+# -----------------------------------------------------------------------------
+
+def matmul_ogs_set_idle_sms(num_idle_sms):
+ """
+ persistent kernels will leave `num_idle_sms` idle
+ """
+ update_opt_flags_constraints({"idle_sms": num_idle_sms})
+
+def matmul_ogs(x, w, bias,
+ routing_data: RoutingData | None = None,
+ gather_indx: GatherIndx | None = None,
+ scatter_indx: ScatterIndx | None = None,
+ precision_config: PrecisionConfig | None = None,
+ betas: torch.Tensor | None = None,
+ gammas: torch.Tensor | None = None,
+ out_alpha: float | None = None,
+ y: torch.Tensor | None = None,
+ fused_activation: FusedActivation | None = None,
+ epilogue: Epilogue | None = None,
+ ):
+ """
+ Y[:, :] = 0.
+ for e in num_experts:
+ Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
+ """
+ is_input_batched = x.ndim == 3
+ if is_input_batched:
+ assert gather_indx is None, "gather not supported in batched mode"
+ assert scatter_indx is None, "scatter not supported in batched mode"
+ assert routing_data is None, "routing not supported in batched mode"
+ assert w.ndim == 3 and w.shape[0] == x.shape[0]
+ # canonicalize inputs
+ if precision_config is None:
+ precision_config = PrecisionConfig()
+ if fused_activation is None:
+ fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
+ if epilogue is None:
+ epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
+ if routing_data is None:
+ routing_data = RoutingData(None, None, max(1, w.shape[0]), 1)
+ # unpack scales
+ w_scale = precision_config.weight_scale
+ w_has_mx = w_scale is not None
+ is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
+ if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
+ if not isinstance(w, Tensor):
+ # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
+ dtype = FP4 if w.dtype == torch.uint8 else w.dtype
+ w = wrap_torch_tensor(w, dtype=dtype)
+ if w_scale is not None and not isinstance(w_scale, Tensor):
+ w_scale = Tensor(w_scale)
+ if w_scale is not None:
+ w_scale.storage.data = w_scale.data.view(torch.uint8)
+ w_scale.dtype = torch.uint8
+ x_scale = precision_config.act_scale
+ x_has_mx = x_scale is not None
+ if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp"
+ if x_scale is not None and not isinstance(x_scale, Tensor):
+ x_scale = Tensor(x_scale)
+ if not isinstance(x, Tensor):
+ x = Tensor(x, dtype=x.dtype)
+ # determine shapes
+ has_gather = gather_indx is not None
+ has_scatter = scatter_indx is not None
+ is_ragged = routing_data.expt_hist is not None
+ M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
+ batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
+ K, N = w.shape[-2:]
+ assert K == x.shape[-1]
+ if x.ndim == 3 and w.ndim == 3:
+ assert x.shape[0] == w.shape[0]
+ # compute optimization flags
+ out_dtype = precision_config.out_dtype or x.dtype
+ can_use_tma = x.numel() > 0 and x.storage.is_tma_compliant() and \
+ w.numel() > 0 and w.storage.is_tma_compliant() and \
+ (w_scale is None or w_scale.storage.is_tma_compliant())
+ # hopper w/ mxfp4 doesn't support TMA
+ can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
+ can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
+ opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
+ M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
+ )
+ if not can_use_fused_scatter and opt_flags.fused_scatter:
+ raise InapplicableConstraint("Fused scatter is not supported")
+ if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
+ raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
+ if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
+ raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
+ # fused activation
+ matmul_fused_activation = fused_activation
+ reduce_fused_activation = FusedActivation()
+ if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
+ matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation
+ # allocate output/scratchpad memory
+ allocation = init_allocation(x, w, precision_config, fused_activation,
+ routing_data, gather_indx, scatter_indx, opt_flags)
+ memory = apply_allocation(allocation, y)
+ # early exit
+ if batch_size * M * N == 0:
+ ret = memory["output"].squeeze(0)
+ if not is_input_batched:
+ ret = ret.squeeze(0)
+ return ret
+ # TMA descriptors require a global memory allocation
+ if opt_flags.is_persistent:
+ triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
+ # Intermediate tensors and postprocess kernels for each situation
+ has_scratchpad = "matmul" in memory["scratchpad"]
+ # Canonical output tensor (matmul scratchpad if present, otherwise final output tensor)
+ out_matmul = memory["scratchpad"].get("matmul", memory["output"])
+ out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data
+ # Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer
+ out_matmul_scale = precision_config.out_scale
+ if out_matmul_scale is not None:
+ out_matmul_scale = out_matmul_scale.data.view(torch.uint8)
+ if has_scratchpad and "mx_out_scale" in memory["scratchpad"]:
+ out_matmul_scale = memory["scratchpad"]["mx_out_scale"]
+ out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1
+ # matrix multiplication
+ flex = precision_config.flex_ctx
+ bias_stride = None if bias is None else bias.stride(0)
+ num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
+ # moe metadata
+ expt_data = routing_data.expt_data
+ block_m = opt_flags.block_m
+ expt_hist = None if expt_data is None else expt_data.hist
+ expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
+ expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
+ expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
+ # spmd grid
+ grid_m = triton.cdiv(M, opt_flags.block_m)
+ if expt_block_pid_map is not None:
+ grid_m = routing_data.n_blocks(M, opt_flags.block_m)
+ grid_n = triton.cdiv(N, opt_flags.block_n)
+ max_grid = batch_size * grid_m * grid_n * opt_flags.split_k
+ grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
+ # canonicalize storage
+ has_gather_tma = has_gather and target_info.has_tma_gather()
+ has_scatter_tma = opt_flags.fused_scatter and target_info.has_tma_gather()
+ y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if opt_flags.fused_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:]))
+ x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
+ w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
+ y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)
+ # create tma descriptor for x
+ x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
+ x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k]
+ x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
+ x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data
+ # create tma descriptor for y
+ y_has_tma = opt_flags.is_persistent and (has_scatter_tma or not opt_flags.fused_scatter)
+ block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.reduction_n
+ y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n]
+ y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
+ y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data
+ # create tma descriptor for w
+ w_has_tma = opt_flags.is_persistent
+ w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
+ # create tma descriptor for w_scale
+ w_scale_tensor_or_tma = w_scale
+ w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
+ w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale
+ # canonicalize strides
+ x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
+ x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
+ x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
+ w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
+ w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
+ out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None)
+ out_matmul_scale_strides = (0, ) * (3 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
+ # launch kernel
+ kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs)
+ # When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
+ # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
+ # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
+ # w_transpose = w_storage.data.stride()[-1] != 1
+ w_transpose = w_storage.data.stride()[-2] == 1
+ (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
+ y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
+ *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
+ *out_matmul_scale_strides[-3:],
+ x_tensor_or_tma, x_storage.data, *x_strides,
+ flex.lhs_data.scale,
+ None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
+ w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,
+ flex.rhs_data.scale,
+ w_scale_tensor_or_tma, *w_scale_strides,
+ bias, bias_stride,
+ x.shape[-2],
+ x.shape[-2] if routing_data.expt_hist is None else None,
+ N, K,
+ betas, gammas,
+ None if gather_indx is None else gather_indx.src_indx,
+ None if scatter_indx is None else scatter_indx.src_indx,
+ num_indx,
+ None if not opt_flags.fused_scatter else scatter_indx.dst_indx,
+ None if not opt_flags.fused_scatter else scatter_indx.dst_indx.shape[0],
+ expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
+ batch_size, grid_m, grid_n,
+ out_alpha,
+ *matmul_fused_activation.fn_args, matmul_fused_activation.reduction_n,
+ *epilogue.fn_arg_values_matmul,
+ routing_data.n_expts_tot, routing_data.n_expts_act,
+ precision_config.max_num_imprecise_acc,
+ precision_config.allow_tf32,
+ precision_config.flexpoint_saturate_inf,
+ flex.rhs_data.is_per_batch,
+ opt_flags.block_m,
+ opt_flags.block_n,
+ opt_flags.block_k,
+ opt_flags.group_m,
+ XCD_SWIZZLE=opt_flags.xcd_swizzle,
+ SWIZZLE_MX_VALUE=w.storage.layout.name,
+ SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name,
+ EPILOGUE_SUBTILE=opt_flags.epilogue_subtile,
+ SPLIT_K=opt_flags.split_k,
+ EVEN_K=K % opt_flags.block_k == 0,
+ W_CACHE_MODIFIER=opt_flags.w_cache_modifier,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt,
+ num_warps=opt_flags.num_warps,
+ num_stages=opt_flags.num_stages,
+ arch=opt_flags.arch,
+ UPCAST_INDICES=should_upcast_indices(x, w, out_matmul),
+ X_TMA_MODE=x_tma_mode,
+ Y_TMA_MODE=y_tma_mode,
+ SWAP_XW=get_swap_xw(precision_config, opt_flags),
+ IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name,
+ NUM_SMS = grid if opt_flags.is_persistent else 0,
+ **opt_flags.target_kernel_kwargs)
+ # Build grouped reduction inputs in a uniform way
+ group_indx = None if scatter_indx is None or opt_flags.fused_scatter else scatter_indx.src_indx.view(-1, routing_data.n_expts_act)
+ out_final, out_final_mx_scale = reduce_grouped(
+ out_matmul,
+ group_indx,
+ memory["output"].squeeze(0),
+ precision_config.out_scale,
+ reduce_fused_activation,
+ epilogue,
+ x_flex=InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale),
+ out_flex=precision_config.flex_ctx.out_data,
+ x_mx_scale=out_matmul_scale.squeeze(1) if out_matmul_has_mx else None,
+ out_dtype=memory["output"].dtype,
+ flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf,
+ )
+ if not is_input_batched:
+ out_final = out_final.squeeze(0)
+ if out_final_mx_scale is not None:
+ precision_config.out_scale = out_final_mx_scale
+ return out_final
+
+# -----------------------------------------------------------------------------
+# Reference Implementation
+# -----------------------------------------------------------------------------
+
+def matmul_ogs_torch(x, w, bias,
+ routing_data: RoutingData = None,
+ gather_indx: GatherIndx = None,
+ scatter_indx: ScatterIndx = None,
+ precision_config: PrecisionConfig = None,
+ betas = None,
+ gammas = None,
+ round_x = None, round_y = None,
+ ):
+ is_input_batched = x.ndim == 3
+ assert x.dtype.itemsize > 1
+ assert w.dtype.itemsize > 1
+ if is_input_batched:
+ assert gather_indx is None, "gather not supported in batched mode"
+ assert scatter_indx is None, "scatter not supported in batched mode"
+ assert routing_data is None, "routing not supported in batched mode"
+ assert w.ndim == 3 and w.shape[0] == x.shape[0]
+ if round_x is None:
+ round_x = lambda x, idx: x
+ if round_y is None:
+ round_y = lambda x: x
+ if bias is not None and bias.ndim == 1:
+ bias = bias.view(1, *bias.shape)
+ if w.ndim == 2:
+ w = w.view(1, *w.shape)
+ if x.ndim == 2:
+ x = x.view(1, *x.shape)
+ if routing_data is None:
+ routing_data = RoutingData(None, None, w.shape[0], 1)
+ n_expts_act = routing_data.n_expts_act
+ # memory offsets
+ if routing_data.n_expts_tot > 1 and not is_input_batched:
+ sizes = routing_data.expt_hist
+ off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32)
+ off[1:] = torch.cumsum(sizes, 0)
+ offs = list(itertools.pairwise(off))
+ else:
+ offs = [[0, x.shape[1]] for _ in range(w.shape[0])]
+ # compute
+ n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0]
+ y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype)
+ for i, (lo, hi) in enumerate(offs):
+ if gather_indx is None:
+ idx = torch.arange(lo, hi, device=x.device)
+ else:
+ idx = gather_indx.src_indx[lo:hi] // n_expts_act
+ batch = i if is_input_batched else 0
+ out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
+ w[i].float())
+ if bias is not None:
+ out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
+ if gammas is not None:
+ out *= gammas[lo:hi, None]
+ y[batch, lo:hi, :] = round_y(out)
+ if not is_input_batched:
+ y = y.view(y.shape[1], y.shape[2])
+ if scatter_indx is None:
+ return y
+ # accumulate output from all experts
+ n_rows = y.shape[0] // n_expts_act
+ out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device)
+ for i, (lo, hi) in enumerate(offs):
+ dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act
+ msk = dst_idx != -1
+ out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float()
+ return out
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/__init__.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/_common.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d5c99493872d779643aff2a9f7293685d8c4f2b
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/matmul_ogs_details/_common.py
@@ -0,0 +1,179 @@
+import torch
+
+import triton
+import triton.language as tl
+
+# -----------------------------------------------------------------------------
+# Utilities
+# -----------------------------------------------------------------------------
+
+
+@triton.constexpr_function
+def get_scaled_dot_format_string(dtype: tl.dtype):
+ mapping = {
+ tl.float16: "fp16",
+ tl.bfloat16: "bf16",
+ tl.uint8: "e2m1",
+ tl.float8e4nv: "e4m3",
+ tl.float8e5: "e5m2",
+ }
+ return mapping[dtype]
+
+
+@triton.jit
+def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr):
+ """
+ Swizzle the program id based on integer XCD_SWIZZLE.
+ This is useful for reording how blocks are ordered. A scheduler may, for example,
+ assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2.
+ This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment
+ becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to
+ the same hardware unit.
+ """
+ # Number of pids per group in the new arrangement
+ pids_per_group = domain_size // XCD_SWIZZLE
+ extra_pid_groups = domain_size % XCD_SWIZZLE
+
+ # Compute current current and local pid within the group
+ group = pid % XCD_SWIZZLE
+ local_pid = pid // XCD_SWIZZLE
+
+ # Calculate new pid based on the new grouping
+ new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid
+ return new_pid
+
+
+@triton.jit
+def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr):
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ tl.assume(group_size >= 0)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ return pid_m, pid_n
+
+
+def make_matmul_repr(base_name, order):
+ def matmul_repr(specialization):
+ signature = specialization.signature
+ constants = specialization.constants
+ reorder = lambda L: [L[i] for i in order]
+ layout = lambda stride: "N" if stride in constants else "T"
+
+ def convert_dtype(dtype):
+ if "tensordesc" in dtype:
+ ret = convert_dtype(dtype.split("<")[1].split("[")[0])
+ return ret
+ elif "u8" in dtype:
+ return "mxfp4"
+ elif dtype[0] == "*":
+ return dtype[1:]
+ else:
+ return dtype
+
+ dtypes = "x".join(
+ [convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])]
+ )
+ layouts = "".join(
+ [
+ f"{layout(i)}"
+ for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])
+ ]
+ )
+ blocks = "x".join(
+ [f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]]
+ )
+ # mode = []
+ # if "GatherIndx" not in constants:
+ # mode += ['g']
+ # if "ScatterSrcIndx" not in constants:
+ # mode += ['s']
+ # suffix = "" if not mode else "_o" + (''.join(mode))
+ # if base_name.startswith("_p"):
+ # suffix += "_ptma"
+ return f"{base_name}_{layouts}_{dtypes}_{blocks}"
+
+ return matmul_repr
+
+
+def matmul_launch_metadata(grid, kernel, args):
+ from ..proton_opts import launch_metadata_allow_sync
+
+ ret = dict()
+ M, N, K = args["M"], args["N"], args["K"]
+ Y, X, W = args["YPtr"], args["XPtr"], args["WPtr"]
+ tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
+ hist = args["ExptHist"]
+ if hist is not None:
+ # If annotation is given, use that to generate name for profiling.
+ if tokens_per_expt is not None:
+ n_rows = f"{tokens_per_expt}*"
+ elif launch_metadata_allow_sync():
+ n_rows = int(hist.float().mean())
+ else:
+ n_rows = "unknown"
+
+ if launch_metadata_allow_sync():
+ n_tokens = float(hist.sum())
+ n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (
+ hist > 0
+ ).sum()
+ elif tokens_per_expt is not None:
+ n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
+ # This may not be totally correct (e.g., we might not be using all experts)
+ # but it's better than nothing.
+ n_w_bytes = W.numel() * W.element_size()
+ else:
+ n_tokens = None
+ n_w_bytes = 0
+
+ # If annotation is given, use that to generate name for profiling.
+ tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
+ n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
+ else:
+ n_tokens = None
+ n_w_bytes = W.numel() * W.element_size()
+ repr = (
+ lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}"
+ )
+ nbits = X.dtype.itemsize * 8
+ batch_repr = ""
+ if "batch_size" in args and args["batch_size"] > 1:
+ batch_repr = repr("B", args["batch_size"]) + ", "
+ ret["name"] = (
+ f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}"
+ )
+ ep_subtile = args["EPILOGUE_SUBTILE"]
+ if ep_subtile is not None and ep_subtile > 1:
+ ret["name"] += f" ep/{ep_subtile}"
+
+ if hist is not None and n_tokens is None:
+ return ret # Don't fill metadata because we can't compute them properly.
+
+ fM = M if M is not None else n_tokens
+ fK = K if K is not None else n_tokens
+ ret[f"flops{nbits}"] = 2.0 * fM * N * fK
+
+ gindx = args.get("GatherIndx", None)
+ # sindx = args.get("WriteBackIndx", None)
+ n_x_bytes = X.numel() * X.element_size()
+ n_y_bytes = Y.numel() * Y.element_size()
+ if hist is not None:
+ assert n_tokens is not None
+ n_expts_act = args["N_EXPTS_ACT"]
+
+ if (gindx is not None) and launch_metadata_allow_sync():
+ # recreate inverse GatherIndx.
+ dst = torch.full_like(gindx, -1)
+ idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
+ mask = gindx != -1
+ dst[gindx[mask]] = idx[mask]
+ n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
+ else:
+ n_read_rows = n_tokens
+ n_x_bytes = n_read_rows * X.shape[-1] * X.element_size()
+ n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size()
+ ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes)
+
+ return ret
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/_matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..45528d5ed7379be143e32329fc8b2009cd135dc5
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/matmul_ogs_details/_matmul_ogs.py
@@ -0,0 +1,429 @@
+# isort: off
+# fmt: off
+import triton
+import triton.language as tl
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.cdna4_scale import unswizzle_mx_scale_cdna4
+from vllm.kvprune.triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
+from vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
+from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
+
+
+@triton.jit
+def _zero_masked_rows(
+ pid_m, pid_n,
+ Y, stride_y_m, stride_y_n,
+ N,
+ ScatterSrcIndx, num_idxs,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
+ offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M)
+ offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
+ src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
+ YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
+ mask_n = offs_n < N
+ mask = (src_idx == -1)[:, None] & mask_n[None, :]
+ tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask)
+
+
+_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
+@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
+ repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
+def _matmul_ogs(
+ Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
+ YExpectedScale, YActualScale, YChecksumScale,
+ stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
+ X, XPtr, stride_x_z, stride_x_m, stride_x_k,
+ XScale,
+ XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
+ W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
+ WScale,
+ WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
+ B, stride_b_e, # Bias
+ NRows, M, N, K, # shapes
+ # expt data
+ Betas, Gammas,
+ GatherIndx,
+ ScatterSrcIndx, num_idxs,
+ WriteBackIndx, writeback_size,
+ ExptHist, ExptOffs, ExptOffsSum, ExptData,
+ # true grid size
+ batch_size, grid_m, grid_n,
+ # Out scale
+ out_alpha,
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
+ # MoE config
+ N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
+ # precision config
+ MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ PER_BATCH_SCALE: tl.constexpr,
+ # optimization config
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
+ # One of ["HOPPER", "BLACKWELL", None]
+ SWIZZLE_MX_VALUE: tl.constexpr,
+ # One of ["HOPPER", "BLACKWELL", None]
+ SWIZZLE_MX_SCALE: tl.constexpr,
+ EPILOGUE_SUBTILE: tl.constexpr,
+ EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
+ W_CACHE_MODIFIER: tl.constexpr,
+ NUM_SMS: tl.constexpr,
+ X_TMA_MODE: tl.constexpr,
+ Y_TMA_MODE: tl.constexpr,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=None,
+ UPCAST_INDICES: tl.constexpr = False,
+ SWAP_XW: tl.constexpr = False,
+ IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
+
+ tl.assume(stride_y_k >= 0)
+ tl.assume(stride_y_z >= 0)
+ tl.assume(stride_y_m >= 0)
+ tl.assume(stride_y_n >= 0)
+ tl.assume(stride_x_z >= 0)
+ tl.assume(stride_x_m >= 0)
+ tl.assume(stride_x_k >= 0)
+ tl.assume(stride_w_e >= 0)
+ tl.assume(stride_w_k >= 0)
+ tl.assume(stride_w_n >= 0)
+ if stride_w_mx_e is not None:
+ tl.assume(stride_w_mx_e >= 0)
+ if stride_w_mx_k is not None:
+ tl.assume(stride_w_mx_k >= 0)
+ if stride_w_mx_n is not None:
+ tl.assume(stride_w_mx_n >= 0)
+ if B is not None:
+ tl.assume(stride_b_e >= 0)
+ tl.assume(batch_size >= 0)
+ tl.assume(grid_m >= 0)
+ tl.assume(grid_n >= 0)
+
+ is_w_microscaled: tl.constexpr = WMxScale is not None
+ MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
+ if is_w_microscaled:
+ w_type: tl.constexpr = W.dtype.element_ty
+ is_mxfp4: tl.constexpr = w_type == tl.uint8
+ tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
+ "mx_weight_ptr must be uint8 or fp8")
+ tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
+ else:
+ tl.static_assert(SWIZZLE_MX_VALUE is None)
+ tl.static_assert(SWIZZLE_MX_SCALE is None)
+ is_x_microscaled: tl.constexpr = XMxScale is not None
+ if is_x_microscaled:
+ x_type: tl.constexpr = X.dtype.element_ty
+ tl.static_assert(is_w_microscaled)
+ tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
+ tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ is_out_microscaled: tl.constexpr = stride_y_mx_z is not None
+
+ OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
+ yN = N // ACTIVATION_REDUCTION_N
+
+ pid = tl.program_id(0)
+ if ExptOffsSum is not None and XCD_SWIZZLE > 1:
+ # Determine how much padding there is on the expert data. This allows us to
+ # know the true grid size and avoid processing padding tiles.
+ padding_m = grid_m - tl.load(ExptOffsSum)
+ else:
+ padding_m: tl.constexpr = 0
+
+ HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
+ index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32
+
+ unpadded_m = grid_m - padding_m
+ tl.assume(unpadded_m >= 0)
+ total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K
+ if padding_m > 0 and pid >= total_actual_tiles:
+ tl.device_assert(batch_size == 0)
+ pid_mn = pid - total_actual_tiles
+ if pid_mn < padding_m * grid_n:
+ pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M)
+
+ # set masked out rows to 0
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
+ _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
+ return
+
+ # swizzle program ids
+ pid_emnk = pid
+ if XCD_SWIZZLE != 1:
+ pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE)
+ pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K)
+ pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K)
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ pid_m, pid_n = swizzle2d(pid_mn, unpadded_m, grid_n, GROUP_M)
+ # For split-k, advance to the output k slice
+ if SPLIT_K > 1:
+ Y += pid_k.to( index_type) * stride_y_k
+ if is_out_microscaled:
+ YActualScale += pid_k.to(index_type) * stride_x_mx_k
+ # set masked out rows to 0
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
+ _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
+ # unpack expert data
+ if ExptData is None:
+ tl.static_assert(M is not None)
+ expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m
+ else:
+ tl.static_assert(M is None)
+ expt_data = tl.load(ExptData + pid_m)
+ if expt_data == -1:
+ return
+ expt_id = expt_data & 0x0000FFFF
+ block_id = expt_data >> 16
+ M = tl.load(ExptHist + expt_id)
+ start_m = tl.load(ExptOffs + expt_id)
+ start_z = 0
+ expt_id, block_id = expt_id.to(index_type), block_id.to(index_type)
+ start_m, start_z = start_m.to(index_type), start_z.to(index_type)
+ pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type)
+ # A pointers
+ offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
+ offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M)
+ X += start_z * stride_x_z
+ if GatherIndx is None:
+ X += start_m * stride_x_m
+ else:
+ GatherIndx += start_m
+ # no needs to bounds-check here because `offs_x_m` wraps around M dim
+ offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT
+ offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K)
+ XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k
+
+ # TODO: refactor if/else when triton front end improves
+ if is_w_microscaled:
+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
+ tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
+ tl.static_assert(not is_x_microscaled)
+ # We have pack 2 fp4 values in a byte but we divide the dimension by 2
+ # when swizzling
+ W_K_DIVISOR: tl.constexpr = 1
+ W_K_MULTIPLIER: tl.constexpr = 2
+ W_N_DIVISOR: tl.constexpr = 4
+ else:
+ # We have pack 2 fp4 values in a byte
+ W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1
+ W_K_MULTIPLIER: tl.constexpr = 1
+ W_N_DIVISOR: tl.constexpr = 1
+
+ if W_TRANSPOSE:
+ # When weight is transposed, 2 fp4 values are packed per Byte along
+ # the contiguous dimension, K.
+ PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR
+ else:
+ # When weight is not transposed, fp4 values are *not* packed along
+ # the contiguous dimension, N.
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_K_DIVISOR
+ MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
+
+ WMxScale += expt_id * stride_w_mx_e
+
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ # TODO: support non W_TRANSPOSE with blackwell swizzling
+ tl.static_assert(W_TRANSPOSE)
+ tl.static_assert(BLOCK_N % 128 == 0)
+ tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0)
+ PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 128
+ stride_scale_k: tl.constexpr = 1
+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
+ # TODO: support non W_TRANSPOSE with Hopper swizzling
+ tl.static_assert(W_TRANSPOSE)
+ n_warps: tl.constexpr = tl.extra.cuda.num_warps()
+ tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0)
+ tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0)
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32
+ stride_scale_k = stride_w_mx_k
+ elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
+ tl.static_assert(stride_w_mx_k is not None)
+ tl.static_assert(stride_w_mx_n is not None)
+ NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE
+ stride_scale_k = stride_w_mx_k
+ else:
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N
+ stride_scale_k = stride_w_mx_k
+ offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
+ offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
+ # K dimension must be the last dimension for the scales
+ offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
+ WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
+ else:
+ WMxScalePtrs = None
+ offs_k_scale = None
+ W_K_DIVISOR: tl.constexpr = 1
+ W_K_MULTIPLIER: tl.constexpr = 1
+ W_N_DIVISOR: tl.constexpr = 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N
+
+ # B pointers
+ offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
+ offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
+
+ if is_x_microscaled:
+ XMxScale += start_z.to(index_type) * stride_x_mx_z
+ if GatherIndx is None:
+ XMxScale += start_m * stride_x_mx_m
+ offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
+ XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
+ else:
+ XMxScalePtrs = None
+
+ offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W)
+ W += expt_id * stride_w_e
+ WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n)
+ # compute output
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)):
+ if EVEN_K:
+ mask_k = tl.full([BLOCK_K], True, dtype=tl.int1)
+ mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1)
+ if is_w_microscaled and SWIZZLE_MX_SCALE is None:
+ mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1)
+ if is_x_microscaled:
+ mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
+ else:
+ mask_k = offs_k < k
+ mask_k_w = offs_w_k < ((k // (W_K_DIVISOR if W_TRANSPOSE else 1)) * W_K_MULTIPLIER)
+ if is_w_microscaled and SWIZZLE_MX_SCALE is None:
+ mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k
+ if is_x_microscaled:
+ mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < k
+
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+ w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER)
+ if is_w_microscaled:
+ x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
+ w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
+
+ if is_x_microscaled:
+ x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :])
+ elif x_format == "fp16" or x_format == "bf16":
+ x_scales: tl.constexpr = None
+ else:
+ # Scale of 1 in E8M0 format
+ x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8)
+
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ w_scales = unswizzle_mx_scale_bw(tl.load(WMxScalePtrs))
+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
+ # Handshake with the swizzling code
+ num_warps: tl.constexpr = tl.extra.cuda.num_warps()
+ w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps)
+ elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
+ w_scales = unswizzle_mx_scale_cdna4(tl.load(WMxScalePtrs), BLOCK_N, MX_SCALE_BLOCK_K)
+ else:
+ w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :])
+
+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
+ # Handshake with the swizzling code
+ tl.static_assert(x_format == "bf16")
+ tl.static_assert(w_format == "e2m1")
+ w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1)
+ tl.static_assert(w.dtype == tl.bfloat16)
+ acc = acc.trans()
+ x = x.trans()
+ # w = w.trans()
+ acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ acc = acc.trans()
+ else:
+ rhs_k_pack: tl.constexpr = W_TRANSPOSE or not is_w_microscaled or W_K_DIVISOR != 2
+ acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True, rhs_k_pack=rhs_k_pack)
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ WMxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_w_mx_k
+ else:
+ WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k
+ if is_x_microscaled:
+ XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k
+ else:
+ acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k
+ WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k
+ # bias + scale
+ offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
+ offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
+ mask_m = offs_m < M
+ mask_n = offs_y_n < N
+ if B is not None:
+ BPtrs = B + expt_id * stride_b_e + offs_y_n
+ if pid_k == 0:
+ bias = tl.load(BPtrs, mask=mask_n, other=0)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ if Betas is not None:
+ betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0)
+ else:
+ betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ if Gammas is not None:
+ gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0)
+ else:
+ gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ # flexpoint
+ x_scale = load_scale(XScale)
+ if PER_BATCH_SCALE:
+ w_scale = load_scale(WScale + expt_id)
+ else:
+ w_scale = load_scale(WScale)
+ acc *= x_scale * w_scale
+ acc = acc + bias[None, :] * betas[:, None]
+ if out_alpha is not None:
+ acc *= out_alpha
+ if ACTIVATION_FN is not None:
+ out = ACTIVATION_FN(acc, *activation_fn_args)
+ tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
+ offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N)
+ mask_n = offs_y_n < yN
+ else:
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
+ out = acc
+ out *= gammas[:, None]
+ # write-back
+ Y += start_z.to(index_type) * stride_y_z
+ if WriteBackIndx is not None:
+ WriteBackIndx += start_m
+ dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1)
+ mask_m = mask_m & (dst_idx != -1)
+ offs_y_m = dst_idx
+ else:
+ Y += start_m * stride_y_m
+ offs_y_m = offs_m
+
+ YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n
+ mask = mask_m[:, None] & mask_n[None, :]
+ if is_out_microscaled:
+ MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
+ N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
+ tl.static_assert(EPILOGUE_FN is not None)
+ out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
+ tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
+ offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N)
+ mask_n_scale = offs_y_n_scale < N_MX_BLOCK
+ YActualScale += start_z.to(index_type) * stride_y_mx_z
+ if WriteBackIndx is None:
+ YActualScale += start_m * stride_y_mx_m
+ YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
+ else:
+ YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
+ tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
+ else:
+ out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)
+ if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8:
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty)
+ tl.store(YPtrs, out, mask=mask)
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dcaf2d88970e301c797268cbb7a7ba2ceb6521d
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
@@ -0,0 +1,471 @@
+# isort: off
+# fmt: off
+import torch
+import triton
+import triton.language as tl
+from triton.tools.ragged_tma import load_ragged, store_ragged
+from vllm.kvprune.triton_kernels import target_info
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
+from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
+ float_to_flex,
+ load_scale,
+ nan_propagating_absmax_reduce,
+ compute_scale,
+)
+from vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
+from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
+
+
+@triton.constexpr_function
+def cuda_capability_geq(major, minor):
+ return target_info.cuda_capability_geq(major, minor)
+
+@triton.constexpr_function
+def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
+ if isinstance(tensor_or_desc, tl.tensor):
+ return tensor_or_desc.dtype.element_ty
+ elif isinstance(tensor_or_desc, tl.tensor_descriptor):
+ return tensor_or_desc.dtype
+ else:
+ raise ValueError(f"Invalid type: {type(tensor_or_desc)}")
+
+@triton.jit
+def _load_tile_attrs(
+ tile_id, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr):
+ # unpack and swizzle program ids
+ pid_emnk = tile_id
+ if XCD_SWIZZLE != 1:
+ pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE)
+ pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
+ pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
+ if SPLIT_K > 1:
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ else:
+ pid_k: tl.constexpr = 0
+ pid_mn = pid_mnk
+ pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
+
+ # unpack expert data
+ if ExptData is None:
+ tl.static_assert(M is not None)
+ expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1
+ else:
+ tl.static_assert(M is None)
+ expt_data = tl.load(ExptData + pid_m)
+ expt_id = expt_data & 0x0000FFFF
+ block_id = expt_data >> 16
+ eM = tl.load(ExptHist + expt_id)
+ start_m = tl.load(ExptOffs + expt_id)
+ start_z = 0
+
+ off_m = BLOCK_M * block_id
+ off_n = BLOCK_N * pid_n
+
+ return expt_id, start_z, start_m, eM, off_m, off_n, pid_k
+
+@triton.jit
+def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
+ mask = mask & (offs < writeback_size)
+ offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1)
+ mask = offs != -1
+ return (offs, mask)
+
+
+_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2])
+@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
+ repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
+def _p_matmul_ogs(
+ Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
+ YExpectedScale, YActualScale, YChecksumScale,
+ stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
+ X, XPtr, stride_x_z, stride_x_m, stride_x_k,
+ XScale,
+ XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
+ W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
+ WScale,
+ MxScale, stride_mx_e, stride_mx_k, stride_mx_n,
+ B, stride_b_e, # Bias
+ NRows, M, N, K, # shapes
+ # expt data
+ Betas, Gammas,
+ GatherIndx,
+ ScatterSrcIndx, num_idxs,
+ WriteBackIndx, writeback_size,
+ ExptHist, ExptOffs, ExptOffsSum, ExptData,
+ # true grid size
+ batch_size, grid_m, grid_n,
+ # Out scale
+ out_alpha,
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
+ # MoE config
+ N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
+ # precision config
+ MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ PER_BATCH_SCALE: tl.constexpr,
+ # optimization config
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
+ # NYI: Must be None
+ SWIZZLE_MX_VALUE: tl.constexpr,
+ # One of ["BLACKWELL", None]
+ SWIZZLE_MX_SCALE: tl.constexpr,
+ EPILOGUE_SUBTILE: tl.constexpr,
+ EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
+ W_CACHE_MODIFIER: tl.constexpr,
+ NUM_SMS: tl.constexpr,
+ X_TMA_MODE: tl.constexpr,
+ Y_TMA_MODE: tl.constexpr,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=None,
+ UPCAST_INDICES:tl.constexpr=False,
+ SWAP_XW: tl.constexpr = False,
+ IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
+ # tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")
+
+ # why is this faster than using host-side tensor descriptor?!
+ if Y_TMA_MODE is not None:
+ Y = tl.make_tensor_descriptor(YPtr, Y.shape, Y.strides[:-1] + (1,), Y.block_shape)
+
+ is_microscaled_format: tl.constexpr = MxScale is not None
+ tl.static_assert(not is_microscaled_format or W_TRANSPOSE, "NYI. Non-transposed mxfp4 weights")
+ MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
+ if is_microscaled_format:
+ w_type: tl.constexpr = get_dtype(W)
+ tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
+ "mx_weight_ptr must be uint8")
+ tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
+
+ # We have pack 2 fp4 values in a byte
+ W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR
+ MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
+ else:
+ W_PACK_DIVISOR: tl.constexpr = 1
+ MX_SCALE_BLOCK_K: tl.constexpr = 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ tl.static_assert(SWIZZLE_MX_SCALE is None)
+
+ if ExptOffsSum is not None:
+ # Determine how much padding there is on the expert data. This allows us to
+ # know the true grid size and avoid processing padding tiles.
+ padding_m = grid_m - tl.load(ExptOffsSum)
+ else:
+ padding_m: tl.constexpr = 0
+
+ index_type: tl.constexpr = tl.int64
+
+ USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None
+ HAS_SCATTER: tl.constexpr = WriteBackIndx is not None
+ HAS_GATHER: tl.constexpr = GatherIndx is not None
+ USE_GATHER_TMA: tl.constexpr = HAS_GATHER and X_TMA_MODE == "dense"
+ USE_SCATTER_TMA: tl.constexpr = HAS_SCATTER and Y_TMA_MODE == "dense"
+
+ if EPILOGUE_SUBTILE is None:
+ SUBTILE_FACTOR: tl.constexpr = 1
+ else:
+ SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE
+ EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR
+ OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N
+ yN = N // ACTIVATION_REDUCTION_N
+
+ # set masked out rows to 0
+ if HAS_SCATTER and N_EXPTS_ACT == 1:
+ # Iterate with reversed pids so that later pids will get more tiles if the number of
+ # tiles isn't evenly divisible by the number of SMs.
+ # The main loop after this iterates in the forward direction such that earlier
+ # pids get more tiles if the number of tiles isn't evenly divisible.
+ # This helps balance the work across the SMs.
+ for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS):
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M)
+
+ z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32)
+ offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0])
+ offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1])
+ src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
+ YPtrs = YPtr + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
+ mask_n = offs_n < yN
+ mask = (src_idx == -1)[:, None] & mask_n[None, :]
+ tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask)
+
+
+ k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K)
+ num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
+
+ # If true, do not share loop-carried variables between the prologue and the
+ # epilogue to enable better pipelining with mmav5
+ INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0)
+
+ # start negative; will be incremented at the top of the loop
+ if INDEPENDENT_EPILOGUE:
+ tile_id1 = tl.program_id(0) - NUM_SMS
+
+ # Keep track of local max for updating flexpoint scales.
+ THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
+ local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
+
+ DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
+
+ for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True):
+ expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(
+ tile_id, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M, BLOCK_N, SPLIT_K,
+ GROUP_M, XCD_SWIZZLE)
+
+ # Base pointers and offsets.
+ if X_TMA_MODE is None:
+ XBase = X + start_z.to(index_type) * stride_x_z
+ offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k
+ if SPLIT_K > 1:
+ offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k
+
+ if USE_GATHER_TMA:
+ offs_m = off_m + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < (M if M is not None else eM)
+ if ExptData is None:
+ offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m)
+ # Bump rows to account for the Z offset.
+ offs_x_m += start_z * (stride_x_z // stride_x_m)
+ offs_x_m = tl.where(mask_m, offs_x_m, -1)
+ else:
+ offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
+ mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
+ elif X_TMA_MODE is None:
+ tl.static_assert(HAS_GATHER)
+ offs_m = off_m + tl.arange(0, BLOCK_M)
+ if M is not None:
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M)
+ else:
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
+ # no needs to bounds-check here because `offs_m` wraps around M dim
+ offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
+ offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
+
+
+ acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
+ off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
+ off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
+ off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
+
+ # --- load x ---
+ if USE_GATHER_TMA:
+ x = X.gather(offs_x_m, off_k)
+ elif X_TMA_MODE == "dense":
+ x = X.load([start_z, start_m + off_m, off_k])
+ x = x.reshape(BLOCK_M, BLOCK_K)
+ elif X_TMA_MODE == "ragged":
+ x = load_ragged(X, start_m, eM, [start_z, off_m, off_k], ragged_dim=1)
+ x = x.reshape(BLOCK_M, BLOCK_K)
+ else:
+ tl.static_assert(X_TMA_MODE is None)
+ XPtrs = XBase + offs_x_m + offs_x_k
+ XBase += BLOCK_K * SPLIT_K * stride_x_k
+ mask_k = tl.arange(0, BLOCK_K) < K - off_k
+ if EVEN_K:
+ if SPLIT_K > 1:
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+ else:
+ x = tl.load(XPtrs)
+ else:
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+
+ # --- load w ---
+ if W_TRANSPOSE:
+ w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:]).T
+ else:
+ w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:])
+
+ # --- load w_scale ---
+ if is_microscaled_format:
+ x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
+ mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
+ if x_format == "fp16" or x_format == "bf16":
+ x_scales: tl.constexpr = None
+ else:
+ x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8)
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128)
+ w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
+ w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
+ w_scales = unswizzle_mx_scale_bw(w_scales)
+ else:
+ w_scales = MxScale.load([expt_id, off_k_mx, off_n])
+ w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T
+
+ # --- update accumulator ---
+ if is_microscaled_format:
+ if SWAP_XW:
+ acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
+ else:
+ acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
+ else:
+ if SWAP_XW:
+ acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ else:
+ acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+
+ if INDEPENDENT_EPILOGUE:
+ tile_id1 += NUM_SMS
+ expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs(
+ tile_id1, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M, BLOCK_N, SPLIT_K,
+ GROUP_M, XCD_SWIZZLE)
+ else:
+ tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM
+ off_m1, off_n1, pid_k1 = off_m, off_n, pid_k
+
+ offs_m = off_m1 + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < (M if M is not None else eM1)
+ if USE_SCATTER_TMA:
+ offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
+ MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
+ if SPLIT_K > 1:
+ # Compute the split k offset in number of rows, and add it to offs_y_m.
+ # This allows us to write to the correct slice in the output tensor while using
+ # a 2D TMA scatter.
+ tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m))
+ split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m)
+ offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m)
+ elif Y_TMA_MODE is None:
+ tl.static_assert(HAS_SCATTER)
+ offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
+ MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
+ else:
+ offs_y_m = start_m1 + offs_m
+ MASK_ACC = False if USE_GATHER_TMA else USE_FLEXPOINT_SCALE
+
+ # bias + scale
+ offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
+ mask_n = offs_y_n < N
+ if B is not None:
+ BPtrs = B + expt_id1 * stride_b_e + offs_y_n
+ if pid_k1 == 0:
+ bias = tl.load(BPtrs, mask=mask_n, other=0)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ if Betas is not None:
+ betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
+ else:
+ betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ if Gammas is not None:
+ gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0)
+ else:
+ gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ x_scale = load_scale(XScale)
+ if PER_BATCH_SCALE:
+ w_scale = load_scale(WScale + expt_id1)
+ else:
+ w_scale = load_scale(WScale)
+
+ accs = (acc,)
+ biases = (bias,)
+
+ if SUBTILE_FACTOR >= 2:
+ acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
+ accs = (acc0, acc1)
+ bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
+ biases = (bias0, bias1)
+
+ if SUBTILE_FACTOR >= 4:
+ acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
+ acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
+ accs = (acc00, acc01, acc10, acc11)
+ bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
+ bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
+ biases = (bias00, bias01, bias10, bias11)
+
+ tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
+ tl.static_assert(len(accs) == SUBTILE_FACTOR)
+
+ for a_i in tl.static_range(len(accs)):
+ acc_tile = accs[a_i]
+ acc_tile *= x_scale * w_scale
+
+ if SWAP_XW:
+ acc_tile = acc_tile.T
+
+ acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
+ if out_alpha is not None:
+ acc_tile *= out_alpha
+
+ if ACTIVATION_FN is not None:
+ out = ACTIVATION_FN(acc_tile, *activation_fn_args)
+ tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
+ else:
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
+ out = acc_tile
+
+ out *= gammas[:, None]
+
+ if MASK_ACC:
+ out = tl.where(mask_m[:, None], out, 0.0)
+ # Flexpoint
+ out_view = tl.reshape(out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
+ local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
+ out = float_to_flex(
+ out, YExpectedScale,
+ None, # ActualScale: local absmax is tracked and updated after the loop
+ YChecksumScale,
+ None, # mask: out is manually masked to 0
+ YPtr, FLEXPOINT_SATURATE_INF
+ )
+ if EPILOGUE_FN is not None:
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
+
+ out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
+ out = out.to(YPtr.dtype.element_ty)
+ if USE_SCATTER_TMA:
+ # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
+ # there shouldn't be any other negative values.
+ offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True)
+ Y.scatter(out, offs_y_m, out_off_n)
+ elif Y_TMA_MODE == "dense":
+ out = tl.reshape(out, [1] + out.shape)
+ off_kz = pid_k * batch_size + start_z1
+ Y.store([off_kz, off_m1, out_off_n], out)
+ elif Y_TMA_MODE == "ragged":
+ out = tl.reshape(out, [1] + out.shape)
+ store_ragged(Y, start_m1, eM1, [pid_k, off_m1, out_off_n], out, ragged_dim=1)
+ else:
+ tl.static_assert(Y_TMA_MODE is None)
+ offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
+ mask_n = offs_y_n < yN
+
+ YPtrs = YPtr + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n
+ mask = mask_m[:, None] & mask_n[None, :]
+ tl.store(YPtrs, out, mask=mask)
+
+
+ # Update the flexpoint scales
+ if YActualScale is not None:
+ tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), YPtr), sem="relaxed")
+
+
+_per_device_alloc_fns = {}
+def get_per_device_per_stream_alloc_fn(device):
+ if device not in _per_device_alloc_fns:
+ _per_stream_tensors = {}
+ def alloc_fn(size: int, alignment: int, stream):
+ assert alignment == 128
+ if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size:
+ _per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8)
+ _per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
+ return _per_stream_tensors[stream]
+
+ _per_device_alloc_fns[device] = alloc_fn
+ return _per_device_alloc_fns[device]
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/_reduce_grouped.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/_reduce_grouped.py
new file mode 100644
index 0000000000000000000000000000000000000000..2972a6240856eb11db1cf54b6adb734b7bc57b6e
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/matmul_ogs_details/_reduce_grouped.py
@@ -0,0 +1,126 @@
+from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
+ float_to_flex,
+ load_scale,
+)
+from vllm.kvprune.triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _reduce_grouped(
+ X,
+ stride_xb: tl.uint64,
+ stride_xm: tl.uint64,
+ stride_xn, #
+ XScale, # input scalar flex scale
+ Out,
+ stride_om: tl.uint64,
+ stride_on, # output tensor
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale, # output scalar flex scales
+ InIndx,
+ B,
+ N, #
+ XMxScale,
+ stride_mxb: tl.uint64,
+ stride_mxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
+ OutMxScale,
+ stride_omxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr,
+ activation_fn_args,
+ ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr,
+ epilogue_fn_args,
+ #
+ HAS_IN_MX_SCALE: tl.constexpr,
+ HAS_OUT_MX_SCALE: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ K: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ pid_t = tl.program_id(0)
+ BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
+ # persistent along N: single program on N, iterate tiles of size BLOCK_N
+ start = pid_t * K
+ # load indices into a tuple
+ if InIndx is None:
+ indxs = (pid_t,)
+ else:
+ indxs = ()
+ for i in tl.static_range(0, K):
+ indxs = indxs + (tl.load(InIndx + start + i),)
+ # determine first valid topk row
+ fi = indxs[(K - 1)]
+ for i in tl.static_range(K - 2, -1, -1):
+ fi = tl.where(indxs[i] != -1, indxs[i], fi)
+ # record overwritten row index (may be -1 if none)
+ XPtrs = X + tl.arange(0, BLOCK_N) * stride_xn
+ OutPtrs = Out + tl.arange(0, BLOCK_N_OUT) * stride_on
+ if HAS_IN_MX_SCALE:
+ XScalePtrs = XMxScale + tl.arange(0, BLOCK_N // 32) * stride_xn
+ if HAS_OUT_MX_SCALE:
+ OutScalePtrs = OutMxScale + tl.arange(0, BLOCK_N_OUT // 32) * stride_on
+ x_scale = load_scale(XScale)
+ for n_curr in tl.range(0, N, BLOCK_N, num_stages=4):
+ acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)
+ x_n_mask = tl.arange(0, BLOCK_N) < N - n_curr
+ x_n_mask_scale = tl.arange(0, BLOCK_N // 32) < tl.cdiv(N - n_curr, 32)
+ # accumulate contributions for this tile
+ for i in tl.static_range(0, K):
+ curr = tl.zeros([BLOCK_N], dtype=tl.float32)
+ # iterate over split_k partial values
+ for b in tl.range(0, B):
+ is_valid = indxs[i] != -1
+ x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb
+ vals = tl.load(x_row_ptr, mask=x_n_mask & is_valid, other=0.0)
+ vals = vals.to(tl.float32)
+ if HAS_IN_MX_SCALE:
+ scale_row_ptr = XScalePtrs + indxs[i] * stride_mxs + b * stride_mxb
+ scale = tl.load(
+ scale_row_ptr, mask=x_n_mask_scale & is_valid, other=0.0
+ )
+ scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
+ vals = vals.reshape([BLOCK_N // 32, 32])
+ vals = (scale[:, None] * vals).reshape([BLOCK_N])
+ curr += vals
+ # apply nonlinearity to split-k output
+ if ACTIVATION_FN is not None:
+ curr = ACTIVATION_FN(curr[None, :], *activation_fn_args)
+ curr = tl.reshape(curr, [curr.shape[-1]])
+ # update final accumulator
+ acc += curr
+ acc *= x_scale
+ # Compute per-32-col MXFP scales for this tile if requested
+ Nrem = (N - n_curr) // ACTIVATION_REDUCTION_N
+ out_n_mask = tl.arange(0, BLOCK_N_OUT) < Nrem
+ out_n_mask_scale = tl.arange(0, BLOCK_N_OUT // 32) < tl.cdiv(Nrem, 32)
+ if HAS_OUT_MX_SCALE:
+ acc, acc_scale = quantize_mxfp8_fn(acc[None, :], out_n_mask[None, :])
+ acc = tl.reshape(acc, [acc.shape[-1]])
+ acc_scale = tl.reshape(acc_scale, [acc_scale.shape[-1]])
+ # Convert to flexpoint output if configured (scalar scales)
+ acc = float_to_flex(
+ acc,
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale,
+ None,
+ Out,
+ FLEXPOINT_SATURATE_INF,
+ )
+ # write-back for this tile
+ out_ptr = OutPtrs + pid_t * stride_om
+ tl.store(out_ptr, acc, mask=out_n_mask)
+ if HAS_OUT_MX_SCALE:
+ out_scale_ptr = OutScalePtrs + pid_t * stride_omxs
+ tl.store(out_scale_ptr, acc_scale, mask=out_n_mask_scale)
+ XPtrs += BLOCK_N * stride_xn
+ OutPtrs += BLOCK_N_OUT * stride_on
+ if HAS_IN_MX_SCALE:
+ XScalePtrs += BLOCK_N // 32 * stride_xn
+ if HAS_OUT_MX_SCALE:
+ OutScalePtrs += BLOCK_N_OUT // 32 * stride_xn
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..68bfd0b75f35db339bb5f84386110dd1946dce24
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags.py
@@ -0,0 +1,303 @@
+# isort: off
+# fmt: off
+from dataclasses import dataclass
+import triton
+from vllm.kvprune.triton_kernels.target_info import get_cdna_version
+import torch
+from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
+
+
+@dataclass
+class OptFlags:
+ block_m: int
+ block_n: int
+ block_k: int
+ num_warps: int
+ num_stages: int
+ group_m: int
+ xcd_swizzle: int
+ w_cache_modifier: str
+ split_k: int
+ is_persistent: bool
+ fused_scatter: bool
+ idle_sms: int
+ epilogue_subtile: int | None
+ arch: str
+ target_kernel_kwargs: dict
+
+ def __post_init__(self):
+ if self.fused_scatter and self.split_k != 1:
+ raise ValueError("Not supported")
+
+
+def make_default_opt_flags_amd(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ enforce_bitwise_invariance,
+ epilogue_effective_itemsize,
+ constraints,
+):
+ constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
+ # tokens per expert
+ if routing_data is None:
+ tokens_per_expt = m
+ elif routing_data.expected_tokens_per_expt is None:
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
+ else:
+ tokens_per_expt = routing_data.expected_tokens_per_expt
+
+ is_cdna4 = get_cdna_version() == 4
+ # block_m
+ if constraints.get("block_m", None):
+ block_m = constraints["block_m"]
+ elif enforce_bitwise_invariance:
+ block_m = 256 if is_cdna4 else 128
+ elif tokens_per_expt >= 512 and n >= 2048:
+ block_m = 256 if is_cdna4 else 128
+ elif is_cdna4 and m >= 512:
+ block_m = 128
+ else:
+ block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
+
+ if routing_data is not None:
+ grid_m = routing_data.n_blocks(m, block_m)
+ else:
+ grid_m = triton.cdiv(m, block_m)
+ # group_m:
+ group_m = 4
+ # number of xcds
+ num_xcds = 8
+ xcd_swizzle = num_xcds
+ # block_nk:
+ block_n, block_k = opt_flags_amd.compute_block_nk(
+ n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
+ )
+ # Replace block_k if provided in constraints.
+ # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
+ if constraints.get("block_k", None) is not None:
+ block_k = constraints["block_k"]
+ if constraints.get("block_n", None) is not None:
+ block_n = constraints["block_n"]
+ is_persistent = constraints.get("is_persistent", False)
+ # split_k:
+ if constraints.get("split_k", None) is not None:
+ split_k = constraints["split_k"]
+ elif is_persistent or enforce_bitwise_invariance:
+ split_k = 1
+ else:
+ grid_size = grid_m * ((n + block_n - 1) // block_n)
+ n_cu = torch.cuda.get_device_properties(0).multi_processor_count
+ split_k = max(1, n_cu // grid_size)
+ # w_cache_modifier:
+ w_cache_modifier = ".cg" if block_m <= 32 else None
+ # num_warps, num_stages
+ num_warps = 2 if (m is not None and m <= 16) else 8
+ num_stages = 2
+ # AMD-specific
+ target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
+ epilogue_subtile = constraints.get('epilogue_subtile', None)
+ if epilogue_subtile is None:
+ epilogue_subtile = 1
+ ret = OptFlags(
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ group_m=group_m,
+ xcd_swizzle=xcd_swizzle,
+ w_cache_modifier=w_cache_modifier,
+ split_k=split_k,
+ is_persistent=is_persistent,
+ fused_scatter=constraints.get('fused_scatter', False),
+ idle_sms=0,
+ epilogue_subtile=epilogue_subtile,
+ arch=None,
+ target_kernel_kwargs=target_kernel_kwargs,
+ )
+ # check constraints
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
+ return ret
+
+def make_default_opt_flags_nvidia(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ enforce_bitwise_invariance,
+ epilogue_effective_itemsize,
+ constraints,
+):
+ constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
+ # tokens per expert
+ if routing_data is None:
+ tokens_per_expt = m
+ elif routing_data.expected_tokens_per_expt is None:
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
+ else:
+ tokens_per_expt = routing_data.expected_tokens_per_expt
+ # pid swizzling
+ group_m = 8
+ xcd_swizzle = 1
+ # block_m
+ if constraints.get("block_m", None):
+ block_m = constraints["block_m"]
+ elif enforce_bitwise_invariance:
+ block_m = 128
+ else:
+ block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
+ # block n
+ arch = None
+ block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
+ # is_persistent
+ grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
+ n_sms = torch.cuda.get_device_properties(0).multi_processor_count
+ tiles_per_sm = grid_size / n_sms
+ supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
+ if constraints.get("is_persistent", None) is not None:
+ is_persistent = constraints["is_persistent"]
+ else:
+ has_simple_epilogue = precision_config.max_num_imprecise_acc is None
+ is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
+ # TEMP CHANGE
+ if precision_config.act_scale is not None or precision_config.out_scale is not None:
+ is_persistent = False
+ # block k
+ if constraints.get("block_k", None) is not None:
+ block_k = constraints["block_k"]
+ else:
+ block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
+ # split_k
+ if constraints.get("split_k", None) is not None:
+ split_k = constraints["split_k"]
+ elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
+ split_k = 1
+ else:
+ estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
+ split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
+ if split_k > 1:
+ # With split_k, results are written in f32. Use that for the following computations.
+ out_dtype = torch.float32
+ compute_num_stages_args = (
+ precision_config,
+ is_persistent,
+
+ block_m,
+ block_n,
+ block_k,
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ )
+
+ if constraints.get("epilogue_subtile", None) is not None:
+ subtiles_to_check = [constraints["epilogue_subtile"]]
+ else:
+ subtiles_to_check = [1, 2, 4]
+ num_stages = -1
+ for ep in subtiles_to_check:
+ ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
+ if ns > num_stages:
+ epilogue_subtile, num_stages = ep, ns
+ assert num_stages >= 1
+ if constraints.get("num_stages", None):
+ num_stages = constraints["num_stages"]
+ # fused scatter scratchpad
+ if constraints.get("fused_scatter", None) is not None:
+ fused_scatter = constraints["fused_scatter"]
+ else:
+ fused_scatter = can_use_fused_scatter and split_k == 1
+ # Handshake with the HBM swizzling
+ num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
+ ret = OptFlags(
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ fused_scatter=fused_scatter,
+ group_m=group_m,
+ xcd_swizzle=xcd_swizzle,
+ w_cache_modifier=None,
+ split_k=split_k,
+ is_persistent=is_persistent,
+ epilogue_subtile=epilogue_subtile,
+ arch=arch,
+ target_kernel_kwargs=dict(),
+ idle_sms=constraints.get("idle_sms", 0),
+ )
+ # check constraints
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
+ return ret
+
+# --------------
+# User Interface
+# --------------
+
+_opt_flags_constraints: dict = dict()
+_opt_flags: OptFlags | None = None
+
+def update_opt_flags_constraints(constraints: dict[str, int]):
+ global _opt_flags_constraints
+ _opt_flags_constraints.update(constraints)
+
+def reset_opt_flags_constraints():
+ global _opt_flags_constraints
+ _opt_flags_constraints = dict()
+
+def set_opt_flags(opt_flags: OptFlags):
+ global _opt_flags
+ assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
+ assert not _opt_flags, "opt_flags already set; please reset to None first"
+ _opt_flags = opt_flags
+
+class InapplicableConstraint(Exception):
+ pass
+
+def make_opt_flags(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ epilogue_effective_itemsize,
+):
+ if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
+ raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
+ if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter:
+ raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint")
+ enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
+ if _opt_flags is not None:
+ assert not _opt_flags_constraints
+ return _opt_flags
+ args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
+ routing_data, can_use_persistent_tma, can_use_fused_scatter,
+ enforce_bitwise_invariance, epilogue_effective_itemsize,
+ _opt_flags_constraints]
+ backend = triton.runtime.driver.active.get_current_target().backend
+ if backend == "hip":
+ return make_default_opt_flags_amd(*args)
+ if backend == "cuda":
+ return make_default_opt_flags_nvidia(*args)
+ assert False
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags_details/__init__.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f91ea25b993c4ed3ea0adbc5d531dbed0313640
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
@@ -0,0 +1,37 @@
+import torch
+import triton
+from vllm.kvprune.triton_kernels.target_info import get_cdna_version
+from vllm.kvprune.triton_kernels.tensor import bitwidth
+
+
+def compute_block_nk(
+ n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
+):
+ lhs_width = bitwidth(lhs_dtype) / 8
+ rhs_width = bitwidth(rhs_dtype) / 8
+
+ # block_n:
+ n_cu = torch.cuda.get_device_properties(0).multi_processor_count
+ if n is not None:
+ if n <= 128 and (n & (n - 1)) == 0:
+ block_n = n
+ else:
+ block_n = max(
+ 32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))
+ )
+ elif block_m > 64:
+ block_n = 256
+ else:
+ block_n = 128
+
+ if get_cdna_version() == 4 and block_m == 128:
+ block_n = 512
+
+ # block_k needs to match the cacheline size (128B)
+ block_k = int(128 // min(lhs_width, rhs_width))
+
+ # TODO: block_k = 128 seems to work better for now.
+ # perhaps due to increased number of k loops to pipeline
+ if precision_config.weight_scale is not None and get_cdna_version() != 4:
+ block_k = 128
+ return block_n, block_k
diff --git a/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
new file mode 100644
index 0000000000000000000000000000000000000000..08bc8985c7d5a377d7648b22aba248693c18ca5b
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
@@ -0,0 +1,119 @@
+import torch
+import triton
+from vllm.kvprune.triton_kernels import target_info
+from vllm.kvprune.triton_kernels.tensor import get_layout, bitwidth, FP4
+from vllm.kvprune.triton_kernels.tensor_details.layout import HopperMXScaleLayout
+from vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import (
+ MXFP_BLOCK_SIZE,
+)
+
+
+def compute_grid_size(routing_data, m, n, block_m, block_n):
+ if routing_data is not None:
+ grid_m = routing_data.n_blocks(m, block_m)
+ else:
+ grid_m = triton.cdiv(m, block_m)
+ grid_n = (n + block_n - 1) // block_n
+ return grid_m * grid_n
+
+
+def compute_block_n(n: int, arch, precision_config):
+ # block_n:
+ layout = get_layout(precision_config.weight_scale)
+ if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
+ return 128
+ elif precision_config.max_num_imprecise_acc is None and n > 128:
+ return 256
+ else:
+ return max(16, min(128, triton.next_power_of_2(n)))
+
+
+def compute_block_k(
+ m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config
+):
+ lhs_width = bitwidth(lhs_dtype)
+ rhs_width = bitwidth(rhs_dtype)
+ # block_k needs to match the cacheline size (1024 bits)
+ block_k = int(1024 // min(lhs_width, rhs_width))
+ has_native_mxfp = target_info.cuda_capability_geq(10, 0)
+ if rhs_width == 4 and not has_native_mxfp:
+ block_k = 128
+ elif k is not None:
+ block_k = max(32, min(triton.next_power_of_2(k), block_k))
+ has_mx_weight_scale = (
+ precision_config is not None and precision_config.weight_scale is not None
+ )
+ if has_native_mxfp and is_persistent and has_mx_weight_scale:
+ block_k = min(block_k, 128)
+ return block_k
+
+
+def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
+ device_props = torch.cuda.get_device_properties(0)
+ n_sms = device_props.multi_processor_count
+ split_k = n_sms // grid_size
+ if k is not None:
+ # avoid split_k for small k
+ num_block_k = triton.cdiv(k, block_k)
+ split_k = min(split_k, num_block_k // 4)
+ split_k = max(split_k, 1)
+ return split_k
+
+
+def compute_num_warps(block_m, block_n, precision_config):
+ layout = get_layout(precision_config.weight_scale)
+ if isinstance(layout, HopperMXScaleLayout):
+ return layout.num_warps
+ return max(block_m * block_n // 4096, 4)
+
+
+def compute_num_stages(
+ precision_config,
+ is_persistent,
+ block_m,
+ block_n,
+ block_k,
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ epilogue_subtile,
+ epilogue_effective_itemsize,
+):
+ if precision_config.max_num_imprecise_acc is not None:
+ return 3
+ weight_size = bitwidth(rhs_dtype) / 8
+ stage_size = (
+ block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size
+ )
+ device_props = torch.cuda.get_device_properties(0)
+ smem_capacity = device_props.shared_memory_per_block_optin
+ has_native_mxfp = target_info.cuda_capability_geq(10, 0)
+ if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None:
+ if rhs_dtype == FP4:
+ # 4-bit e2m1 weights are padded 2x
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
+ stage_size += block_k * block_n * weight_size
+
+ if is_persistent:
+ # Per-stage wait barrier
+ stage_size += 8
+ if target_info.cuda_capability_geq(10, 0):
+ acc_size = epilogue_effective_itemsize or out_dtype.itemsize
+ else:
+ acc_size = out_dtype.itemsize
+ if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None:
+ acc_block_n = block_n // epilogue_subtile
+ else:
+ acc_block_n = block_n
+ # pipelined TMA store local to global, or
+ # pipelined layout conversion before store of the accumulator
+ # note: layout conversion has some padding
+ smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
+ if precision_config.weight_scale is not None:
+ # mx scales
+ stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
+ elif has_native_mxfp:
+ # mx scales
+ stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
+ num_stages = min(4, smem_capacity // int(stage_size))
+ return num_stages
diff --git a/vllm/kvprune/triton_kernels/numerics.py b/vllm/kvprune/triton_kernels/numerics.py
new file mode 100644
index 0000000000000000000000000000000000000000..024d3fcf0b819646a485596070b14c7a0a2e17ed
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/numerics.py
@@ -0,0 +1,42 @@
+import torch
+from dataclasses import dataclass
+
+MAX_FINITE_FLOAT8E5 = 57344.0
+MAX_FINITE_FLOAT8E4NV = 448.0
+MAX_FINITE_FLOAT8E4B8 = 240.0
+
+
+@dataclass(frozen=True)
+class BaseFlexData:
+ dtype: torch.dtype | None = None
+
+ def view(self, x: torch.Tensor):
+ if self.dtype is None:
+ return x
+ return x.view(self.dtype)
+
+ def reinterpret(self, x):
+ if self.dtype is None or x.dtype.itemsize > 1:
+ return x
+ return x.view(self.dtype)
+
+
+@dataclass(frozen=True)
+class InFlexData(BaseFlexData):
+ scale: torch.Tensor | None = None
+
+ @property
+ def is_per_batch(self):
+ return False if self.scale is None else len(self.scale) > 1
+
+
+@dataclass(frozen=True)
+class OutFlexData(BaseFlexData):
+ expected_scale: torch.Tensor | None = None
+ actual_scale: torch.Tensor | None = None
+ checksum_scale: torch.Tensor | None = None
+
+ def __iter__(self):
+ yield self.expected_scale
+ yield self.actual_scale
+ yield self.checksum_scale
diff --git a/vllm/kvprune/triton_kernels/numerics_details/__init__.py b/vllm/kvprune/triton_kernels/numerics_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/numerics_details/flexpoint.py b/vllm/kvprune/triton_kernels/numerics_details/flexpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..baa828f4f95392f2ce70ca3e03d1fdded29ff145
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/numerics_details/flexpoint.py
@@ -0,0 +1,204 @@
+from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
+import triton
+import triton.language as tl
+from vllm.kvprune.triton_kernels.target_info import cuda_capability_geq
+
+# -------------------------------
+# Kernels stuff
+# -------------------------------
+
+TL_MAX_FINITE_FLOAT8E5 = tl.constexpr(MAX_FINITE_FLOAT8E5)
+TL_MAX_FINITE_FLOAT8E4NV = tl.constexpr(MAX_FINITE_FLOAT8E4NV)
+TL_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(MAX_FINITE_FLOAT8E4B8)
+TL_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(1.750)
+TL_MAX_FINITE_FLOAT16 = tl.constexpr(65472.0)
+
+TL_RCP_MAX_FINITE_FLOAT8E5 = tl.constexpr(0x37924925) # 0x1.24924Ap-16
+TL_RCP_MAX_FINITE_FLOAT8E4NV = tl.constexpr(0x3B124925) # 0x1.24924Ap-9
+TL_RCP_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(0x3B888889) # 0x1.111112p-8
+TL_RCP_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(0x3F124925) # 0x1.24924Ap-1
+TL_RCP_MAX_FINITE_FLOAT16 = tl.constexpr(0x37802008) # 0x1.004010p-16
+
+
+@triton.jit
+def max_finite(dtype):
+ if dtype == tl.constexpr(tl.float8e5):
+ return TL_MAX_FINITE_FLOAT8E5
+ elif dtype == tl.constexpr(tl.float8e4nv):
+ return TL_MAX_FINITE_FLOAT8E4NV
+ elif dtype == tl.constexpr(tl.float8e4b8):
+ return TL_MAX_FINITE_FLOAT8E4B8
+ elif dtype == tl.constexpr(tl.float8e4b15):
+ return TL_MAX_FINITE_FLOAT8E4B15
+ elif dtype == tl.constexpr(tl.float16):
+ return TL_MAX_FINITE_FLOAT16
+ else:
+ tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
+
+
+@triton.jit
+def rcp_max_finite(dtype):
+ if dtype == tl.constexpr(tl.float8e5):
+ return TL_RCP_MAX_FINITE_FLOAT8E5
+ elif dtype == tl.constexpr(tl.float8e4nv):
+ return TL_RCP_MAX_FINITE_FLOAT8E4NV
+ elif dtype == tl.constexpr(tl.float8e4b8):
+ return TL_RCP_MAX_FINITE_FLOAT8E4B8
+ elif dtype == tl.constexpr(tl.float8e4b15):
+ return TL_RCP_MAX_FINITE_FLOAT8E4B15
+ elif dtype == tl.constexpr(tl.float16):
+ return TL_RCP_MAX_FINITE_FLOAT16
+ else:
+ tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
+
+
+@triton.jit
+def sm86_min_nan_xorsign_abs_f32(a, b):
+ """Wrapper for min.NaN.xorsign.abs.f32 PTX instruction.
+
+ Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
+ NaN inputs are propagated to the output.
+
+ Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
+ """
+ tl.static_assert(
+ cuda_capability_geq(8, 6),
+ "min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+",
+ )
+ tl.static_assert(
+ a.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+ tl.static_assert(
+ b.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+
+ return tl.inline_asm_elementwise(
+ """{
+ min.NaN.xorsign.abs.f32 $0, $1, $2;
+ }""",
+ "=r,r,r",
+ [a, b],
+ dtype=tl.float32,
+ is_pure=True,
+ pack=1,
+ )
+
+
+@triton.jit
+def sm86_max_nan_xorsign_abs_f32(a, b):
+ """Wrapper for max.NaN.xorsign.abs.f32 PTX instruction.
+
+ Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
+ NaN inputs are propagated to the output.
+
+ Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
+ """
+ tl.static_assert(
+ cuda_capability_geq(8, 6),
+ "max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+",
+ )
+ tl.static_assert(
+ a.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+ tl.static_assert(
+ b.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+
+ return tl.inline_asm_elementwise(
+ """{
+ max.NaN.xorsign.abs.f32 $0, $1, $2;
+ }""",
+ "=r,r,r",
+ [a, b],
+ dtype=tl.float32,
+ is_pure=True,
+ pack=1,
+ )
+
+
+@triton.jit
+def load_scale(scale_ptr):
+ return 1.0 if scale_ptr is None else tl.load(scale_ptr)
+
+
+@triton.jit
+def flex_to_float(x, scale_ptr):
+ scale = load_scale(scale_ptr)
+ return x.to(tl.float32) * scale
+
+
+@triton.jit
+def clip(x, limit):
+ res = tl.minimum(x, limit)
+ res = tl.maximum(-limit, res)
+ return res
+
+
+@triton.jit
+def nan_propagating_absmax_reduce(x, axis=None):
+ if cuda_capability_geq(8, 6):
+ # abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported.
+ x_absmax = tl.reduce(x, axis, sm86_max_nan_xorsign_abs_f32)
+ # Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it.
+ x_absmax = x_absmax.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
+ else:
+ # Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float)
+ masked_abs_x = x.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
+ x_absmax = tl.max(masked_abs_x, axis)
+
+ return x_absmax
+
+
+@triton.jit
+def compute_scale(x, Out):
+ x_absmax = nan_propagating_absmax_reduce(tl.ravel(x, can_reorder=True))
+
+ # atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000).
+ # We use integer minimum because NaNs are above +inf in integer representation.
+ x_absmax = tl.minimum(x_absmax, 0x7F800000).to(tl.float32, bitcast=True)
+ RCP_MAX_VALUE = rcp_max_finite(Out.dtype.element_ty)
+ return tl.fma(x_absmax, RCP_MAX_VALUE.to(tl.float32, bitcast=True), 1.0e-30)
+
+
+@triton.jit
+def update_scale(x, scale_ptr, Out) -> None:
+ if scale_ptr is not None:
+ scale = compute_scale(x, Out)
+ tl.atomic_max(scale_ptr, scale, sem="relaxed")
+
+
+@triton.jit
+def float_to_flex(
+ x,
+ expected_scale_ptr_or_val,
+ actual_scale_ptr,
+ checksum_scale_ptr,
+ mask,
+ Out,
+ saturate_infs: tl.constexpr,
+):
+ if expected_scale_ptr_or_val is not None:
+ if expected_scale_ptr_or_val.dtype.is_ptr():
+ invscale = 1.0 / tl.load(expected_scale_ptr_or_val)
+ else:
+ invscale = 1.0 / expected_scale_ptr_or_val
+ else:
+ invscale = 1.0
+ if checksum_scale_ptr is not None:
+ x_int32 = x.to(tl.int32, bitcast=True)
+ zero = tl.cast(0.0, tl.int32)
+ if mask is not None:
+ x_int32 = tl.where(mask, x_int32, zero)
+ checksum_local = tl.xor_sum(tl.ravel(x_int32, can_reorder=True), 0)
+ tl.atomic_add(checksum_scale_ptr, checksum_local)
+ if mask is not None:
+ if actual_scale_ptr is not None:
+ x = tl.where(mask, x, 0.0)
+ update_scale(x, actual_scale_ptr, Out)
+ x = x * invscale
+ # if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case.
+ if expected_scale_ptr_or_val is not None:
+ if saturate_infs:
+ CLIP_VALUE = max_finite(Out.dtype.element_ty)
+ x = clip(x, CLIP_VALUE)
+ return x
diff --git a/vllm/kvprune/triton_kernels/numerics_details/mxfp.py b/vllm/kvprune/triton_kernels/numerics_details/mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..37c69c83c1dd77668ae80cbee0f21bafc5767815
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/numerics_details/mxfp.py
@@ -0,0 +1,303 @@
+# isort: off
+# fmt: off
+from enum import Enum
+import triton
+import torch
+import torch.nn.functional as F
+from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp
+from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, MXFP_BLOCK_SIZE, _quantize_mxfp8_fn
+
+# -----------------------------------------------------------------------------
+# Dequantization / Quantization Utilities
+# -----------------------------------------------------------------------------
+
+
+class DequantScaleRoundingMode(Enum):
+ ROUND_UP = 0
+ ROUND_DOWN = 1
+
+
+def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
+ DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
+ """
+ Convert the src weights to mx format. The src weight is quantized along the axis dimension.
+
+ If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
+ Note that this means the k_dim of the tensor will be half of the logical k_dim.
+
+ If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
+ in their respective formats.
+ """
+ ndim = src_tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ axis = axis if axis >= 0 else axis + ndim
+ # downcast
+ src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
+ is_fp4 = out_quant_type == torch.uint8
+ is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
+ assert is_fp4 or is_fp8
+ divisor = 2 if is_fp4 else 1
+ L = src_tensor.shape[-1]
+ if is_fp4:
+ assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
+ out_shape = src_tensor.shape[:-1] + (L // divisor, )
+ out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
+
+ out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
+ out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
+
+ if src_tensor.numel() > 0:
+ kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
+ kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
+ kernel_scale = out_scale.view(-1, out_scale.shape[-1])
+
+ BLOCK_OUT_DIM = 128
+ BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
+ grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
+ grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
+
+ _downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
+ *kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
+ *kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
+ DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
+
+ out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
+ out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
+ return out_quant_tensor, out_scale
+
+
+def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
+ """
+ Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
+
+ The function assumes that the tensors were quantized along the given axis.
+ It permutes the tensor so that the quantized axis is last, reshapes to 2D,
+ launches the Triton upcast kernel, and then unpermutes back to the original order.
+ """
+ ndim = tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ axis = axis if axis >= 0 else axis + ndim
+ assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
+ f"Got {tensor.ndim=} and {scale.ndim=}")
+ # dtype checks
+ assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
+ f"Invalid tensor dtype {tensor.dtype=}"
+ assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
+ assert target_dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {target_dtype=}"
+ # upcast
+ logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
+ tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
+ scale = scale.transpose(axis, scale.ndim - 1).contiguous()
+ out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=target_dtype, device=tensor.device)
+ reshaped_out = out.view(-1, out.shape[-1])
+ reshaped_tensor = tensor.view(-1, tensor.shape[-1])
+ reshaped_scale = scale.view(-1, scale.shape[-1])
+ BLOCK_OUT_DIM = 128
+ BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
+ blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
+ blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
+ _upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
+ *reshaped_scale.stride(), reshaped_tensor,
+ *reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
+ BLOCK_QUANT_DIM, num_warps=8)
+ out = out.transpose(axis, scale.ndim - 1).contiguous()
+ return out
+
+
+# ------------
+
+
+def right_shift_unsigned(x, shift):
+ # CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
+ return (x >> shift) & ((1 << (32 - shift)) - 1)
+
+
+def get_max_quant_val(dtype: torch.dtype):
+ d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
+ assert dtype in d
+ return d[dtype]
+
+
+def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
+ DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
+ """
+ Converts the src tensor to the output format specified by out_quant_type.
+ axis: The axis along which the tensors are contiguous and quantization is applied.
+ DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
+
+ Returns:
+ out_quant_tensor: Quantized tensor in mx format.
+ • For mxfp8, the output has the same shape as src_tensor.
+ • For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
+ scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
+ Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
+ where L is the original length along that axis.
+ """
+ # This should probably be packed into its own tiny class
+ ndim = src_tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ assert src_tensor.dtype in {torch.float32, torch.bfloat16,
+ torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
+
+ axis = axis if axis >= 0 else axis + ndim
+ is_fp4 = out_quant_type == torch.uint8
+ is_fp8 = "float8" in str(out_quant_type)
+ assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
+
+ device = src_tensor.device
+
+ # For mxfp4 conversion, we assume the contiguous axis length is even.
+ if is_fp4:
+ axis_shape = src_tensor.size(axis)
+ assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
+
+ # Permute the tensor so that the contiguous axis becomes the last dimension.
+ src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
+ axis_shape = src.shape[-1]
+
+ # Pad the axis to be divisible by 32, in case it is not.
+ next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
+ pad_amount = next_multiple - axis_shape
+ padded_src = F.pad(src, (0, pad_amount))
+ valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
+ padded_axis_shape = padded_src.size(-1) # now divisible by 32
+
+ # --- Compute per-group maximums for scale ---
+ # Set padded entries to -1 so they don’t affect the max.
+ abs_f = torch.abs(padded_src)
+ abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
+ # Reshape the last dimension into groups of 32.
+ new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
+ abs_groups = abs_f.view(*new_shape)
+ # Compute maximum along the group dimension (of size 32).
+ max_val, _ = abs_groups.max(dim=-1, keepdim=True)
+
+ # Choose a max quantization value depending on type.
+ max_quant_val = get_max_quant_val(out_quant_type)
+ dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
+
+ # Convert to int to round the FP32 scale, prior to quantization!
+ ds_int = dequant_scale.view(torch.int32)
+ if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
+ ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
+ else:
+ ds_int_rounded = ds_int & 0x7F800000
+ # Reinterpret back as float32.
+ dequant_scale_rounded = ds_int_rounded.view(torch.float32)
+
+ # Compute the quantization scale.
+ quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
+
+ # Quantize the tensor
+ orig_padded_shape = padded_src.shape
+ padded_src_groups = padded_src.view(*new_shape)
+ quant_tensor = padded_src_groups * quant_scale
+ # Reshape back to the original shape and trim padding
+ quant_tensor = quant_tensor.view(orig_padded_shape)
+ quant_tensor = quant_tensor[..., :axis_shape]
+
+ # Finally, convert the quantized tensor to the target format
+ if is_fp8:
+ # Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
+ quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
+ out_weight = quant_tensor.to(out_quant_type)
+ else:
+ assert is_fp4, f"Invalid output quantization type {out_quant_type}"
+ # For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
+ # First, reinterpret the quantized tensor bits.
+ q_int = quant_tensor.contiguous().view(torch.int32)
+ # Extract sign, exponent, and mantissa.
+ signs = q_int & 0x80000000
+ exponents = right_shift_unsigned(q_int, 23) & 0xFF
+ mantissas = q_int & 0x7FFFFF
+
+ E8_BIAS = 127
+ E2_BIAS = 1
+ # Adjust mantissas for subnormals.
+ mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
+ (E8_BIAS - exponents - 1), mantissas)
+ exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
+ e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
+ e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
+ e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
+
+ # Pack pairs of 4-bit values along the last dimension.
+ e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
+ evens = e2m1_value[..., 0]
+ odds = e2m1_value[..., 1]
+ out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
+
+ # --- Process and output the scale ---
+ dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
+ dq_scale = dq_scale.squeeze(-1)
+ out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
+ dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
+ return out_weight, dq_scale
+
+
+def cvt_e2m1_to_fp32(input_tensor):
+ assert input_tensor.dtype == torch.uint8
+
+ input_tensor = input_tensor.to(torch.int32)
+ evens = input_tensor & 0xF
+ odds = (input_tensor >> 4) & 0xF
+
+ vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
+ outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
+ outputs = torch.cat([outputs, -outputs])
+
+ even_floats = outputs[evens]
+ odd_floats = outputs[odds]
+ output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
+ output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
+ return output_tensor
+
+
+def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
+ """
+ Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
+ axis: The axis along which dequantization is applied.
+
+ Returns:
+ out_weight: Tensor in the target format.
+ """
+
+ ndim = tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
+ assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
+
+ # Permute the tensor and scale so that the quantization axis becomes the last dimension
+ axis = axis if axis >= 0 else axis + ndim
+ scale = scale.transpose(axis, scale.ndim - 1)
+ tensor = tensor.transpose(axis, tensor.ndim - 1)
+
+ dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
+ if tensor.dtype == torch.uint8:
+ fp32_tensor = cvt_e2m1_to_fp32(tensor)
+ else:
+ fp32_tensor = tensor.to(torch.float32)
+
+ logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
+ axis_shape = fp32_tensor.size(-1)
+ padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
+ pad_size = padded_axis_shape - axis_shape
+ padded_tensor = F.pad(fp32_tensor, (0, pad_size))
+
+ new_axis_shape = padded_tensor.shape[-1]
+ new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
+ padded_tensor = padded_tensor.view(*new_shape)
+ dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
+ out_padded = padded_tensor * dq_scale_padded
+
+ # Flatten back and remove the padded tail
+ out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
+ out_tensor = out_padded[..., :axis_shape]
+
+ out_tensor = out_tensor.to(target_dtype).contiguous()
+ out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
+
+ return out_tensor
+
+
+quantize_mxfp8_fn = _quantize_mxfp8_fn
diff --git a/vllm/kvprune/triton_kernels/numerics_details/mxfp_details/__init__.py b/vllm/kvprune/triton_kernels/numerics_details/mxfp_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py b/vllm/kvprune/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eac6467e2d8d49385106574ec073cf677c622e0
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
@@ -0,0 +1,158 @@
+import triton
+import triton.language as tl
+
+# fmt: off
+
+
+MXFP_BLOCK_SIZE = tl.constexpr(32)
+
+
+@triton.jit
+def _get_max_quant_val(dtype: tl.constexpr):
+ if dtype == tl.uint8:
+ return 6.0
+ elif dtype == tl.float8e5:
+ return 57344.0
+ elif dtype == tl.float8e4nv:
+ return 448.0
+ else:
+ tl.static_assert(False, f"Invalid {dtype=}")
+
+@triton.jit
+def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
+ BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
+
+ # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
+ f32_tensor = src_tensor.to(tl.float32)
+ abs_tensor = tl.abs(f32_tensor)
+ abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
+ abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
+ dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
+ if DEQUANT_SCALE_ROUNDING_MODE == 0:
+ # DequantScaleRoundingMode.ROUND_UP
+ # compute 2 ** ceil(log2(dequant_scale))
+ # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
+ # A corner case: exponent is 0xFF that will overflow but that's already
+ # NaN so assume we don't care.
+ dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
+ else:
+ # DequantScaleRoundingMode.ROUND_DOWN
+ # compute 2 ** floor(log2(dequant_scale))
+ assert DEQUANT_SCALE_ROUNDING_MODE == 1
+ dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
+ dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
+ quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
+
+ f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ quant_tensor = f32_tensor * quant_scale
+
+ # Reshape the tensors after scaling
+ quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
+ # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
+ quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
+ dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
+
+ # First, we simply extract the exponent part of the scales and store the result
+ dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
+ # Now we must convert the tensors to the mx format.
+ if is_fp8:
+ out_tensor = quant_tensor.to(mx_tensor_dtype)
+ else:
+ quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
+ signs = quant_tensor & 0x80000000
+ exponents = (quant_tensor >> 23) & 0xFF
+ mantissas = (quant_tensor & 0x7FFFFF)
+
+ # 0.25 <= x < 0.75 maps to 0.5, a denormal number
+ E8_BIAS = 127
+ E2_BIAS = 1
+ # Move implicit bit 1 at the beginning to mantissa for denormals
+ adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
+ mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
+
+ # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
+ exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
+
+ # Combine sign, exponent, and mantissa, while saturating
+ # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
+ e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
+ e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
+
+ e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
+ evens, odds = tl.split(e2m1_value)
+ out_tensor = evens | (odds << 4)
+
+ return out_tensor, dequant_scale_exponent
+
+@triton.jit
+def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
+ mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
+ src_ptr, stride_src_outer, stride_src_quant,
+ outer_dim, quant_dim,
+ BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
+
+ tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
+
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
+ tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
+ f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
+
+ src_dtype: tl.constexpr = src_ptr.dtype.element_ty
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
+ tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
+
+ outer_block = tl.program_id(0).to(tl.int64)
+ quant_block = tl.program_id(1).to(tl.int64)
+
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
+
+ start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
+ start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
+
+ src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
+ mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
+ mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
+
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
+ offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
+ offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
+
+ mask_src_quant = start_src_quant + offs_src_quant < quant_dim
+ mask_n = start_out + offs_outer < outer_dim
+ full_mask_src = mask_src_quant & mask_n
+
+ mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
+ full_mask_mxt = mask_mxt_quant & mask_n
+
+ scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
+ full_scale_mask = scale_mask_k & mask_n
+
+ src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
+ mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
+ mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
+ src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
+
+ out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
+ DEQUANT_SCALE_ROUNDING_MODE)
+
+ tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
+ tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
+
+
+@triton.jit(repr=lambda _: "_dequantize_mxfp8")
+def _quantize_mxfp8_fn(input, mask, pid=None):
+ return _compute_quant_and_scale(input, mask, tl.float8e4nv)
diff --git a/vllm/kvprune/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py b/vllm/kvprune/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e5f027fa986c06f402405a4a5047b649b3e1bfe
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
@@ -0,0 +1,125 @@
+import triton
+import triton.language as tl
+
+from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
+
+
+# fmt: off
+@triton.jit
+def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
+ stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
+ outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
+
+ tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
+ dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
+ tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 or dst_dtype == tl.float32)
+ tl.static_assert(
+ mx_tensor_dtype == tl.uint8
+ or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
+ "mx_tensor_ptr must be uint8 or float8 or dst_dtype")
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+
+ # Determine if we are dealing with fp8 types.
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
+
+ # Compute starting indices for the quantized (packed) dimension and the outer dimension.
+ outer_block = tl.program_id(0).to(tl.int64)
+ quant_block = tl.program_id(1).to(tl.int64)
+
+ start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
+ start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
+
+ mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
+ mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
+ out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
+
+ # Compute offsets and masks.
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
+ offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
+ offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
+
+ mask_outer = start_out + offs_outer < outer_dim
+ mask_out_quant = start_out_quant + offs_out_quant < quant_dim
+ full_mask_out = mask_out_quant & mask_outer
+
+ mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
+ full_mask_src = mask_src_quant & mask_outer
+
+ mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
+ full_scale_mask = mask_scale & mask_outer
+
+ tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
+ scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
+ out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
+
+ # Load the packed tensor and scale.
+ tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
+ scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
+
+ # Upcast the scale to the destination type.
+ if dst_dtype == tl.bfloat16:
+ dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
+ else:
+ dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
+ if dst_dtype == tl.float16:
+ dst_scale = dst_scale.to(tl.float16)
+
+ # Now upcast the tensor.
+ intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
+ if is_fp8:
+ dst_tensor = tensor.to(intermediate_dtype)
+ if tensor.dtype == tl.float8e5:
+ from_e_bits: tl.constexpr = 5
+ from_m_bits: tl.constexpr = 2
+ to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
+ to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
+
+ # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
+ non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
+ non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
+ dst_tensor = tl.where(
+ (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
+ (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
+ dst_tensor,
+ )
+ else:
+ assert is_fp4
+ dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
+ dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
+ dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
+ # e2m1
+ em0 = tensor & 0x07
+ em1 = tensor & 0x70
+ x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
+ x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
+ # Three cases:
+ # 1) x is normal and non-zero: Correct bias
+ x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
+ x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
+ # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
+ x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
+ x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
+ # 3) x is zero, do nothing
+ dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
+ dst_tensor = dst_tensor.to(dst_dtype)
+
+ # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
+ dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
+ scale = scale.reshape(dst_scale.shape)
+
+ out_tensor = dst_tensor * dst_scale
+ # Correct any NaNs encoded via the scale.
+ out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
+ out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
+ tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
diff --git a/vllm/kvprune/triton_kernels/proton_opts.py b/vllm/kvprune/triton_kernels/proton_opts.py
new file mode 100644
index 0000000000000000000000000000000000000000..a187eecc2d66659c278be3668e7865ee8a785694
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/proton_opts.py
@@ -0,0 +1,19 @@
+# proton options
+
+import os
+
+_launch_metadata_allow_sync = None
+
+
+def launch_metadata_allow_sync():
+ global _launch_metadata_allow_sync
+ if _launch_metadata_allow_sync is None:
+ _launch_metadata_allow_sync = not (
+ os.getenv("PROTON_LAUNCH_METADATA_NOSYNC") == "1"
+ )
+ return _launch_metadata_allow_sync
+
+
+def set_launch_metadata_allow_sync(allow_sync: bool):
+ global _launch_metadata_allow_sync
+ _launch_metadata_allow_sync = allow_sync
diff --git a/vllm/kvprune/triton_kernels/reduction_details/__init__.py b/vllm/kvprune/triton_kernels/reduction_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/reduction_details/reduce_bitmatrix.py b/vllm/kvprune/triton_kernels/reduction_details/reduce_bitmatrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..398482c321e119dfeb059fc420341ca58d1cceb1
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/reduction_details/reduce_bitmatrix.py
@@ -0,0 +1,133 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def vpopc(x):
+ """
+ Vertical popcount
+ Input x : uint32[..., N]
+ Output y : uint32[..., 32]
+ semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
+ credits: @apgoucher
+ """
+
+ tl.static_assert(
+ x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers"
+ )
+
+ BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
+ BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
+ if BLOCK_N >= 8:
+ sa1: tl.constexpr = 8
+ else:
+ sa1: tl.constexpr = BLOCK_N
+ # create 8-way sums in 4-bit fields:
+ y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1])
+ y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111
+ y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4]
+ if BLOCK_N >= 128:
+ sa2: tl.constexpr = 16
+ else:
+ sa2: tl.constexpr = BLOCK_N // sa1
+ # create 128-way sums in 8-bit fields:
+ y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4])
+ y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0F0F0F0F
+ y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
+ sa3: tl.constexpr = BLOCK_N // (sa1 * sa2)
+ # create N-way sums in 32-bit fields:
+ y = tl.reshape(y, [BATCHES, 1, sa3, 8])
+ y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000FF
+ y = tl.sum(y, 2) # [BATCHES, 4, 8]
+ y = tl.reshape(y, x.shape[:-1] + [32])
+ return y
+
+
+@triton.jit
+def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr):
+ pid = tl.program_id(0)
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+ tl.store(Ret + offs, 0)
+
+
+@triton.jit
+def _sum_bitmatrix_rows(
+ B,
+ shape_bm,
+ stride_bm: tl.constexpr,
+ stride_bn: tl.constexpr, # input bitmatrix
+ Ret,
+ Partials,
+ stride_pm: tl.constexpr,
+ stride_pn,
+ shape_pn, # outputs
+ BLOCK_MM: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ tl.static_assert(BLOCK_MM % BLOCK_M == 0)
+ TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M
+ if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr():
+ shape_bm = tl.load(shape_bm)
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM)
+ offs_n = pid_n * 32 + tl.arange(0, 32)
+ n_rows = shape_bm
+ bits = tl.load(
+ B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0
+ )
+ bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M])
+ ret = vpopc(bits) # [TILE_SIZE, 32]
+
+ offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
+
+ tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed")
+ tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)
+
+
+def clear_sums(n_cols, device, MEMSET_BLOCK=512):
+ cdiv = triton.cdiv
+ blocks = cdiv(n_cols, MEMSET_BLOCK)
+ out_ret = torch.empty((blocks * MEMSET_BLOCK,), device=device, dtype=torch.int32)
+ _sum_bitmatrix_memset[(blocks,)](out_ret, MEMSET_BLOCK)
+ return out_ret
+
+
+def sum_bitmatrix_rows(x, out_ret, partials_block_size=None):
+ assert partials_block_size is not None
+ cdiv = triton.cdiv
+ PARTIALS_BLOCK_M = partials_block_size
+ n_rows, n_cols = x.shape
+ n_rows_max = x.shape_max[0]
+ assert out_ret.shape == (n_cols,)
+
+ TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
+ BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
+
+ pids_x = cdiv(n_rows_max, BLOCK_MM)
+ pids_y = cdiv(n_cols, 32)
+ out_partials = torch.empty(
+ (pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32
+ )
+ out_partials = torch.transpose(out_partials, 0, 1)
+
+ # output tensors
+ _sum_bitmatrix_rows[(pids_x, pids_y)](
+ x.storage.data,
+ n_rows,
+ x.stride(0),
+ x.stride(1), # input
+ out_ret, # output [final reduction]
+ out_partials,
+ out_partials.stride(0),
+ out_partials.stride(1),
+ out_partials.shape[1], # output [partial reductions]
+ BLOCK_M=PARTIALS_BLOCK_M,
+ BLOCK_MM=BLOCK_MM, # constants
+ num_warps=8,
+ )
+
+ out_partials = out_partials[: cdiv(n_rows_max, PARTIALS_BLOCK_M), :]
+
+ return out_ret, out_partials
diff --git a/vllm/kvprune/triton_kernels/routing.py b/vllm/kvprune/triton_kernels/routing.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd736f6f0867b95c67a3c857b4f0bcc80c79fc0
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/routing.py
@@ -0,0 +1,521 @@
+import torch
+import triton
+from dataclasses import dataclass, field
+from .routing_details._routing_compute import _combined_routing_compute
+from .routing_details._routing_compute import _combined_routing_memset
+from .routing_details._routing_compute import _routing_clear_bitmatrix
+from .routing_details._expt_data import _expt_data_memset
+from .routing_details._expt_data import _expt_data_compute
+from .target_info import is_hip
+
+
+@dataclass
+class GatherIndx:
+ """
+ Indices for an operation that performs:
+ Y = X[src_idx, :]
+ """
+
+ # array such that `dst_idx[src_idx] = arange(0, N)`
+ src_indx: torch.Tensor
+ dst_indx: torch.Tensor
+
+
+@dataclass
+class ScatterIndx:
+ """
+ Indices for an operation that performs:
+ Y[dst_idx, :] = X
+ """
+
+ # array such that `dst_idx[src_idx] = arange(0, N)`
+ src_indx: torch.Tensor
+ dst_indx: torch.Tensor
+
+
+@dataclass
+class ExptData:
+ # hist[i] is the number of tokens routed to expert i
+ hist: torch.Tensor
+ # token_offs_raw[i] is the offset of the first token routed
+ # to expert i in an expert-sorted array
+ token_offs_raw: torch.Tensor
+ # token_offs_pad[block][i] is the offset of the first token routed
+ # to expert i in an expert-sorted array, assuming histogram
+ # rounded to the next multiple of `block`
+ token_offs_pad: dict[int, torch.Tensor]
+ # block_id_map[block] contain one value for each `pid`` launched by
+ # the matrix multiplication kernel launched with BLOCK_M=block:
+ # - the value is -1 if the `pid` has no work to do
+ # - otherwise, the value is two int16 (packed as an int32) that
+ # correspond respectively to (1) the expert assigned to
+ # the tokens processed by this pid; (2) the block assigned to the
+ # tokens processed by this pid (think `pid_m` in a regular matmul)
+ # see `test_routing.py` for a reference implementation and more details
+ block_pid_map: dict[int, torch.Tensor]
+
+ def __post_init__(self):
+ if self.hist is not None:
+ assert self.hist.dtype == torch.int32
+ if self.token_offs_raw is not None:
+ assert self.token_offs_raw.dtype == torch.int32
+ if self.token_offs_pad is not None:
+ for v in self.token_offs_pad.values():
+ assert v.dtype == torch.int32
+ if self.block_pid_map is not None:
+ for v in self.block_pid_map.values():
+ assert v.dtype == torch.int32
+
+
+@dataclass
+class RoutingData:
+ gate_scal: torch.Tensor = field()
+ expt_hist: torch.Tensor = field()
+ n_expts_tot: int = field()
+ n_expts_act: int = field()
+ expt_data: ExptData = None
+
+ # Used to make perf annotation cleaner: when we use expert sharding, we can
+ # use this to tell the "expected" number of local tokens per expert, because
+ # the actual number can vary per each input.
+ expected_tokens_per_expt: int = field(default=None)
+
+ def n_blocks(self, n_rows, block_m):
+ if n_rows <= self.n_expts_tot:
+ return n_rows
+ else:
+ return (
+ triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m)
+ + self.n_expts_tot
+ - 1
+ )
+
+
+# --------------------------
+# sort tokens by expert
+# --------------------------
+
+
+class SortTokens(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
+ HIST_BLOCK_M = 32
+ INDX_OFFS_BLOCK_M = 512
+ MEMSET_BLOCK = 1024
+ cdiv = triton.cdiv
+
+ device = expt_scal.device
+ dtype = expt_scal.dtype
+ n_tokens_raw, _ = bitmatrix.shape
+ n_tokens_pad, n_expts_act = expt_scal.shape
+ n_gates_pad = n_tokens_pad * n_expts_act
+
+ hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
+ hist = hist[:n_expts_tot]
+ assert hist.dtype == torch.int32
+ # scratchpad
+ expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
+ combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
+ # output
+ topk_indx = combined_indx[:n_gates_pad]
+ gate_indx = combined_indx[n_gates_pad:]
+ gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
+
+ (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1a,
+ blocks2a,
+ MEMSET_BLOCK_A,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ ) = _compute_expt_data_internal(hist, n_expts_tot, n_gates_pad)
+
+ blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
+ blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
+
+ _combined_routing_memset[(blocks1a + blocks1b,)](
+ combined_indx,
+ n_gates_pad * 2,
+ -1,
+ MEMSET_BLOCK,
+ hist, #
+ expt_offs,
+ hist.shape[0],
+ n_expts_tot,
+ partial_hist, # inputs
+ partial_hist.shape[0],
+ partial_hist.stride(0),
+ partial_hist.stride(1), # outputs
+ token_offs_combined,
+ token_offs_combined.stride(0), #
+ blocks1a,
+ block_pid_map, #
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
+ BLOCK_N=512,
+ BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
+ )
+
+ indx_offs = partial_hist
+
+ _combined_routing_compute[(blocks2a + blocks2b,)](
+ topk_indx,
+ gate_indx,
+ gate_scal, # outputs
+ expt_scal,
+ expt_indx,
+ indx_offs,
+ indx_offs.stride(0),
+ indx_offs.stride(1), # inputs
+ expt_offs,
+ n_tokens_raw, # input shape
+ HIST_BLOCK_M,
+ n_expts_act, # constants
+ hist,
+ token_offs_pad,
+ token_offs_pad.stride(0),
+ block_pid_map,
+ block_pid_map.stride(0), # outputs
+ block_m_log2_start,
+ block_m_num,
+ HIST2_BLOCK_M,
+ blocks2a, # etc.
+ )
+
+ ctx.n_tokens_raw = n_tokens_raw
+ ctx.n_tokens_pad = n_tokens_pad
+ ctx.n_expts_act = n_expts_act
+ ctx.save_for_backward(gate_indx)
+ return (
+ hist,
+ topk_indx,
+ gate_indx,
+ gate_scal,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ )
+
+ @staticmethod
+ def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
+ (gate_indx,) = ctx.saved_tensors
+ dgate_scal = dgate_scal[gate_indx]
+ dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
+ return dgate_scal, None, None, None
+
+
+def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix):
+ return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix)
+
+
+# --------------------------
+# prune routing
+# --------------------------
+
+
+class PruneRouting(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
+ from .compaction import compaction
+
+ n_tokens_pad = expt_scal.shape[0]
+ assert n_expts_tot % simulated_ep == 0
+ _routing_clear_bitmatrix[(n_tokens_pad,)](
+ bitmatrix.storage.data,
+ bitmatrix.storage.data.stride(0),
+ bitmatrix.storage.data.stride(1),
+ bitmatrix.storage.data.shape[1],
+ n_expts_tot // simulated_ep,
+ BLOCK_N=512,
+ )
+ # perform compaction to update expt_scal / expt_indx
+ expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
+ n_expts_tot = n_expts_tot // simulated_ep
+ bitmatrix.shape[-1] = n_expts_tot
+ return expt_scal, expt_indx, bitmatrix
+
+
+def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
+ return PruneRouting.apply(
+ expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep
+ )
+
+
+# --------------------------
+# expt_data
+# --------------------------
+
+
+def log2_power_of_two(x):
+ assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two"
+ return x.bit_length() - 1
+
+
+block_m_log2_start = 4
+
+
+def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
+ MEMSET_BLOCK = 512
+ HIST2_BLOCK_M = 512
+ device = expt_hist.device
+ n_expts_tot = n_expts_tot
+ cdiv = triton.cdiv
+ # block_ms are all powers-of-two between 16 and 128 (inclusive)
+ block_m_log2_end = 9 if is_hip() else 8
+ block_m_num = block_m_log2_end - block_m_log2_start
+ if n_gates <= n_expts_tot:
+ max_n_tiles = n_gates
+ else:
+ max_n_tiles = (
+ n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start)
+ )
+ # allocate memory
+ pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
+ dtype = torch.int32
+
+ token_offs_combined = torch.empty(
+ (block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device
+ )
+
+ token_offs_raw = token_offs_combined[0][: n_expts_tot + 1]
+ token_offs_pad = token_offs_combined[1:]
+
+ block_pid_map = torch.empty(
+ (block_m_num, pad(max_n_tiles)), dtype=dtype, device=device
+ )
+ memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
+ # compute outputs
+ token_offs_pad = token_offs_pad[:, : n_expts_tot + 1]
+ block_pid_map = block_pid_map[:, :max_n_tiles]
+
+ blocks1 = memset_grid + block_m_num + 1
+ blocks2 = n_expts_tot * block_m_num
+ return (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1,
+ blocks2,
+ MEMSET_BLOCK,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ )
+
+
+def _unpack_into_dict(x):
+ block_m_log2_end = block_m_log2_start + x.shape[0]
+ x = {
+ 2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))
+ }
+ return x
+
+
+def compute_expt_data(expt_hist, n_expts_tot, n_gates):
+ if expt_hist is None:
+ return ExptData(None, None, None, None)
+
+ # this just computes the kernel arguments:
+ (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1,
+ blocks2,
+ MEMSET_BLOCK,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ ) = _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates)
+
+ _expt_data_memset[(blocks1,)](
+ expt_hist,
+ n_expts_tot, #
+ token_offs_combined,
+ token_offs_combined.stride(0), #
+ block_pid_map, #
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK=MEMSET_BLOCK, # optimization parameters
+ num_warps=4,
+ )
+ _expt_data_compute[(blocks2,)](
+ expt_hist,
+ token_offs_pad,
+ token_offs_pad.stride(0),
+ block_pid_map,
+ block_pid_map.stride(0), # outputs
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK=HIST2_BLOCK_M, # optimization parameters
+ num_warps=4,
+ )
+
+ token_offs_pad = _unpack_into_dict(token_offs_pad)
+ block_pid_map = _unpack_into_dict(block_pid_map)
+ return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+
+# --------------------------
+# routing
+# --------------------------
+
+
+def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
+ (
+ hist,
+ topk_indx,
+ gate_indx,
+ gate_scal,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ ) = sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix)
+ token_offs_pad = _unpack_into_dict(token_offs_pad)
+ block_pid_map = _unpack_into_dict(block_pid_map)
+ expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+ # pack the matmul data structure
+ gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
+ scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
+ return (
+ RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data),
+ gather_indx,
+ scatter_indx,
+ )
+
+
+def routing(
+ logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None
+):
+ from .topk import topk
+
+ if sm_first:
+ logits = torch.softmax(logits, dim=-1)
+ expt_scal, expt_indx, bitmatrix = topk(
+ logits,
+ n_expts_act, #
+ apply_softmax=not sm_first,
+ y_indx=expt_indx,
+ n_rows=n_rows,
+ )
+ n_expts_tot = logits.shape[-1] // simulated_ep
+ # mutate bitmatrix
+ if simulated_ep > 1:
+ expt_scal, expt_indx, bitmatrix = prune_routing(
+ expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep
+ )
+
+ return routing_from_bitmatrix(
+ bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
+ )
+
+
+# --------------------------
+# torch reference
+# --------------------------
+
+
+def compute_expt_data_torch(hist, n_expts_tot, n_gates):
+ # offset for each experts
+ device = hist.device
+ token_offs_raw = torch.cumsum(hist, dim=0)
+ token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw))
+ token_offs_raw = token_offs_raw.int()
+ # maximum number of tiles for all values of `block_m` considered
+ block_ms = [16, 32, 64, 128]
+ if is_hip():
+ block_ms.append(256)
+ if n_gates <= n_expts_tot:
+ max_n_tiles = n_gates
+ else:
+ # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
+ # ceil_div(x, y): -(-x // y)
+ max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms))
+ # fill up tile offset/infos for each block
+ token_offs_pad = dict()
+ block_pid_map = dict()
+ for block_m in block_ms:
+ n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
+ token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
+ token_offs_pad[block_m] = torch.cat(
+ (torch.zeros(1, device=device), token_offs_pad[block_m])
+ )
+ token_offs_pad[block_m] = token_offs_pad[block_m].int()
+ # compute data required to drive ragged batch matmul
+ block_pid_map[block_m] = -torch.ones(
+ max_n_tiles, dtype=torch.int32, device=device
+ )
+
+ # for e in range(n_expts_tot):
+ # offset = token_offs_pad[block_m][e]
+ # for b in range(n_tiles[e]):
+ # block_pid_map[block_m][offset + b] = (b << 16) + e
+
+ col = torch.arange(max_n_tiles, device=device)
+ map_vals = (
+ torch.arange(n_expts_tot, device=device)[:, None] + (col << 16)[None, :]
+ )
+ map_idxs = token_offs_pad[block_m][:-1, None] + col[None, :]
+ mask = col[None, :] < n_tiles[:, None]
+ block_pid_map[block_m].index_put_((map_idxs[mask],), map_vals.int()[mask])
+ return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+
+def topk_torch(vals, k, expt_indx, has_user_provided_indx=False):
+ # topk of experts
+ if has_user_provided_indx:
+ tk_indx = expt_indx
+ else:
+ tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
+ tk_indx = tk_indx.long()
+ tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
+ tk_indx = tk_indx.int()
+ return tk_val, tk_indx
+
+
+def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
+ has_user_provided_indx = expt_indx is not None
+ n_gates_pad = logits.shape[0] * n_expts_act
+
+ if n_rows is not None:
+ logits = logits[:n_rows, :]
+ _, n_expts_tot = logits.shape
+ if sm_first:
+ logits = torch.softmax(logits, dim=-1)
+ expt_scal, expt_indx = topk_torch(
+ logits, n_expts_act, expt_indx, has_user_provided_indx=has_user_provided_indx
+ )
+ if not sm_first:
+ expt_scal = torch.softmax(expt_scal, dim=-1)
+ # sort each token's selections by expert
+ if not has_user_provided_indx:
+ expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
+ expt_scal = torch.gather(expt_scal, 1, sort_indices)
+ # flatten topk data
+ expt_scal = expt_scal.reshape(-1)
+ expt_indx = expt_indx.reshape(-1).to(torch.int32)
+ # sort by expert_id so experts are contiguous for the matmul
+ topk_indx = torch.argsort(expt_indx, stable=True)
+ gate_indx = torch.argsort(topk_indx, stable=True)
+ gate_scal = expt_scal[topk_indx]
+ hist = torch.histc(
+ expt_indx, bins=n_expts_tot, max=n_expts_tot - 1
+ ).int() # histogram of tokens over experts
+ # pack the matmul data structure
+ gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
+ scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
+ # compute expt_data
+ expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad)
+ return (
+ RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data),
+ gather_indx,
+ scatter_indx,
+ )
diff --git a/vllm/kvprune/triton_kernels/routing_details/__init__.py b/vllm/kvprune/triton_kernels/routing_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/routing_details/_expt_data.py b/vllm/kvprune/triton_kernels/routing_details/_expt_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd625868fb668d1a317e193ec4d5ec24a4da6206
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/routing_details/_expt_data.py
@@ -0,0 +1,75 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _cdiv_pow2(n, log2_k):
+ return (n + ((1 << log2_k) - 1)) >> log2_k
+
+
+@triton.jit
+def _expt_data_memset(
+ Hist,
+ n_expts_tot,
+ MDStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ if pid <= SIZES:
+ MDStarts += pid * tile_starts_stridem
+ x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty)
+ Tile_ptrs = MDStarts + tl.arange(0, BLOCK)
+ tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1)
+
+ for i in range(0, n_expts_tot + 1, BLOCK):
+ offs_n = tl.arange(0, BLOCK) + i
+ mask_n0 = offs_n < n_expts_tot
+ hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0)
+ hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2)
+
+ tile_starts = tl.cumsum(hist_tile, 0) + x_tile
+ x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty)
+ tl.store(Tile_ptrs, tile_starts - hist_tile)
+ Tile_ptrs += BLOCK
+
+ else:
+ pid -= SIZES + 1
+ TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK)
+ tl.store(TileInfoOut, 0xFFFFFFFF)
+
+
+@triton.jit
+def _expt_data_compute(
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ expt_id = pid // SIZES
+ buff_id = pid % SIZES
+
+ MDTileStarts += buff_id * tile_starts_stridem
+ MDTileInfo += buff_id * tile_info_stridem
+
+ n_tokens = tl.load(Hist + expt_id)
+ tile_dim_log2 = first_tile_dim_log2 + buff_id
+ n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2)
+
+ tile_off = tl.load(MDTileStarts + expt_id)
+ MDTileInfo += tile_off
+
+ for block_off in range(0, n_blocks, BLOCK):
+ block_offs = block_off + tl.arange(0, BLOCK)
+ data = (block_offs << 16) + expt_id
+ tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks)
diff --git a/vllm/kvprune/triton_kernels/routing_details/_routing_compute.py b/vllm/kvprune/triton_kernels/routing_details/_routing_compute.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b097cc1cc8c1117363f031cfc9a785b94a7d5ed
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/routing_details/_routing_compute.py
@@ -0,0 +1,241 @@
+import triton
+import triton.language as tl
+
+from ._expt_data import _expt_data_compute, _expt_data_memset
+
+
+@triton.jit
+def _routing_compute_expt_offs(
+ ExpertHist,
+ FinalExpertOffs,
+ hist_size, # histogram
+ BLOCK_N: tl.constexpr,
+):
+ loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N
+ x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty)
+ for i in range(loop_iterations):
+ offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < hist_size
+ hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
+ tok_starts = tl.cumsum(hist2, 0) - hist2 + x
+ x += tl.sum(hist2, 0)
+ tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
+ offs_n += BLOCK_N
+
+
+@triton.jit
+def _routing_compute_indx_offs(
+ PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id
+):
+ offs_m = tl.arange(0, BLOCK_M)
+ # iterate over input data
+ curr_sum = 0
+ for _ in range(0, shape_pm, BLOCK_M):
+ offs = offs_m * stride_pm + expt_id * stride_pn
+ curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
+ out = tl.cumsum(curr, 0) + curr_sum
+ curr_sum += tl.sum(curr, 0)
+ tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm)
+ offs_m += BLOCK_M
+
+
+@triton.jit
+def _keyed_add(x, y):
+ # we keep the key in the upper 16 bits of a uint32:
+ key_mask: tl.constexpr = 0xFFFF0000
+
+ kx = x & key_mask
+ ky = y & key_mask
+ z = tl.where(kx == ky, x + y - kx, y)
+ return z
+
+
+@triton.jit
+def _routing_compute_indx(
+ pid_m,
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+):
+ if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
+ n_tokens = tl.load(n_tokens)
+ n_gates = n_tokens * N_EXPTS_ACT
+
+ tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
+
+ local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
+ expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
+
+ # stable-sort by expert ID:
+ kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
+ kv_pairs = tl.sort(kv_pairs, 0)
+ expert = kv_pairs >> 16
+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xFFFF)
+ mask = expert != 0xFFFF
+ gate_scal = tl.load(ExptScal + offs, mask=mask)
+
+ # compute run lengths in expert-sorted order:
+ x = kv_pairs & 0xFFFF0000 | 0x00000001
+ expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
+ exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF
+
+ gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask)
+ gates += tl.load(TokensStart + expert, mask=mask)
+ gates += exclusive_run_lengths
+
+ tl.store(ScatterIndx + offs, gates, mask=mask)
+ tl.store(GatherIndx + gates, offs, mask=mask)
+ tl.store(GateScal + gates, gate_scal, mask=mask)
+
+
+@triton.jit
+def _combined_routing_compute(
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+ blocks2a,
+):
+ pid = tl.program_id(0)
+ if pid < blocks2a:
+ _expt_data_compute(
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES,
+ BLOCK,
+ )
+ else:
+ pid -= blocks2a
+ _routing_compute_indx(
+ pid,
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M,
+ N_EXPTS_ACT,
+ )
+
+
+@triton.jit
+def _routing_clear_bitmatrix(
+ Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr
+):
+ pid_m = tl.program_id(0)
+ cutoff_word = cutoff // 32
+ cutoff_bit = cutoff % 32
+ cutoff_mask = (1 << (cutoff_bit)) - 1
+ for start_n in range(0, shape_bn, BLOCK_N):
+ offs_n = start_n + tl.arange(0, BLOCK_N)
+ values = tl.load(
+ Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn
+ )
+ values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
+ values = tl.where(offs_n > cutoff_word, 0, values)
+ tl.store(
+ Bitmatrix + pid_m * stride_bm + offs_n * stride_bn,
+ values,
+ mask=offs_n < shape_bn,
+ )
+
+
+@triton.jit
+def _combined_routing_memset(
+ Indx,
+ size,
+ sentinel,
+ BLOCK: tl.constexpr,
+ ExpertHist,
+ FinalExpertOffs,
+ hist_size,
+ n_expts_tot,
+ PartialHist,
+ shape_pm,
+ stride_pm,
+ stride_pn,
+ MDStarts,
+ tile_starts_stridem,
+ blocks1a,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK_A: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ """
+ This kernel essentially combines 6 different pieces of functionality,
+ statically branching on the value of tl.program_id(0) to decide which
+ codepath to take.
+
+ pid == 0: create the token cumsum
+ 1 <= pid <= SIZES: create a tile cumsum
+ SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
+ blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
+ pid == blocks1a + n_expts_tot: compute_expt_offs
+ pid > blocks1a + n_expts_tot: initialise Indx to sentinel
+
+ As each of these is a relatively trivial workload, launching them from
+ this single trampoline is beneficial as they can execute on different
+ streaming multiprocesses in parallel.
+ """
+
+ pid = tl.program_id(0)
+
+ if pid < blocks1a:
+ _expt_data_memset(
+ ExpertHist,
+ n_expts_tot,
+ MDStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES,
+ BLOCK_A,
+ )
+ elif pid == n_expts_tot + blocks1a:
+ _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
+ elif pid < n_expts_tot + blocks1a:
+ _routing_compute_indx_offs(
+ PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a
+ )
+ else:
+ offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK)
+ mask = offs < size
+ tl.store(Indx + offs, sentinel, mask=mask)
diff --git a/vllm/kvprune/triton_kernels/specialize.py b/vllm/kvprune/triton_kernels/specialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcf44d70cb47664e6a837ec4cf0d28f04fbb1c16
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/specialize.py
@@ -0,0 +1,143 @@
+import inspect
+import re
+import textwrap
+import types
+import triton
+
+
+def cacheable(f):
+ """
+ A decorator that allow you to write something of the form:
+
+ @cacheable
+ def my_kernel(): return (expression dynamically defining a kernel)
+
+ such that it interacts gracefully with triton cache and preload.
+ """
+
+ g = f()
+ g.fn.__name__ = f.__name__
+ g.fn.__module__ = f.__module__
+ g.fn.__qualname__ = f.__qualname__
+ g.__name__ = f.__name__
+ g.__module__ = f.__module__
+ g.__qualname__ = f.__qualname__
+ g._fn_name = f"{f.__module__}.{f.__qualname__}"
+ return g
+
+
+def define_kernel(src, module, attrs=None, **extra_globals):
+ """
+ Dynamically create a Triton function or kernel from a src string,
+ linking any symbols in the kernel to objects specified by extra_globals.
+ """
+
+ # create templace function
+ def _empty_fn():
+ pass
+
+ gdict = dict(**(_empty_fn.__globals__))
+ gdict.update(extra_globals)
+ f = types.FunctionType(_empty_fn.__code__, gdict)
+ f.__module__ = module.__name__
+
+ src = textwrap.dedent(src)
+ src = src[src.find("def ") :]
+
+ stored_functions = []
+ function_name = src[4:].split("(")[0].strip()
+
+ exec_globals = gdict
+ exec_globals.update({"stored_functions": stored_functions})
+ exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
+
+ f.__signature__ = inspect.signature(stored_functions[0])
+ f.__name__ = function_name
+ f.__doc__ = stored_functions[0].__doc__
+
+ if attrs is None:
+ attrs = dict()
+ f = triton.JITFunction(f, **attrs)
+ f._unsafe_update_src(src)
+ return f
+
+
+def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
+ if name is None:
+ name = f"{fn.__name__}"
+ # Get original source code
+ src = inspect.getsource(fn.fn)
+ src = textwrap.dedent(src)
+ lines = src.split("\n")
+ # Skip decorator and def line
+ def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def"))
+ # separate header vs body LOC
+ header_end = def_idx
+ while not lines[header_end].rstrip().endswith(":"):
+ header_end += 1
+ body_lines = lines[header_end + 1 :]
+ header_lines = lines[def_idx : header_end + 1]
+ # clean-up header
+ header_clean = [
+ l.split("#", 1)[0].strip() # keep code, discard comment
+ for l in header_lines
+ if l.split("#", 1)[0].strip() # skip blank‑after‑comment lines
+ ]
+ # decompose arguments
+ header_src = " ".join(header_clean) # turn it into a single line
+ m = re.search(r"\((.*)\)\s*:", header_src)
+ if not m:
+ raise ValueError("Could not parse function header")
+ args_str = m.group(1)
+ args = [arg.strip() for arg in args_str.split(",") if arg.strip()]
+ non_specialized_args = []
+ for arg in args:
+ arg_key = arg.split(":")[0].split("=")[0].strip()
+ new_args = tuples.get(arg_key, [arg])
+ if arg_key not in constants:
+ non_specialized_args += new_args
+ # add global symbols
+ spec_fns = {
+ v.__name__: v
+ for k, v in constants.items()
+ if isinstance(v, triton.runtime.jit.JITFunction)
+ }
+ globals = spec_fns | fn.get_capture_scope()
+ # build new source code and define kernel dynamically
+ new_signature = f"def {name}({', '.join(non_specialized_args)}):"
+ constexpr_lines = [
+ f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}"
+ for key, value in constants.items()
+ ]
+ tuple_lines = [
+ f" {key} = {'(' + ','.join(value) + (',' if len(value) >= 1 else '') + ')'}"
+ for key, value in tuples.items()
+ ]
+ new_src = "\n".join(
+ ["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines
+ )
+ # find function parameters
+ sig = inspect.signature(triton.runtime.jit.JITFunction.__init__)
+ params = list(sig.parameters.values())[2:]
+ attrs = {param.name: getattr(fn, param.name, param.default) for param in params}
+
+ # make a new repr which appends the repr of the specialized functions.
+ base_repr = attrs["repr"]
+
+ def new_repr(specialization):
+ ret = base_repr(specialization)
+ for spec_fn in spec_fns.values():
+ spec_repr = spec_fn.repr(None)
+ if spec_repr:
+ spec_repr = spec_repr.strip("_")
+ if spec_repr:
+ ret += f"_{spec_repr}"
+ return ret
+
+ attrs["repr"] = new_repr
+
+ if do_not_specialize:
+ attrs["do_not_specialize"] = do_not_specialize
+ ret = define_kernel(new_src, module, attrs, **globals)
+ return ret
diff --git a/vllm/kvprune/triton_kernels/swiglu.py b/vllm/kvprune/triton_kernels/swiglu.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c4b69fa49b4f8fb419bacac72dfe42973584b8e
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/swiglu.py
@@ -0,0 +1,99 @@
+from dataclasses import dataclass
+from vllm.kvprune.triton_kernels.numerics import InFlexData, OutFlexData
+import torch
+import triton
+from .swiglu_details._swiglu import _swiglu, _swiglu_fn
+from vllm.kvprune.triton_kernels import target_info
+
+
+@dataclass(frozen=True)
+class FlexCtx:
+ out_data: OutFlexData = OutFlexData()
+ inp_data: InFlexData = InFlexData()
+ saturate_inf: bool = False
+
+
+@dataclass(frozen=True)
+class PrecisionConfig:
+ limit: float
+ flex_ctx: FlexCtx = FlexCtx()
+
+
+swiglu_fn = _swiglu_fn
+
+
+class SwiGLU(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, a, alpha, precision_config, routing_data):
+ N = a.shape[-1]
+ M = a.numel() // N
+ assert a.stride()[-1] == 1
+ assert a.shape[-1] % 2 == 0
+ out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
+ flex_ctx = precision_config.flex_ctx
+ # optimization hyperparameters
+ BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
+ num_warps = 4
+ kwargs = {"maxnreg": 64} if not target_info.is_hip() else {}
+ # launch semi-persistent kernel
+ N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
+ num_sms = target_info.num_sms()
+ if routing_data is not None:
+ waves_per_sm = 32 if target_info.is_hip() else 128
+ num_pid = num_sms * (waves_per_sm // num_warps)
+ M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
+ grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms),)
+ else:
+ M_BLOCKS = triton.cdiv(M, BLOCK_M)
+ if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
+ grid = (8 * num_sms,)
+ else:
+ grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms),)
+ n_tokens = None
+ if routing_data is not None:
+ n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
+ _swiglu[grid](
+ flex_ctx.out_data.reinterpret(out),
+ flex_ctx.out_data.expected_scale,
+ flex_ctx.out_data.actual_scale,
+ flex_ctx.out_data.checksum_scale,
+ flex_ctx.inp_data.reinterpret(a),
+ flex_ctx.inp_data.scale,
+ alpha,
+ M,
+ N // 2,
+ a.shape[-1],
+ 1,
+ out.shape[-1],
+ 1,
+ precision_config.limit,
+ n_tokens,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ EVEN_N=(N // 2) % BLOCK_N == 0,
+ M_BLOCKS=M_BLOCKS,
+ N_BLOCKS=N_BLOCKS,
+ flexpoint_saturate_inf=flex_ctx.saturate_inf,
+ num_warps=num_warps,
+ **kwargs,
+ )
+ out = out.view(a.shape[:-1] + out.shape[-1:])
+ return out
+
+
+def swiglu(a, alpha, precision_config, routing_data=None):
+ return SwiGLU.apply(a, alpha, precision_config, routing_data)
+
+
+def swiglu_torch(a, alpha, precision_config):
+ limit = precision_config.limit
+ a_gelu = a[..., ::2]
+ if limit is not None:
+ a_gelu = a_gelu.clamp(max=limit)
+ a_linear = a[..., 1::2]
+ if limit is not None:
+ a_linear = a_linear.clamp(min=-limit, max=limit)
+
+ out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
+ out = out_gelu * (a_linear + 1)
+ return out
diff --git a/vllm/kvprune/triton_kernels/swiglu_details/__init__.py b/vllm/kvprune/triton_kernels/swiglu_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/swiglu_details/_swiglu.py b/vllm/kvprune/triton_kernels/swiglu_details/_swiglu.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bdd8410513f148dfe4e3f3b74c12950a1554bde
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/swiglu_details/_swiglu.py
@@ -0,0 +1,141 @@
+from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
+ load_scale,
+ float_to_flex,
+ update_scale,
+)
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def clip(x, limit, clip_lower: tl.constexpr):
+ res = tl.minimum(x, limit)
+ if clip_lower:
+ res = tl.maximum(-limit, res)
+ return res
+
+
+@triton.jit
+def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
+ return tl.max(
+ tl.reshape(
+ tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True
+ ),
+ axis=1,
+ )
+
+
+def swiglu_repr(specialization):
+ signature = specialization.signature
+ constants = specialization.constants
+ convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype
+ dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]])
+ blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]])
+ return f"_swiglu_{dtypes}_{blocks}"
+
+
+def swiglu_launch_metadata(grid, kernel, args):
+ M, N = args["M"], args["N"]
+ ret = dict()
+ ret["name"] = f"{kernel.name} [M = {M}, N = {N}]"
+ A, Out = args["A"], args["Out"]
+ ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size()
+ return ret
+
+
+@triton.jit
+def compute_swiglu(gelu, linear, scale, alpha, limit):
+ gelu = gelu.to(tl.float32) * scale
+ if limit is not None:
+ gelu = clip(gelu, limit, clip_lower=False)
+ linear = linear.to(tl.float32) * scale
+ if limit is not None:
+ linear = clip(linear, limit, clip_lower=True)
+ s = gelu / (1 + tl.exp(-alpha * gelu))
+ return tl.fma(s, linear, s) # (s * (linear + 1))
+
+
+@triton.jit(repr=lambda _: "_swiglu")
+def _swiglu_fn(input, alpha, limit):
+ gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
+ return compute_swiglu(gelu, linear, 1.0, alpha, limit)
+
+
+@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
+def _swiglu(
+ Out,
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale,
+ A,
+ AScale,
+ alpha,
+ M,
+ N,
+ stride_am,
+ stride_an,
+ stride_outm,
+ stride_outn,
+ limit: tl.constexpr,
+ NTokens,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ EVEN_N: tl.constexpr,
+ M_BLOCKS,
+ N_BLOCKS,
+ flexpoint_saturate_inf: tl.constexpr,
+):
+ if NTokens is not None:
+ M = tl.load(NTokens)
+ M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
+
+ local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32)
+
+ a_scale = load_scale(AScale)
+ out_expected_scale = load_scale(OutExpectedScale)
+
+ for pid in tl.range(
+ tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2
+ ):
+ pid_m = pid // N_BLOCKS
+ pid_n = pid % N_BLOCKS
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_m = off_m < M
+ mask_n = off_n < N
+ packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2
+ packed_mask_n = packed_off_n < N
+ packed_mask_n = tl.max_constancy(packed_mask_n, [16])
+ # load a
+ packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N)
+ packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an
+ if EVEN_N:
+ a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.0)
+ else:
+ if pid_n * BLOCK_N + BLOCK_N <= N:
+ a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.0)
+ else:
+ packed_mask = mask_m[:, None] & packed_mask_n[None, :]
+ a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.0)
+ a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
+ out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit)
+ # update flexpoint stats and divide by scale
+ # we don't need masking because of the `other` when loading `A`
+ if OutActualScale is not None:
+ absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads())
+ local_max = tl.maximum(local_max, absmax)
+ out = float_to_flex(
+ out,
+ out_expected_scale,
+ None, # ActualScale: local absmax is tracked and updated after the loop
+ OutChecksumScale,
+ None,
+ Out,
+ flexpoint_saturate_inf,
+ )
+ mask = mask_m[:, None] if EVEN_N else mask_m[:, None] & mask_n[None, :]
+ tl.store(
+ Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask
+ )
+
+ update_scale(local_max, OutActualScale, Out)
diff --git a/vllm/kvprune/triton_kernels/target_info.py b/vllm/kvprune/triton_kernels/target_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ae4303c512241455cc8aed5a85a2edb1c1c8eb
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/target_info.py
@@ -0,0 +1,54 @@
+import torch
+import triton
+import triton.language as tl
+
+from triton.language.target_info import (
+ cuda_capability_geq,
+ is_cuda,
+ is_hip,
+ is_hip_cdna3,
+ is_hip_cdna4,
+)
+
+__all__ = [
+ "cuda_capability_geq",
+ "get_cdna_version",
+ "has_tma_gather",
+ "has_native_mxfp",
+ "is_cuda",
+ "is_hip",
+ "is_hip_cdna3",
+ "is_hip_cdna4",
+ "num_sms",
+]
+
+
+@triton.constexpr_function
+def get_cdna_version():
+ """
+ Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
+ only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD
+ hardware or unsupported architecture
+ """
+ target = tl.target_info.current_target()
+ if target.backend != "hip":
+ return -1
+ if target.arch == "gfx942":
+ return 3
+ if target.arch == "gfx950":
+ return 4
+ return -1
+
+
+@triton.constexpr_function
+def has_tma_gather():
+ return cuda_capability_geq(10, 0)
+
+
+@triton.constexpr_function
+def has_native_mxfp():
+ return cuda_capability_geq(10, 0)
+
+
+def num_sms():
+ return torch.cuda.get_device_properties(0).multi_processor_count
diff --git a/vllm/kvprune/triton_kernels/tensor.py b/vllm/kvprune/triton_kernels/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6992e942365b2cf52701be8d013f174dd4458784
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor.py
@@ -0,0 +1,227 @@
+from dataclasses import dataclass, fields
+from typing import Type
+
+import torch
+from triton.tools.tensor_descriptor import TensorDescriptor
+from triton.tools.ragged_tma import create_ragged_descriptor
+
+from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
+from .target_info import cuda_capability_geq
+from .tensor_details.layout import Layout, StridedLayout
+
+
+@dataclass
+class Storage:
+ data: torch.Tensor
+ layout: Layout = None
+
+ def __post_init__(self):
+ assert isinstance(self.data, torch.Tensor)
+ if self.layout is None:
+ self.layout = StridedLayout(self.data.shape)
+
+ @property
+ def device(self):
+ return self.data.device
+
+ def is_tma_compliant(self):
+ # TMAs didn't exist until Hopper
+ if not cuda_capability_geq(9, 0):
+ return False
+ # TMAs only exist for 2D, 3D, 5D inputs
+ if len(self.data.shape) not in [2, 3, 5]:
+ return False
+ # TMAs need at most one stride equal to 1
+ # and all other strides divisble by 16
+ strides = list(self.data.stride())
+ try:
+ major_dim = strides.index(1)
+ except ValueError:
+ major_dim = -1
+ ndim = self.data.ndim
+ bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
+ compliant = [
+ strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim
+ ]
+ return all(compliant)
+
+ def make_dense_tma(self, block_shape, transpose=False):
+ strides = list(self.data.stride())
+ shape = list(self.data.shape)
+ transpose = self.data.stride()[-1] != 1
+ if transpose:
+ block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
+ shape = shape[:-2] + [shape[-1], shape[-2]]
+ strides = strides[:-2] + [strides[-1], strides[-2]]
+ if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
+ indx = strides.index(1)
+ block_shape[indx] = block_shape[indx] // 2
+ if shape[-1] % 128 != 0:
+ raise ValueError(
+ "inner shape need to be multiple of 128 for "
+ "mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
+ )
+ block_shape = self.layout.swizzle_block_shape(block_shape)
+ return TensorDescriptor(self.data, shape, strides, block_shape)
+
+ def make_tma(self, block_shape, mode, transpose=False):
+ if mode in ["dense", "gather", "scatter"]:
+ return self.make_dense_tma(block_shape, transpose)
+ assert mode == "ragged"
+ ragged_dim = len(self.data.shape) - 2
+ return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)
+
+
+@dataclass
+class IntegerType:
+ bitwidth: int
+
+
+@dataclass
+class FloatType:
+ bitwidth_exponent: int
+ bitwidth_mantissa: int
+ is_signed: bool
+
+ def __post_init__(self):
+ self.bitwidth = (
+ int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
+ )
+
+
+BIT = IntegerType(1)
+FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
+
+
+def bitwidth(type: IntegerType | FloatType | torch.dtype):
+ if isinstance(type, torch.dtype):
+ return type.itemsize * 8
+ return type.bitwidth
+
+
+@dataclass
+class Tensor:
+ storage: Storage | torch.Tensor
+ dtype: IntegerType | FloatType | torch.dtype = None
+ shape: list[int] | None = None
+ shape_max: list[int] | None = None
+
+ def __post_init__(self):
+ # set storage
+ if isinstance(self.storage, torch.Tensor):
+ self.storage = Storage(self.storage)
+ # initialize dtype
+ if self.dtype is None:
+ self.dtype = self.storage.data.dtype
+ if bitwidth(self.dtype) < 8 and self.shape is None:
+ raise ValueError("shape must be provided for sub-byte types")
+ # initialize shape
+ if self.shape is None:
+ self.shape = list(self.storage.data.shape)
+ # validate shape: all elements must be `int` or numel-1 `torch.Tensor`
+ is_int = lambda s: isinstance(s, int)
+ is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
+ assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
+ # initialize shape_max
+ if self.shape_max is None:
+ self.shape_max = [None] * len(self.shape)
+ for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
+ if smax is not None and not is_int(smax):
+ raise ValueError(
+ f"shape_max[{i}] must be `int` or `None`; got {type(smax)}"
+ )
+ if smax is None:
+ self.shape_max[i] = s
+ # validate shape_max: all elements must be `int`
+ assert all(map(is_int, self.shape_max))
+
+ # torch compatibility layer
+ @property
+ def ndim(self):
+ return len(self.shape)
+
+ @property
+ def device(self):
+ return self.storage.device
+
+ def stride(self, i=None):
+ return self.storage.data.stride() if i is None else self.storage.data.stride(i)
+
+ def data_ptr(self):
+ return self.storage.data.data_ptr()
+
+ def numel(self):
+ return self.storage.data.numel()
+
+ def element_size(self):
+ return bitwidth(self.dtype) // 8
+
+ @property
+ def data(self):
+ t = self.storage
+ return t.data if isinstance(t, Storage) else t
+
+ def dim(self):
+ return self.ndim
+
+ def size(self, i=None):
+ if i is None:
+ return self.shape
+ return self.shape[i]
+
+
+@dataclass
+class Bitmatrix(Tensor):
+ """
+ Represents a boolean matrix in a packed format where each element occupies
+ a single bit of memory.
+
+ _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
+ with the actual bitmatrix to avoid having to launch a separate memset
+ kernel when we call Bitmatrix::sum().
+ """
+
+ scratchpad: torch.Tensor = None
+
+ def __init__(self, storage, shape, shape_max=None, scratchpad=None):
+ super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
+ self.scratchpad = scratchpad
+
+ def sum(self, partials_block_size):
+ _, n_cols = self.shape
+ dev = self.device
+ if self.scratchpad is None:
+ self.scratchpad = clear_sums(n_cols, dev)
+ out_ret = self.scratchpad[:n_cols]
+ self.scratchpad = None # throw error if we try to sum again
+ return sum_bitmatrix_rows(self, out_ret, partials_block_size)
+
+
+def get_layout(tensor: torch.Tensor | Tensor | None):
+ if tensor is None:
+ return None
+ if isinstance(tensor, Tensor):
+ return tensor.storage.layout
+ return StridedLayout
+
+
+def wrap_torch_tensor(torch_tensor, dtype=None):
+ if dtype is None:
+ dtype = torch_tensor.dtype
+ shape = list(torch_tensor.shape)
+ shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(
+ dtype
+ )
+ return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
+
+
+def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
+ assert isinstance(tensor, Tensor)
+ old_storage = tensor.storage
+ old_data = old_storage.layout.unswizzle_data(old_storage.data)
+ new_layout = layout_cls(old_data.shape, **layout_kwargs)
+ new_data = new_layout.swizzle_data(old_data)
+ attrs = {
+ k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"
+ }
+ return Tensor(Storage(new_data, new_layout), **attrs)
diff --git a/vllm/kvprune/triton_kernels/tensor_details/__init__.py b/vllm/kvprune/triton_kernels/tensor_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout.py b/vllm/kvprune/triton_kernels/tensor_details/layout.py
new file mode 100644
index 0000000000000000000000000000000000000000..98122f3517a593b1bc479c43d8d64fb64191a7af
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor_details/layout.py
@@ -0,0 +1,40 @@
+from .layout_details.base import Layout
+from .layout_details.blackwell_scale import BlackwellMXScaleLayout
+from .layout_details.blackwell_value import BlackwellMXValueLayout
+from .layout_details.hopper_scale import HopperMXScaleLayout
+from .layout_details.hopper_value import HopperMXValueLayout
+from .layout_details.cdna4_scale import CDNA4MXScaleLayout
+from .layout_details.strided import StridedLayout
+from ..target_info import cuda_capability_geq, is_hip_cdna4
+
+__all__ = [
+ "Layout",
+ "BlackwellMXValueLayout",
+ "BlackwellMXScaleLayout",
+ "HopperMXScaleLayout",
+ "HopperMXValueLayout",
+ "CDNA4MXScaleLayout",
+ "StridedLayout",
+]
+
+
+def make_default_matmul_mxfp4_w_layout(mx_axis: int):
+ if cuda_capability_geq(10):
+ # return StridedLayout, dict()
+ return BlackwellMXValueLayout, dict()
+ elif cuda_capability_geq(9):
+ return HopperMXValueLayout, {"mx_axis": mx_axis}
+ else:
+ return StridedLayout, dict()
+
+
+def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
+ if is_hip_cdna4():
+ return CDNA4MXScaleLayout, dict()
+ else:
+ if cuda_capability_geq(10):
+ return BlackwellMXScaleLayout, dict()
+ elif cuda_capability_geq(9):
+ return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
+
+ return StridedLayout, dict()
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout_details/__init__.py b/vllm/kvprune/triton_kernels/tensor_details/layout_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout_details/base.py b/vllm/kvprune/triton_kernels/tensor_details/layout_details/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d23dab8f42abd1d87bf77c08c3b64c1efe4d3e3
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor_details/layout_details/base.py
@@ -0,0 +1,18 @@
+from abc import ABC, abstractmethod
+
+
+class Layout(ABC):
+ def __init__(self, shape) -> None:
+ self.initial_shape = shape
+
+ @abstractmethod
+ def swizzle_data(self, data):
+ pass
+
+ @abstractmethod
+ def unswizzle_data(self, data):
+ pass
+
+ @abstractmethod
+ def swizzle_block_shape(self, block_shape):
+ pass
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/vllm/kvprune/triton_kernels/tensor_details/layout_details/blackwell_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..a54a300cfdd906dec1a78aaf4f48259529659cdf
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor_details/layout_details/blackwell_scale.py
@@ -0,0 +1,81 @@
+import math
+import triton
+import triton.language as tl
+import torch
+from .base import Layout
+
+SWIZZLE_ALIGN_INNER = 8
+SWIZZLE_SIZE_INNER = 4
+SWIZZLE_SIZE_OUTER = 128
+
+
+class BlackwellMXScaleLayout(Layout):
+ name: str = "BLACKWELL_SCALE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+ (
+ *self.leading_shape,
+ self.K,
+ self.N,
+ ) = shape
+ self.B = math.prod(self.leading_shape)
+ self.ALIGN_K = 8
+ self.ALIGN_N = 128
+ self.SWIZZLE_K = 4
+ self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K
+ self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N
+
+ def swizzle_data(self, data):
+ data = torch.nn.functional.pad(
+ data, (0, self.N_pad - self.N, 0, self.K_pad - self.K)
+ )
+ data = data.transpose(-1, -2).contiguous()
+ data = data.reshape(
+ self.B,
+ self.N_pad // self.ALIGN_N,
+ self.ALIGN_N // 32,
+ 32,
+ self.K_pad // self.SWIZZLE_K,
+ self.SWIZZLE_K,
+ )
+ data = data.transpose(2, 4).contiguous()
+ data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
+ return data
+
+ def unswizzle_data(self, data):
+ data = data.reshape(
+ self.B,
+ self.N_pad // self.ALIGN_N,
+ self.K_pad // self.SWIZZLE_K,
+ 32,
+ self.ALIGN_N // 32,
+ self.SWIZZLE_K,
+ )
+ data = data.transpose(2, 4)
+ data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad)
+ data = data.transpose(-1, -2)
+ return data[..., : self.K, : self.N]
+
+ def swizzle_block_shape(self, block_shape):
+ MX_PACK_DIVISOR = 32
+ MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
+ return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
+
+
+@triton.jit
+def unswizzle_mx_scale_bw(
+ x,
+ SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER,
+ SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER,
+ ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER,
+):
+ shape_0: tl.constexpr = x.shape[0]
+ shape_1: tl.constexpr = x.shape[1]
+ tl.static_assert(shape_1 % SIZE_OUTER == 0)
+ tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER)
+ x = x.reshape(
+ shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER
+ )
+ x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER)
+ return x
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout_details/blackwell_value.py b/vllm/kvprune/triton_kernels/tensor_details/layout_details/blackwell_value.py
new file mode 100644
index 0000000000000000000000000000000000000000..622744888b91eb0c99ba6d9c7fb150acb2d89702
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor_details/layout_details/blackwell_value.py
@@ -0,0 +1,37 @@
+import torch
+from .base import Layout
+
+
+class BlackwellMXValueLayout(Layout):
+ name: str = "BLACKWELL_VALUE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+ self.shape = shape
+
+ def swizzle_data(self, data):
+ # permutation needed to make `data` row major
+ to_row_major = sorted(range(data.ndim), key=lambda d: (data.stride(d), d))[::-1]
+ # permutation needed to retrieve original order
+ inv = [0] * data.ndim
+ for i, d in enumerate(to_row_major):
+ inv[d] = i
+ # leading dimension must be padded to be aligned to 128
+ align_dim = lambda x: (x + 128 - 1) // 128 * 128
+ major_dim = data.stride().index(1)
+ pad = align_dim(data.shape[major_dim]) - data.shape[major_dim]
+ data = torch.nn.functional.pad(data.permute(to_row_major), (0, pad)).permute(
+ inv
+ )
+ return data
+
+ def unswizzle_data(self, data: torch.Tensor):
+ # Trim padding along all dims back to the original shape recorded at init.
+ assert data.ndim == len(self.shape), (
+ "Rank mismatch between data and recorded shape"
+ )
+ sizes = [min(data.size(i), self.shape[i]) for i in range(data.ndim)]
+ return data[tuple(slice(0, s) for s in sizes)]
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout_details/cdna4_scale.py b/vllm/kvprune/triton_kernels/tensor_details/layout_details/cdna4_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..beecaee3e12d93294df0365010966e15d625635e
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor_details/layout_details/cdna4_scale.py
@@ -0,0 +1,50 @@
+import triton
+import triton.language as tl
+from .base import Layout
+
+NON_K_PRESHUFFLE_BLOCK_SIZE = 32
+
+
+class CDNA4MXScaleLayout(Layout):
+ name: str = "CDNA4_SCALE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+
+ def swizzle_data(self, data):
+ block_shape = data.shape
+ SCALE_K = block_shape[-2]
+ N = block_shape[-1]
+ data = data.transpose(-1, -2)
+ data = data.view(
+ -1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1
+ )
+ data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
+ if len(block_shape) == 3:
+ E = block_shape[0]
+ data = data.reshape(E, N // 32, SCALE_K * 32)
+ else:
+ assert len(block_shape) == 2
+ data = data.reshape(N // 32, SCALE_K * 32)
+ return data.transpose(-1, -2)
+
+ def unswizzle_data(self, data):
+ raise NotImplementedError()
+
+ def swizzle_block_shape(self, block_shape):
+ SCALE_K = block_shape[-2]
+ N = block_shape[-1]
+ return block_shape[:-2] + [N // 32, SCALE_K * 32]
+
+
+@triton.jit
+def unswizzle_mx_scale_cdna4(
+ x,
+ BLOCK_N: tl.constexpr,
+ MX_SCALE_BLOCK_K: tl.constexpr,
+ N_PRESHUFFLE_FACTOR: tl.constexpr = NON_K_PRESHUFFLE_BLOCK_SIZE,
+):
+ x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1)
+ x = x.permute(0, 5, 3, 1, 4, 2, 6)
+ x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K)
+ return x
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout_details/hopper_scale.py b/vllm/kvprune/triton_kernels/tensor_details/layout_details/hopper_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ef61e889b2c4c38bad4832bd160734a4b492b26
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor_details/layout_details/hopper_scale.py
@@ -0,0 +1,91 @@
+import torch
+import triton
+import triton.language as tl
+from .base import Layout
+
+
+class HopperMXScaleLayout(Layout):
+ name: str = "HOPPER_SCALE"
+
+ def __init__(self, shape, mx_axis, num_warps=8) -> None:
+ assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
+ super().__init__(shape)
+ self.mx_axis = mx_axis
+ self.num_warps = num_warps
+ *self.leading_shape, _, _ = shape
+
+ def _maybe_mT(self, data):
+ if self.mx_axis == len(self.leading_shape):
+ return data.contiguous().mT
+ return data
+
+ def swizzle_data(self, data):
+ data = self._maybe_mT(data).contiguous()
+ *batch, M, K = data.shape
+ SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
+ SWIZZLE_ALIGN_K = 2
+ pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
+ pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
+ data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
+ *batch, M, K = data.shape
+ assert data.is_contiguous()
+ assert M % (2 * self.num_warps * 2 * 8) == 0 and K % 2 == 0, (
+ f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
+ )
+ b = len(batch)
+ data = data.reshape(
+ *batch,
+ M // (2 * self.num_warps * 2 * 8),
+ 2,
+ self.num_warps,
+ 2,
+ 8,
+ K // 2,
+ 2,
+ )
+ perm = [0, 2, 5, 1, 4, 6, 3]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.flatten(-5, -1)
+ data = data.flatten(-3, -2)
+ assert data.shape[-2] == M // 32
+ assert data.shape[-1] == K * 32
+ data = self._maybe_mT(data)
+ return data
+
+ def unswizzle_data(self, data):
+ data = self._maybe_mT(data)
+ *batch, M, K = data.shape
+ b = len(batch)
+ data = data.reshape(
+ *batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2
+ )
+ perm = [0, 3, 1, 6, 4, 2, 5]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.reshape(*batch, M * 32, K // 32)
+ data = self._maybe_mT(data)
+ return data
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
+
+
+@triton.jit
+def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
+ """
+ Triton inverse of swizzle_mxfp4_scale_hopper
+ """
+ tl.static_assert(len(x.shape) == 2, "NYI")
+ # implementation assumes mxfp data is packed along the last dimension
+ x = x.trans() if mx_axis == 0 else x
+ M: tl.constexpr = x.shape[0]
+ K: tl.constexpr = x.shape[1]
+ tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
+ tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
+ x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
+ x = x.trans(0, 3, 1, 6, 4, 2, 5)
+ x = x.reshape(M * 32, K // 32)
+ # implementation assumed mxfp data is packed along the last dimension
+ x = x.trans() if mx_axis == 0 else x
+ return x
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout_details/hopper_value.py b/vllm/kvprune/triton_kernels/tensor_details/layout_details/hopper_value.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ddfadf09427f519bc9867094c7855d9d12eac7
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor_details/layout_details/hopper_value.py
@@ -0,0 +1,362 @@
+import torch
+import triton
+import triton.language as tl
+from .base import Layout
+
+
+def right_shift_unsigned(x, shift):
+ return (x >> shift) & ((1 << (32 - shift)) - 1)
+
+
+# -----------------------------------------------------------------------
+# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
+# 1000000111000000 (first fp4)
+# 1000000111000000 (second fp4)
+# 1000000111000000 (third fp4)
+# 0110110000000000 (fourth fp4)
+# This is done so that dequantization can be done in 14 SASS instructions
+# -----------------------------------------------------------------------
+
+
+def _compress_fp4(x):
+ x = x.to(torch.int32)
+ return ((x & 0x8) << 12) | ((x & 0x7) << 6)
+
+
+def _compress_fourth(x):
+ x = x.to(torch.int32)
+ return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13)
+
+
+def _pack_bits(x: torch.Tensor, mx_axis: int):
+ x = x.contiguous()
+ assert x.shape[-1] % 4 == 0, (
+ "Input tensor must have a last dimension divisible by 4"
+ )
+ x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4))
+ first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16)
+ second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16)
+ third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16)
+ fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16)
+ x = (
+ first
+ | right_shift_unsigned(second, 3)
+ | right_shift_unsigned(third, 6)
+ | fourth
+ )
+ assert x.is_contiguous()
+ x = x.view(torch.uint8)
+ return x
+
+
+# -----------------------------------------------------------------------
+# inverse operation of _pack_bits
+# -----------------------------------------------------------------------
+
+
+def _bf16_to_fp4e2m1(x):
+ # 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
+ assert x.dtype == torch.int16
+ s = (right_shift_unsigned(x, 15) & 0x1) << 3
+ em = right_shift_unsigned(x, 6) & 0x7
+ return (s | em).to(torch.uint8)
+
+
+def _bf16x2_to_fp4e2m1x2(x):
+ # 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
+ assert x.dtype == torch.int32
+ lo = (x & 0xFFFF).to(torch.int16)
+ hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16)
+ ret_lo = _bf16_to_fp4e2m1(lo)
+ ret_hi = _bf16_to_fp4e2m1(hi)
+ return ret_lo | (ret_hi << 4)
+
+
+def _unpack_bits(x, mx_axis: int):
+ x = x.view(torch.int32)
+ m = 0b10000001110000001000000111000000
+ a = (x << 1) & 0b10000000000000001000000000000000
+ b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000
+ c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000
+ unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c]
+ x = torch.stack(unpacked, dim=-1)
+ x = x.flatten(-2, -1)
+ x = _bf16x2_to_fp4e2m1x2(x)
+ return x
+
+
+# -----------------------------------------------------------------------
+
+
+class HopperMXValueLayout(Layout):
+ name: str = "HOPPER_VALUE"
+
+ def __init__(self, shape, mx_axis, mma_version=3):
+ super().__init__(shape)
+ assert mx_axis in range(len(shape))
+ self.mx_axis = mx_axis
+ self.mma_version = mma_version
+ (
+ *self.leading_shape,
+ self.K,
+ self.N,
+ ) = shape
+
+ def _maybe_mT(self, data):
+ if self.mx_axis == len(self.leading_shape):
+ return data.mT
+ return data
+
+ def swizzle_data(self, data):
+ """
+ Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
+ (*, M // 4, K * 4) such that:
+
+ 1) Groups contiguously all the elements owned by the same thread of 4
+ mma tiles along the K axis. The following animation shows a similar
+ grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
+ as done here:
+ https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
+
+ 2) Moves the elements belonging to thread 4-7 to be contiguous with those
+ from thread 0-3. This is done to get a full cache line when loading them
+ from HBM.
+
+ mx_axis selects the lhs or rhs of the matmul.
+
+ WARNING: Assumes that the matmul will be done in bf16 or fp16!
+ Implementing it for fp8 is as easy as making the tile size (8, 8)
+ """
+ batch = data.ndim - 2
+ assert batch >= 0
+ assert self.mma_version in (2, 3)
+ data = self._maybe_mT(data)
+ init_shape = data.shape
+
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
+
+ # Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
+ contig = (1, u8_kwidth)
+ scott_trick = (2, 1)
+ threads = (4, 4)
+ warp_tile = (2, 2)
+ k_tile = (1, 4 // u8_kwidth)
+
+ sizes = list(data.shape[:-2])
+ pads = []
+ # [rest, K, tile, threads] per dimension
+ for i, (a, b, c, s, d) in enumerate(
+ zip(k_tile, warp_tile, threads, scott_trick, contig)
+ ):
+ pack = a * b * c * s * d
+ size = data.shape[batch + i]
+ pad = (pack - size % pack) % pack
+ pads += [(0, pad)]
+ sizes.append((size + pad) // pack)
+ sizes += [a, b, c, s, d]
+
+ pads = tuple(x for t in pads[::-1] for x in t)
+ data = torch.nn.functional.pad(data, pads)
+ init_shape = data.shape
+ # 0: rest[0]
+ # 1: k_tile[0]
+ # 2: warp_tile[0]
+ # 3: threads[0]
+ # 4: scott_trick[0]
+ # 5: contig[0]
+ # 6: rest[1]
+ # 7: k_tile[1]
+ # 8: warp_tile[1]
+ # 9: threads[1]
+ # 10: scott_trick[1]
+ # 11: contig[1]
+ data = data.view(*sizes)
+ # Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
+ perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11]
+ perm = list(range(batch)) + [batch + p for p in perm]
+ data = data.permute(*perm).contiguous()
+ # These are views
+ data = data.flatten(-10, -1)
+ data = data.flatten(-3, -2)
+ assert data.is_contiguous()
+ assert data.shape[-2] == init_shape[-2] // 4
+ assert data.shape[-1] == init_shape[-1] * 4
+ # twiddle the bits
+ data = _pack_bits(data, self.mx_axis)
+ data = self._maybe_mT(data)
+ return data
+
+ def unswizzle_data(self, data):
+ data = self._maybe_mT(data)
+ data = _unpack_bits(data, self.mx_axis)
+ *batch, M, K = data.shape
+ # We have two times the elements if we already upcasted to bfloat16
+ mult = 2 if data.dtype == torch.bfloat16 else 1
+ assert M % 4 == 0, "M must be divisible by 4"
+ assert K % (4 * 8 * 2 * 2 * mult) == 0, (
+ f"K must be divisible by {4 * 8 * 2 * 2 * mult}"
+ )
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
+ data = data.reshape(
+ *batch,
+ M // 4,
+ 4,
+ K // (4 * 8 * 2 * 2 * mult),
+ 2,
+ 4,
+ 8 // u8_kwidth,
+ 2,
+ u8_kwidth * mult,
+ )
+ b = len(batch)
+ perm = [0, 6, 1, 3, 2, 5, 4, 7]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.reshape(*batch, M * 4, K // 4)
+ data = self._maybe_mT(data)
+ return data[..., : self.K, : self.N]
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
+
+
+@triton.jit
+def _unshuffle_triton(x, mma_version: tl.constexpr):
+ """
+ Triton inverse of swizzle_mxfp4_value_hopper
+ """
+ tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3")
+ # if mx_axis == 0:
+ # x = x.trans()
+
+ # We have two times the elements if we already upcasted to bfloat16
+ mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1
+ M: tl.constexpr = x.shape[0]
+ K: tl.constexpr = x.shape[1]
+ tl.static_assert(M % 4 == 0, "M must be divisible by 4")
+ tl.static_assert(
+ K % (4 * 8 * 2 * 2 * mult) == 0,
+ f"K must be divisible by {4 * 8 * 2 * 2 * mult}",
+ )
+
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1
+ x = x.reshape(
+ M // 4,
+ 4,
+ K // (4 * 8 * 2 * 2 * mult),
+ 2,
+ 4,
+ 8 // u8_kwidth,
+ 2,
+ u8_kwidth * mult,
+ )
+ x = x.trans(0, 6, 1, 3, 2, 5, 4, 7)
+ x = x.reshape(M * 4, K // 4)
+ # if mx_axis == 0:
+ # x = x.trans()
+ return x
+
+
+@triton.jit
+def _unpack_fp4_to_bf16_triton(x):
+ # For now we implement just H100 support (mul.bf16x2)
+ # A100 support is possible via fma
+ r0, r1 = tl.inline_asm_elementwise(
+ r"""
+ {
+ .reg .b32 b, c, d<7>, scale;
+ .reg .b32 bias;
+ mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
+ // We add the missing bias to the scale directly
+ and.b32 $0, $4, 0b10000001110000001000000111000000;
+ mul.bf16x2 $0, $0, bias;
+ shl.b32 b, $4, 3;
+ and.b32 $1, b, 0b10000001110000001000000111000000;
+ mul.bf16x2 $1, $1, bias;
+ shl.b32 c, $4, 6;
+ and.b32 $2, c, 0b10000001110000001000000111000000;
+ mul.bf16x2 $2, $2, bias;
+ // Unpack last two elements
+ shl.b32 d0, $4, 1;
+ and.b32 d1, d0, 0b10000000000000001000000000000000;
+ shr.b32 d2, $4, 3;
+ and.b32 d3, d2, 0b00000001100000000000000110000000;
+ or.b32 d4, d1, d3;
+ shr.b32 d5, $4, 7;
+ and.b32 d6, d5, 0b00000000010000000000000001000000;
+ or.b32 $3, d4, d6;
+ mul.bf16x2 $3, $3, bias;
+ }
+ """,
+ constraints="=r,=r,=r,=r,r",
+ args=[x],
+ dtype=(tl.bfloat16, tl.bfloat16),
+ is_pure=True,
+ pack=4,
+ )
+ # Concat each pack of 4
+ x = tl.join(r0, r1)
+ x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2])
+ x = x.trans(0, 1, 3, 2)
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
+ return x
+
+
+@triton.jit
+def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
+ """
+ Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
+ (x << 0) & 0b1000000111000000
+ (x << 3) & 0b1000000111000000
+ (x << 6) & 0b1000000111000000
+ ((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
+ """
+ # upcast values to bfloat16
+ tl.static_assert(len(x.shape) == 2)
+ tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1")
+ tl.static_assert(x.shape[1] % 4 == 0)
+ tl.static_assert(x.dtype == tl.uint8)
+ if mx_axis == 0:
+ x = x.trans()
+ x = _unpack_fp4_to_bf16_triton(x)
+ x = _unshuffle_triton(x, mma_version=3)
+ if mx_axis == 0:
+ x = x.trans()
+
+ # upcast scale to bfloat16
+ # Add bias missing from the bf16 upcasting sequence
+ # triton / LLVM generates terrible code for this sequence
+ # scale = scale.to(tl.uint16)
+ # scale = scale << 7
+ # scale = scale.to(tl.bfloat16, bitcast=True)
+ scale = tl.inline_asm_elementwise(
+ r"""
+ {
+ prmt.b32 $0, $2, 0, 0x5140;
+ shl.b32 $0, $0, 7;
+ prmt.b32 $1, $2, 0, 0x7362;
+ shl.b32 $1, $1, 7;
+ }
+ """,
+ constraints="=r,=r,r",
+ args=[scale],
+ dtype=tl.bfloat16,
+ is_pure=True,
+ pack=4,
+ )
+ # Broadcast scale
+ scale = scale.expand_dims(mx_axis + 1)
+ scale = scale.broadcast_to(
+ scale.shape[: mx_axis + 1] + [32] + scale.shape[mx_axis + 2 :]
+ )
+ scale = scale.reshape(x.shape)
+
+ # Combine scale and x
+ x = x * scale
+ return x
diff --git a/vllm/kvprune/triton_kernels/tensor_details/layout_details/strided.py b/vllm/kvprune/triton_kernels/tensor_details/layout_details/strided.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbfd9248fca219eb94dae358cafd7fac6e082cd1
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/tensor_details/layout_details/strided.py
@@ -0,0 +1,17 @@
+from .base import Layout
+
+
+class StridedLayout(Layout):
+ name: str = None
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+
+ def swizzle_data(self, data):
+ return data
+
+ def unswizzle_data(self, data):
+ return data
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
diff --git a/vllm/kvprune/triton_kernels/testing.py b/vllm/kvprune/triton_kernels/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..c16cdbe6897b8a8c875080c93aa5cc8713768615
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/testing.py
@@ -0,0 +1,215 @@
+import enum
+import functools
+import os
+import subprocess
+import sys
+import torch
+from vllm.kvprune.triton_kernels.numerics import (
+ MAX_FINITE_FLOAT8E4B8,
+ MAX_FINITE_FLOAT8E4NV,
+ MAX_FINITE_FLOAT8E5,
+)
+
+
+def assert_equal(ref, tri):
+ if isinstance(ref, torch.Tensor):
+ assert torch.all(ref == tri)
+ else:
+ assert ref == tri
+
+
+def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
+ if tri.dtype.itemsize == 1:
+ ref_as_type = ref.to(tri.dtype)
+ if ref.dtype == tri.dtype:
+ assert torch.all(ref_as_type == tri)
+ return
+ ref = ref_as_type
+
+ if ref.numel() == 0:
+ return
+
+ if maxtol is None:
+ maxtol = 2e-2
+ if rmstol is None:
+ rmstol = 4e-3
+ """
+ Compare reference values against obtained values.
+ """
+
+ # cast to float32:
+ ref = ref.to(torch.float32).detach()
+ tri = tri.to(torch.float32).detach()
+ assert ref.shape == tri.shape, (
+ f"Tensors must have same size {ref.shape=} {tri.shape=}"
+ )
+
+ # deal with infinite elements:
+ inf_mask_ref = torch.isinf(ref)
+ inf_mask_tri = torch.isinf(tri)
+ assert torch.equal(inf_mask_ref, inf_mask_tri), (
+ "Tensor must have same infinite elements"
+ )
+ refn = torch.where(inf_mask_ref, 0, ref)
+ trin = torch.where(inf_mask_tri, 0, tri)
+
+ # normalise so that RMS calculation doesn't overflow:
+ eps = 1.0e-30
+ multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
+ refn *= multiplier
+ trin *= multiplier
+
+ ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
+
+ rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
+ max_err = torch.max(rel_err).item()
+ rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
+
+ if verbose:
+ print(
+ "%s maximum relative error = %s (threshold = %s)"
+ % (description, max_err, maxtol)
+ )
+ print(
+ "%s RMS relative error = %s (threshold = %s)"
+ % (description, rms_err, rmstol)
+ )
+
+ if max_err > maxtol:
+ bad_idxs = torch.nonzero(rel_err > maxtol)
+ num_nonzero = bad_idxs.size(0)
+ bad_idxs = bad_idxs[:1000]
+ print(
+ "%d / %d mismatched elements (shape = %s) at coords %s"
+ % (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())
+ )
+
+ bad_idxs = bad_idxs.unbind(-1)
+ print("ref values: ", ref[tuple(bad_idxs)].cpu())
+ print("tri values: ", tri[tuple(bad_idxs)].cpu())
+
+ assert max_err <= maxtol
+ assert rms_err <= rmstol
+
+
+class ComputeSanitizerTool(enum.Enum):
+ MEMCHECK = "memcheck"
+ RACECHECK = "racecheck"
+ SYNCCHECK = "synccheck"
+ INITCHECK = "initcheck"
+
+
+def compute_sanitizer(**target_kwargs):
+ """
+ Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
+ to expose potential memory access errors.
+ This decorator requires the `request` fixture to be present.
+ If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
+ Running tests under compute sanitizer requires launching subprocess and is slow,
+ so use sparingly
+ """
+
+ def decorator(test_fn):
+ @functools.wraps(test_fn)
+ def wrapper(*args, **kwargs):
+ if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
+ test_fn(*args, **kwargs)
+ return
+
+ import psutil
+
+ if target_kwargs.pop("clear_torch_cache", False):
+ # If we don't pop clear_torch_cache, it won't pass
+ # target_kwargs.items() <= kwargs.items() condition below.
+ torch.cuda.empty_cache()
+ tools_to_check = target_kwargs.pop(
+ "tools_to_check", [ComputeSanitizerTool.MEMCHECK]
+ )
+ assert isinstance(tools_to_check, list), f"{tools_to_check=}"
+ assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
+ f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}"
+ )
+
+ ppid_name = psutil.Process(os.getppid()).exe()
+ run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
+ if "run_sanitizer" in kwargs:
+ run_compute_sanitizer &= kwargs["run_sanitizer"]
+ if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
+ for tool in tools_to_check:
+ path = os.path.realpath(test_fn.__globals__["__file__"])
+ # get path of current file
+ env = {
+ "PATH": os.environ["PATH"],
+ "PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
+ "TORCH_SHOW_CPP_STACKTRACES": "1",
+ "CUDA_LAUNCH_BLOCKING": "1",
+ }
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
+ assert "request_fixture" in kwargs, (
+ "memcheck'ed test must have a (possibly unused) `request` fixture"
+ )
+ test_id = kwargs["request_fixture"].node.callspec.id
+ cmd = f"{path}::{test_fn.__name__}[{test_id}]"
+ cmd = [
+ "compute-sanitizer",
+ "--target-processes=application-only",
+ "--destroy-on-device-error=context",
+ f"--tool={tool.value}",
+ sys.executable,
+ "-m",
+ "pytest",
+ "-vsx",
+ cmd,
+ ]
+ for opt in ["--update_checksum", "--ignore_checksum_error"]:
+ if opt in sys.argv:
+ cmd.append(opt)
+ out = subprocess.run(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ env=env,
+ )
+ sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
+ out.stdout
+ ) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
+ test_output = out.stdout
+ if type(test_output) is bytes:
+ test_output = test_output.decode()
+
+ fail = False
+ if not sanitizer_ok:
+ print("compute-sanitizer returned an error")
+ fail = True
+ elif out.returncode != 0:
+ print(
+ "The test failed due to some other reason: consider running without compute-sanitizer to verify."
+ )
+ print(f"{out.returncode=}")
+ fail = True
+
+ if fail:
+ print("*****************************************************")
+ print("******************** TEST OUTPUT ********************")
+ print("*****************************************************")
+ print(test_output)
+ print("*****************************************************")
+ print("****************** TEST OUTPUT END ******************")
+ print("*****************************************************")
+ assert None
+ else:
+ test_fn(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def compute_actual_scale(x, dtype):
+ max_finite = {
+ torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
+ torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
+ torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
+ }[dtype]
+ return x.abs().max() / max_finite
diff --git a/vllm/kvprune/triton_kernels/topk.py b/vllm/kvprune/triton_kernels/topk.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a00c65634cff34e9a0a71d73ab9c9dd2eab7760
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/topk.py
@@ -0,0 +1,157 @@
+import torch
+import triton
+from vllm.kvprune.triton_kernels.topk_details._topk_forward import _topk_forward
+from vllm.kvprune.triton_kernels.topk_details import _topk_backward
+from vllm.kvprune.triton_kernels.tensor import Tensor, Bitmatrix
+from typing import Optional, Union
+
+
+def topk_forward(
+ x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None
+):
+ if not isinstance(x, Tensor):
+ x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
+ x_shape_max = [x.shape[0], x.shape[1]]
+ x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
+ cdiv = lambda a, b: (a + b - 1) // b
+ BLOCK_M = 32
+ BLOCK_N = 32
+ BLOCK_S = 128
+ assert len(x.shape) == 2
+ assert x.shape_max[-1] < 32768
+ assert dim == 1
+ assert return_bitmatrix
+ n_rows, n_cols = x.shape
+ n_rows_max, _ = x.shape_max
+ dev = x.device
+ # scratchpad tensors
+ # NOTE: these are not returned
+ y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
+ if y_indx is not None:
+ use_provided_indx = True
+ else:
+ y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
+ use_provided_indx = False
+ # create bitmatrix in transposed memory layout:
+ n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
+ n_cols_words = n_cols_pad // 32
+ bitmatrix = torch.empty(
+ (n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev
+ )
+ bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
+ s_blocks = cdiv(n_cols, BLOCK_S)
+ s_cols = s_blocks * BLOCK_S
+ scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev)
+ pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
+ _topk_forward[(pids,)](
+ x,
+ x.stride(0), # inputs
+ y_vals,
+ y_indx,
+ y_vals.stride(0),
+ use_provided_indx, # output [topk]
+ bitmatrix,
+ bitmatrix.stride(0),
+ bitmatrix.stride(1), # output [bitmatrix]
+ n_rows,
+ n_cols, # shapes
+ scratchpad,
+ BLOCK_S,
+ s_blocks, # thing to memset to zero
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N, # tunable parameter
+ APPLY_SOFTMAX=apply_softmax,
+ N_EXPTS_PAD=n_cols_pad,
+ N_EXPTS_ACT=k, # constants
+ )
+ bitmatrix_shape = [n_rows, n_cols_words * 32]
+ bitmatrix_shape_max = [n_rows_max, None]
+ bitmatrix = Bitmatrix(
+ bitmatrix,
+ shape=bitmatrix_shape,
+ shape_max=bitmatrix_shape_max,
+ scratchpad=scratchpad,
+ )
+ return y_vals, y_indx, bitmatrix
+
+
+def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):
+ assert dy_vals.shape[-1] == k
+ n_expts_pad = triton.next_power_of_2(x.shape[-1])
+ dx = torch.empty_like(x)
+ _topk_backward[(dy_vals.shape[0],)](
+ y_indx,
+ y_indx.stride(0),
+ dy_vals,
+ dy_vals.stride(0),
+ x,
+ x.stride(0), # inputs
+ dx, # outputs
+ dx.stride(0),
+ x.shape[0],
+ n_rows,
+ x.shape[-1],
+ APPLY_SOFTMAX=apply_softmax,
+ N_EXPTS_ACT=k,
+ N_EXPTS_PAD=n_expts_pad,
+ )
+ return dx
+
+
+class TopK(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows):
+ y_vals, y_indx, bitmatrix = topk_forward(
+ x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows
+ )
+ ctx.save_for_backward(x, y_indx)
+ ctx.apply_softmax = apply_softmax
+ ctx.k = k
+ ctx.n_rows = n_rows
+ return y_vals, y_indx, bitmatrix
+
+ @staticmethod
+ def backward(ctx, dy_vals, _0, _1):
+ x, y_indx = ctx.saved_tensors
+ dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax)
+ return dx, None, None, None, None, None, None
+
+
+def topk(
+ x: Union[Tensor, torch.Tensor],
+ k: int,
+ apply_softmax: bool = True,
+ dim: int = 1,
+ return_bitmatrix: bool = True,
+ y_indx: Optional[torch.Tensor] = None,
+ n_rows: Optional[int] = None,
+):
+ """
+ Computes the top-k values and indices along a specified dimension of a tensor.
+ Note that the input can be either a `Tensor` or a `torch.Tensor`, but the output will always be a `torch.Tensor`.
+
+ Parameters
+ ----------
+ x : Union[triton_kernels.Tensor, torch.Tensor]
+ Input tensor of shape (n_tokens, n_expts).
+ k : int
+ Number of top elements to retrieve.
+ apply_softmax : bool, default True
+ Whether to apply softmax to the input tensor before computing top-k.
+ dim : int, default 1
+ Dimension along which to compute top-k.
+ return_bitmatrix : bool, default True
+ A bitmatrix of shape (n_tokens, cdiv(n_expts, 32)).
+ Each bit on [t, b] indicates whether the b-th expert was selected for the t-th token.
+ y_indx : torch.Tensor, optional
+ Pre-allocated tensor for storing indices of top-k elements with shape (n_tokens, k).
+ If provided, we skip the computation of top-k indices and use this tensor instead.
+ n_rows : int, optional
+ Number of rows to apply top-k on. If None, we consider all rows in `x`.
+
+ Returns
+ -------
+ (expt_scal, expt_indx, bitmatrix) : Tuple[torch.Tensor, torch.Tensor, Bitmatrix]
+ """
+ ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
+ return ret
diff --git a/vllm/kvprune/triton_kernels/topk_details/__init__.py b/vllm/kvprune/triton_kernels/topk_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune/triton_kernels/topk_details/_topk_backward.py b/vllm/kvprune/triton_kernels/topk_details/_topk_backward.py
new file mode 100644
index 0000000000000000000000000000000000000000..eebe481771543a05cfab5741bf1a0c875248f70d
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/topk_details/_topk_backward.py
@@ -0,0 +1,51 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _topk_backward(
+ Yi,
+ stride_ym, # topk indices
+ DY,
+ stride_dym, # output gradient values
+ X,
+ stride_xm, # input values
+ DX,
+ stride_dxm, # input gradient values
+ n_rows,
+ NRows,
+ n_expts_tot,
+ APPLY_SOFTMAX: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ N_EXPTS_PAD: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ if NRows is not None:
+ n_rows = tl.load(NRows)
+ if pid_m >= n_rows:
+ return
+ Yi += pid_m * stride_ym
+ DY += pid_m * stride_dym
+ X += pid_m * stride_xm
+ DX += pid_m * stride_dxm
+ # --
+ offs_xn = tl.arange(0, N_EXPTS_PAD)
+ offs_yn = tl.arange(0, N_EXPTS_ACT)
+ mask_xn = offs_xn < n_expts_tot
+ # recompute softmax
+ y_indx = tl.load(Yi + offs_yn)
+ x = tl.load(X + y_indx)
+ x = x.to(tl.float32)
+ y = tl.softmax(x)
+ # compute input-gradient
+ dy = tl.load(DY + offs_yn)
+ dy = dy.to(tl.float32)
+ s = tl.sum(y * dy, 0)
+ # write-back input gradient
+ tl.store(DX + offs_xn, 0, mask=mask_xn)
+ tl.debug_barrier()
+ if APPLY_SOFTMAX:
+ dx = y * (dy - s)
+ else:
+ dx = dy
+ tl.store(DX + y_indx, dx)
diff --git a/vllm/kvprune/triton_kernels/topk_details/_topk_forward.py b/vllm/kvprune/triton_kernels/topk_details/_topk_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf27ba999cca1a2b8fe63f1c386680c77ea4cec9
--- /dev/null
+++ b/vllm/kvprune/triton_kernels/topk_details/_topk_forward.py
@@ -0,0 +1,183 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def get_topmask_and_fullmask(x):
+ tl.static_assert(
+ x.dtype.is_int_unsigned(), "floating-point value must be passed as bits"
+ )
+ tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
+ fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
+ tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
+ fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
+ return tm_arr, fm_arr
+
+
+@triton.jit
+def fpval_to_key(x):
+ tm, fm = get_topmask_and_fullmask(x)
+ return x ^ tl.where((x & tm) != 0, fm, tm)
+
+
+@triton.jit
+def key_to_fpval(x):
+ tm, fm = get_topmask_and_fullmask(x)
+ return x ^ tl.where((x & tm) == 0, fm, tm)
+
+
+# stable top-k tie-breaks to value with smaller index
+@triton.jit
+def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr):
+ return N_EXPTS_PAD - indx
+
+
+@triton.jit
+def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr):
+ return N_EXPTS_PAD - indx
+
+
+@triton.jit
+def streaming_topk(
+ X,
+ stride_xm,
+ n_expts_tot,
+ offs_m,
+ mask_m,
+ N_EXPTS_PAD: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
+ x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
+ if x_nbits < 16:
+ # this ensures that we leave at least 16 bits for expert index
+ # even if the input dtype is smaller than 16 bits:
+ y_nbits: tl.constexpr = 32
+ else:
+ y_nbits: tl.constexpr = x_nbits * 2
+ x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}")
+ x_dtype: tl.constexpr = X.dtype.element_ty
+
+ # subtract 1 from loop iterations because we peel the first (masked) iteration:
+ loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
+ offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_x_n[None, :] < n_expts_tot
+
+ # first iteration:
+ X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
+ x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
+ x = fpval_to_key(x.to(x_utype, bitcast=True))
+ x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
+ acc = tl.topk(x, N_EXPTS_ACT, dim=1)
+
+ # subsequent iterations:
+ for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
+ acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
+ X_ptrs -= BLOCK_N
+ offs_x_n -= BLOCK_N
+ x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
+ x = fpval_to_key(x.to(x_utype, bitcast=True))
+ x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
+ acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
+
+ # rotate expert index into upper 16 bits:
+ # 0000vvvvvvvviiii --> iiii0000vvvvvvvv
+ acc = (acc << (y_nbits - 16)) | (acc >> 16)
+ # sort in ascending order of expert (descending order of key)
+ acc = tl.sort(acc, dim=1, descending=True)
+ # iiii0000vvvvvvvv --> 0000iiii:
+ y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
+ y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
+ # iiii0000vvvvvvvv --> vvvvvvvv:
+ y_values_raw = acc.to(x_utype)
+ y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
+
+ return y_values, y_indices
+
+
+@triton.jit
+def _topk_forward(
+ X,
+ stride_xm, # inputs
+ Yv,
+ Yi,
+ stride_ym, # topk values/indices
+ USE_PROVIDED_INDX: tl.constexpr,
+ Bits,
+ stride_rm: tl.constexpr,
+ stride_rn: tl.constexpr, # bitmatrix
+ n_rows,
+ n_expts_tot, # shape
+ S,
+ BLOCK_S: tl.constexpr,
+ s_blocks, # thing to memset
+ APPLY_SOFTMAX: tl.constexpr, # constant
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_PAD: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr():
+ n_rows = tl.load(n_rows)
+
+ if pid < s_blocks:
+ tl.store(
+ S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32)
+ )
+
+ if pid * BLOCK_M >= n_rows:
+ # early exit:
+ return
+
+ tl.static_assert(BLOCK_N % 32 == 0)
+ tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
+ x_dtype: tl.constexpr = X.dtype.element_ty
+
+ # load logits
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_y_n = tl.arange(0, N_EXPTS_ACT)
+ mask_m = offs_m[:, None] < n_rows
+ if USE_PROVIDED_INDX:
+ Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ y_indices = tl.load(Yi_ptrs, mask=mask_m)
+ Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
+ y_values = tl.load(Xv_ptrs, mask=mask_m)
+ else:
+ y_values, y_indices = streaming_topk(
+ X,
+ stride_xm,
+ n_expts_tot,
+ offs_m,
+ mask_m, #
+ N_EXPTS_PAD,
+ N_EXPTS_ACT,
+ BLOCK_N,
+ )
+
+ # normalize selected values
+ if APPLY_SOFTMAX:
+ y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(
+ x_dtype
+ )
+
+ # write back
+ Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ tl.store(Yv_ptrs, y_values, mask=mask_m)
+ if not USE_PROVIDED_INDX:
+ Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ tl.store(Yi_ptrs, y_indices, mask=mask_m)
+
+ # pack into bitmatrix
+ y_div = y_indices // 32
+ y_rem = y_indices % 32
+ loop_iterations = N_EXPTS_PAD // BLOCK_N
+ for i in range(loop_iterations):
+ offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
+ y2 = tl.where(
+ y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0
+ )
+ r = tl.reduce_or(y2, axis=1)
+ BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
+ tl.store(BitsPtrs, r, mask=mask_m)
diff --git a/vllm/kvprune/utils/__init__.py b/vllm/kvprune/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f61e74eab6605369645345370ef951e9141fef14
--- /dev/null
+++ b/vllm/kvprune/utils/__init__.py
@@ -0,0 +1,29 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Shared helpers: Triton compat, layout bridge, context, sequences."""
+
+from vllm.kvprune.utils.layout_bridge import (
+ block_table_to_global_page_table,
+ build_batch_mapping,
+ build_page_table_head_major,
+ flatten_kv_cache_head_major,
+ flatten_kv_cache_plane,
+ write_head_major_flat_to_interleaved,
+)
+from vllm.kvprune.utils.triton_compat import (
+ autotune as triton_autotune,
+ cuda_capability_geq,
+ maybe_set_allocator,
+)
+
+__all__ = [
+ "block_table_to_global_page_table",
+ "build_batch_mapping",
+ "build_page_table_head_major",
+ "cuda_capability_geq",
+ "flatten_kv_cache_head_major",
+ "flatten_kv_cache_plane",
+ "write_head_major_flat_to_interleaved",
+ "maybe_set_allocator",
+ "triton_autotune",
+]
diff --git a/vllm/kvprune/utils/arguments.py b/vllm/kvprune/utils/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..7365a6965c645dfbaf109fe752b5d4e23e1d5155
--- /dev/null
+++ b/vllm/kvprune/utils/arguments.py
@@ -0,0 +1,436 @@
+import itertools
+import math
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+from vllm.kvprune.compression import CompressionMethod
+from vllm.kvprune.compression.compression_config import BatchCompressionParams
+from vllm.kvprune.config.engine_config import LLMConfig
+from vllm.kvprune.utils.sequence import Sequence
+from vllm.kvprune.utils.kv_dist import broadcast_from_tp_rank0
+from vllm.kvprune.utils.tp_utils import kv_heads_shard_divisor
+
+
+@dataclass
+class PrefillBatchArguments:
+ B: int
+ N: int
+ do_compression: bool
+ compression_method: CompressionMethod
+ compression_chunk_size: int
+
+ seq_ids: torch.Tensor
+
+ input_ids: torch.Tensor
+ positions: torch.Tensor
+ cu_seqlens_q: torch.Tensor
+ cu_seqlens_k: torch.Tensor
+ max_seqlen_q: int
+ max_seqlen_k: int
+
+ batch_tokens_to_retain: Optional[torch.Tensor]
+ max_tokens_to_retain: Optional[int]
+ protected_first: Optional[List[int]]
+ protected_last: Optional[List[int]]
+
+ PHI: Optional[torch.Tensor]
+
+ # args needed for memory reservation
+ context_lens: torch.Tensor
+ max_new_tokens: torch.Tensor
+
+ # 与 kvpress ``CompactorPress`` blending 默认(未显式指定时用 compression_ratio)对齐
+ compression_ratio: float = 1.0
+
+class PackedTensorArguments:
+ def __init__(
+ self,
+ rank: int,
+ max_batched_tokens: int,
+ config: LLMConfig,
+ seed: int = 42,
+ *,
+ device: torch.device | None = None,
+ use_tp_group_for_collectives: bool = False,
+ ) -> None:
+ hf_config = config.hf_config
+ self.rank = rank
+ self.device = device if device is not None else torch.device(f"cuda:{rank}")
+ self._use_tp_group = use_tp_group_for_collectives
+ self.max_num_batches = config.max_num_seqs
+ self.max_batched_tokens = max_batched_tokens
+ _ws = kv_heads_shard_divisor()
+ self.num_kv_heads = hf_config.num_key_value_heads // _ws
+ self.world_size = config.tensor_parallel_size
+ self.page_size = int(config.kvcache_page_size)
+ self.head_dim = getattr(hf_config, "head_dim", None)
+ self.sketch_dim = config.leverage_sketch_size
+ self.model_dtype = hf_config.torch_dtype
+
+ # i64 pack = [seq_ids (BMAX)] || [input_ids (NMAX)] || [positions (NMAX)] || max_new_tok (BMAX)
+ self.i64_len_max = (
+ self.max_num_batches + 2 * self.max_batched_tokens + self.max_num_batches
+ )
+ self.packed_context_i64 = torch.empty(
+ self.i64_len_max, dtype=torch.int64, device=self.device
+ )
+
+ # i32 pack = [header (6): ... + compression_ratio*1e6] || [cu_q (BMAX+1)] || ...
+ # || [protected_first_tokens (BMAX)] || [protected_last_tokens (BMAX)]
+ self.i32_len_max = (
+ 6
+ + (self.max_num_batches + 1)
+ + (self.max_num_batches + 1)
+ + self.max_num_batches
+ + self.max_num_batches
+ + self.max_num_batches
+ + self.max_num_batches
+ )
+ self.packed_context_i32 = torch.empty(
+ self.i32_len_max, dtype=torch.int32, device=self.device
+ )
+
+ self.generator = torch.Generator(device=self.device).manual_seed(seed)
+ self.PHI = torch.randn(
+ (self.head_dim, self.sketch_dim),
+ device=self.packed_context_i32.device,
+ generator=self.generator,
+ ).to(self.model_dtype) * (1 / math.sqrt(self.sketch_dim))
+
+ def _master_build_prefill(
+ self, seqs: List[Sequence], batch_compression_params: BatchCompressionParams
+ ) -> PrefillBatchArguments:
+ B = len(seqs)
+ if B == 0:
+ raise ValueError(
+ "prefill batch is empty (scheduler should not call build_prefill with "
+ "no sequences)"
+ )
+ Ls = [x.prompt_len for x in seqs]
+
+ N = sum(Ls)
+ assert N <= self.max_batched_tokens
+ do_compression = any(x.compression_params.compression_ratio < 1.0 for x in seqs)
+ do_compression = (
+ do_compression
+ and batch_compression_params.compression_method != CompressionMethod.NONE
+ )
+ pack_slices_64 = self.packed_i64_slices(B, N)
+ pack_slices_32 = self.packed_i32_slices(B)
+
+ # max_retain = max(retain)
+ protected_first_list = [
+ x.compression_params.protected_first_tokens for x in seqs
+ ]
+ protected_last_list = [x.compression_params.protected_last_tokens for x in seqs]
+ retain = [
+ max(
+ int(
+ round(
+ x.compression_params.compression_ratio
+ * (L - s - e)
+ * self.num_kv_heads
+ )
+ ),
+ 1,
+ )
+ for s, e, L, x in zip(protected_first_list, protected_last_list, Ls, seqs)
+ ]
+ retain = torch.tensor(retain, dtype=torch.int32, device="cpu", pin_memory=True)
+ protected_first = torch.tensor(
+ protected_first_list, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ protected_last = torch.tensor(
+ protected_last_list, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ self.packed_context_i32[pack_slices_32["protected_first"]].copy_(
+ protected_first, non_blocking=True
+ )
+ self.packed_context_i32[pack_slices_32["protected_last"]].copy_(
+ protected_last, non_blocking=True
+ )
+ compression_chunk_size = (
+ batch_compression_params.chunk_size
+ if batch_compression_params.do_chunked_compression
+ else -1
+ )
+ min_compression_ratio = min(x.compression_params.compression_ratio for x in seqs)
+ cr_scaled = int(round(float(min_compression_ratio) * 1_000_000.0))
+ cr_scaled = max(min(cr_scaled, 2_000_000_000), -2_000_000_000)
+ header_host = torch.tensor(
+ [
+ B,
+ N,
+ 1 if do_compression else 0,
+ batch_compression_params.compression_method.value,
+ compression_chunk_size,
+ cr_scaled,
+ ],
+ dtype=torch.int32,
+ device="cpu",
+ pin_memory=True,
+ )
+
+ self.packed_context_i32[pack_slices_32["retain"]].copy_(
+ retain, non_blocking=True
+ )
+ self.packed_context_i32[pack_slices_32["header"]].copy_(
+ header_host, non_blocking=True
+ )
+ max_seq_qk = max(Ls)
+
+ cu = torch.tensor(
+ list(itertools.accumulate(Ls, initial=0)),
+ dtype=torch.int32,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i32[pack_slices_32["cu_q"]].copy_(cu, non_blocking=True)
+ self.packed_context_i32[pack_slices_32["cu_k"]].copy_(cu, non_blocking=True)
+ self.packed_context_i32[pack_slices_32["context_lens"]].copy_(
+ cu.diff(), non_blocking=True
+ )
+
+ seq_ids = torch.tensor(
+ [x.seq_id for x in seqs], dtype=torch.int64, device="cpu", pin_memory=True
+ )
+ input_ids = torch.tensor(
+ [tid for x in seqs for tid in x.prompt_token_ids],
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i64[pack_slices_64["seq_ids"]].copy_(
+ seq_ids, non_blocking=True
+ )
+ self.packed_context_i64[pack_slices_64["input_ids"]].copy_(
+ input_ids, non_blocking=True
+ )
+
+ positions = torch.cat(
+ [
+ torch.arange(L, dtype=torch.int64, device="cpu", pin_memory=True)
+ for L in Ls
+ ]
+ )
+ self.packed_context_i64[pack_slices_64["positions"]].copy_(
+ positions, non_blocking=True
+ )
+
+ max_new_tokens = torch.tensor(
+ [seq.sampling_params.max_new_tokens for seq in seqs],
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i64[pack_slices_64["max_new_tokens"]].copy_(
+ max_new_tokens, non_blocking=True
+ )
+ keep_budget = int(retain.max().item())
+ # Padding candidates are now selected per head on the Python side inside
+ # `scores_to_retain_indices`, so the kernel no longer needs an oversized
+ # global candidate window here. Keep this field aligned with the maximum
+ # true keep budget in the batch.
+ max_retain = keep_budget
+ # Non-blocking H2D copies above must finish before NCCL broadcast, or peers can
+ # receive stale/garbage packed buffers → wrong prefill → garbage tokens on TP>1.
+ if self.packed_context_i64.is_cuda:
+ torch.cuda.synchronize()
+ # PHI: rank 0's sketch matrix is broadcast so all TP ranks share one PHI for
+ # leverage / compactor scores (same order as packed_context: i64, i32, PHI).
+ broadcast_from_tp_rank0(
+ self.packed_context_i64, use_tp_group=self._use_tp_group
+ )
+ broadcast_from_tp_rank0(
+ self.packed_context_i32, use_tp_group=self._use_tp_group
+ )
+ if self.world_size > 1:
+ broadcast_from_tp_rank0(self.PHI, use_tp_group=self._use_tp_group)
+ prefill_args = PrefillBatchArguments(
+ B=B,
+ N=N,
+ do_compression=do_compression,
+ compression_method=batch_compression_params.compression_method,
+ compression_chunk_size=compression_chunk_size,
+ seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
+ input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
+ positions=self.packed_context_i64[pack_slices_64["positions"]],
+ cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
+ cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
+ max_seqlen_q=max_seq_qk,
+ max_seqlen_k=max_seq_qk,
+ batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
+ max_tokens_to_retain=max_retain,
+ PHI=self.PHI,
+ context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
+ max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
+ protected_first=protected_first_list,
+ protected_last=protected_last_list,
+ compression_ratio=min_compression_ratio,
+ )
+ return prefill_args
+
+ def _peer_receive_prefill(self) -> PrefillBatchArguments:
+ broadcast_from_tp_rank0(
+ self.packed_context_i64, use_tp_group=self._use_tp_group
+ )
+ broadcast_from_tp_rank0(
+ self.packed_context_i32, use_tp_group=self._use_tp_group
+ )
+ if self.world_size > 1:
+ broadcast_from_tp_rank0(self.PHI, use_tp_group=self._use_tp_group)
+ # Header is 6 fields (B, N, do_compression, method, chunk_size, cr_scaled); must match
+ # packed_i32_slices(B)["header"] for any B.
+ header = self.packed_context_i32[:6].tolist()
+ B, N = int(header[0]), int(header[1])
+ do_compression = bool(int(header[2]))
+ compression_method = CompressionMethod(int(header[3]))
+ compression_chunk_size = int(header[4])
+ compression_ratio = int(header[5]) / 1_000_000.0
+
+ pack_slices_64 = self.packed_i64_slices(B, N)
+ pack_slices_32 = self.packed_i32_slices(B)
+ max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
+ # Must match _master_build_prefill: max_seqlen_{q,k} = max(Ls), not cu_q.max()
+ # (which equals total batch tokens N and breaks varlen attention on peers).
+ keep_budget = int(self.packed_context_i32[pack_slices_32["retain"]].max().item())
+ max_retain = keep_budget
+ prefill_args = PrefillBatchArguments(
+ B=B,
+ N=N,
+ do_compression=do_compression,
+ compression_method=compression_method,
+ compression_chunk_size=compression_chunk_size,
+ seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
+ input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
+ positions=self.packed_context_i64[pack_slices_64["positions"]],
+ cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
+ cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
+ max_seqlen_q=max_seq_len,
+ max_seqlen_k=max_seq_len,
+ batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
+ max_tokens_to_retain=max_retain,
+ PHI=self.PHI,
+ context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
+ max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
+ protected_first=self.packed_context_i32[
+ pack_slices_32["protected_first"]
+ ].tolist(),
+ protected_last=self.packed_context_i32[
+ pack_slices_32["protected_last"]
+ ].tolist(),
+ compression_ratio=compression_ratio,
+ )
+ return prefill_args
+
+ @torch.inference_mode()
+ def build_prefill_args(
+ self,
+ seqs: Optional[List[Sequence]] = None,
+ batch_compression_params: Optional[BatchCompressionParams] = None,
+ ) -> PrefillBatchArguments:
+ if self.rank == 0:
+ return self._master_build_prefill(seqs, batch_compression_params)
+ return self._peer_receive_prefill()
+
+ def broadcast(self):
+ if self.world_size > 1:
+ return broadcast_from_tp_rank0(
+ self.packed_context_i64, use_tp_group=self._use_tp_group
+ )
+ return None
+
+ @staticmethod
+ def packed_i64_slices(B: int, N: int):
+ return {
+ "seq_ids": slice(0, B),
+ "input_ids": slice(B, B + N),
+ "positions": slice(B + N, B + 2 * N),
+ "max_new_tokens": slice(B + 2 * N, 2 * B + 2 * N),
+ }
+
+ @staticmethod
+ def packed_i32_slices(B: int):
+ h0, h1 = 0, 6
+ q0 = h1
+ q1 = q0 + (B + 1)
+ k0 = q1
+ k1 = k0 + (B + 1)
+ r0 = k1
+ r1 = r0 + B
+ c0 = r1
+ c1 = r1 + B
+
+ pf0 = c1
+ pf1 = c1 + B
+ pl0 = pf1
+ pl1 = pf1 + B
+ return {
+ "header": slice(h0, h1),
+ "cu_q": slice(q0, q1),
+ "cu_k": slice(k0, k1),
+ "retain": slice(r0, r1),
+ "context_lens": slice(c0, c1),
+ "protected_first": slice(pf0, pf1),
+ "protected_last": slice(pl0, pl1),
+ }
+
+
+@dataclass
+class DecodeBatchOutput:
+ output_tokens: Optional[torch.Tensor]
+ output_seq_ids: Optional[torch.Tensor]
+
+
+@dataclass
+class DecodeBatchArguments:
+ batch_mapping: Optional[torch.Tensor] = None
+ token_ids: Optional[torch.Tensor] = None
+ positions: Optional[torch.Tensor] = None
+ max_ctx_lens: Optional[torch.Tensor] = None
+ seq_ids: Optional[torch.Tensor] = None
+ temps: Optional[torch.Tensor] = None
+ desired_batch_occupancy: int = -1
+ num_stashed_batches: int = 0
+
+ def update(
+ self,
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ seq_ids,
+ temps=None,
+ desired_batch_occupancy: int = None,
+ ):
+ if self.batch_mapping is not None:
+ self.batch_mapping = torch.cat([self.batch_mapping, batch_mapping], dim=0)
+ else:
+ self.batch_mapping = batch_mapping.clone()
+ if self.token_ids is not None:
+ self.token_ids = torch.cat([self.token_ids, token_ids], dim=0)
+ else:
+ self.token_ids = token_ids.clone()
+ if self.positions is not None:
+ self.positions = torch.cat([self.positions, positions], dim=0)
+ else:
+ self.positions = positions.clone()
+ if self.max_ctx_lens is not None:
+ self.max_ctx_lens = torch.cat([self.max_ctx_lens, max_ctx_lens], dim=0)
+ else:
+ self.max_ctx_lens = max_ctx_lens.clone()
+ if self.seq_ids is not None:
+ self.seq_ids = torch.cat([self.seq_ids, seq_ids], dim=0)
+ else:
+ self.seq_ids = seq_ids.clone()
+
+ if self.temps is not None and temps is not None:
+ self.temps = torch.cat([self.temps, temps], dim=0)
+ elif temps is not None:
+ self.temps = temps.clone()
+
+ if desired_batch_occupancy is not None:
+ self.desired_batch_occupancy = desired_batch_occupancy
+
+ return self
diff --git a/vllm/kvprune/utils/context.py b/vllm/kvprune/utils/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d44a34658d665ce2d238613f5fa3fbc5cf201bf
--- /dev/null
+++ b/vllm/kvprune/utils/context.py
@@ -0,0 +1,109 @@
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import torch
+
+# Import from compression_config, not compression.__init__, to avoid circular imports
+# (compression -> compactor -> context -> compression).
+from vllm.kvprune.compression.compression_config import CompressionMethod
+from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule
+
+
+@dataclass
+class CompressionContext:
+ compression_method: CompressionMethod = CompressionMethod.COMPACTOR
+
+ compression_chunk_size: int = -1
+ batch_tokens_to_retain: torch.Tensor | None = None
+ max_tokens_to_retain: int = 0
+ context_lens: List[int] | None = None
+ PHI: torch.Tensor | None = None
+
+ # Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参)
+ sketch_dimension: int = 48
+ sink_size_start: int = 8
+ sink_size_end: int = 4
+ compactor_blending: Optional[float] = None
+ # 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio)
+ compression_ratio: Optional[float] = None
+
+ protected_first_tokens: List[int] | None = None
+ protected_last_tokens: List[int] | None = None
+
+ # CriticalAdaKV
+ wo_weight: Optional[torch.Tensor] = None
+ critical_ada_epsilon: float = 1e-4
+ critical_ada_first_stage_ratio: float = 0.5
+ critical_ada_alpha_safeguard: float = 0.2
+
+
+@dataclass
+class Context:
+ is_prefill: bool = False
+ do_compression: bool = False
+
+ cu_seqlens_q: torch.Tensor | None = None
+ cu_seqlens_k: torch.Tensor | None = None
+ # Set in ModelRunner.run_prefill before forward — avoids D2H inside compactor kernels.
+ cu_seqlens_q_host: Optional[Tuple[int, ...]] = None
+ cu_seqlens_k_host: Optional[Tuple[int, ...]] = None
+ max_seqlen_q: int = 0
+ max_seqlen_k: int = 0
+ batch_mapping: torch.Tensor | None = None
+ max_bh_len: int = 0
+
+ compression_context: CompressionContext | None = None
+ STORE_STREAM: torch.cuda.Stream | None = None
+
+ key_split: int | None = None
+ attention_schedule: KvpruneAttentionSchedule = (
+ KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ )
+
+
+_CONTEXT = Context()
+
+
+def get_context():
+ return _CONTEXT
+
+
+def set_context(
+ *,
+ is_prefill,
+ do_compression=False,
+ cu_seqlens_q=None,
+ cu_seqlens_k=None,
+ cu_seqlens_q_host: Optional[Tuple[int, ...]] = None,
+ cu_seqlens_k_host: Optional[Tuple[int, ...]] = None,
+ max_seqlen_q=0,
+ max_seqlen_k=0,
+ batch_mapping=None,
+ max_bh_len=0,
+ compression_context: CompressionContext = None,
+ STORE_STREAM=None,
+ key_split=None,
+ attention_schedule=KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE,
+):
+ global _CONTEXT
+ _CONTEXT = Context(
+ is_prefill,
+ do_compression,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ cu_seqlens_q_host,
+ cu_seqlens_k_host,
+ max_seqlen_q,
+ max_seqlen_k,
+ batch_mapping,
+ max_bh_len,
+ compression_context,
+ STORE_STREAM,
+ key_split,
+ attention_schedule,
+ )
+
+
+def reset_context():
+ global _CONTEXT
+ _CONTEXT = Context()
diff --git a/vllm/kvprune/utils/helpers.py b/vllm/kvprune/utils/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e833b885ec2cc2372b1a267a7b361b535fd9d938
--- /dev/null
+++ b/vllm/kvprune/utils/helpers.py
@@ -0,0 +1,35 @@
+from collections.abc import Callable
+
+import torch
+
+
+def maybe_execute_in_stream(
+ fn: Callable, *args, STORE_STREAM: torch.cuda.Stream = None, **kwargs
+):
+ if STORE_STREAM is not None:
+ tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
+ tensors += [val for val in kwargs.values() if isinstance(val, torch.Tensor)]
+ obj = getattr(fn, "__self__", None)
+ if isinstance(obj, torch.Tensor):
+ tensors.append(obj)
+ STORE_STREAM.wait_stream(torch.cuda.default_stream())
+ # Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
+ # The portable API is `torch.cuda.stream(stream)`.
+ stream_ctx = (
+ STORE_STREAM
+ if hasattr(STORE_STREAM, "__enter__")
+ else torch.cuda.stream(STORE_STREAM)
+ )
+ with stream_ctx:
+ output = fn(*args, **kwargs)
+ for t in tensors:
+ t.record_stream(STORE_STREAM)
+ if isinstance(output, tuple):
+ for o in output:
+ if isinstance(o, torch.Tensor):
+ o.record_stream(torch.cuda.default_stream())
+ elif isinstance(output, torch.Tensor):
+ output.record_stream(torch.cuda.default_stream())
+ return output
+ else:
+ return fn(*args, **kwargs)
diff --git a/vllm/kvprune/utils/kv_dist.py b/vllm/kvprune/utils/kv_dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7e13120cc909c77fe28a61bfc1ea18ab928cdd8
--- /dev/null
+++ b/vllm/kvprune/utils/kv_dist.py
@@ -0,0 +1,35 @@
+"""Distributed helpers for kvprune when embedded in vLLM (use TP process group)."""
+
+from __future__ import annotations
+
+import torch
+import torch.distributed as dist
+
+
+def broadcast_from_tp_rank0(
+ tensor: torch.Tensor, *, use_tp_group: bool
+) -> None:
+ """Broadcast ``tensor`` from group-local rank 0.
+
+ When ``use_tp_group`` is False (standalone compactor subprocesses), uses the
+ default process group (world == tensor parallel size).
+
+ When True (embedded in a vLLM worker), uses vLLM's tensor-parallel group so
+ collectives do not accidentally involve DP/PP ranks if the default group is global.
+ """
+ if not use_tp_group:
+ dist.broadcast(tensor, src=0)
+ return
+ from vllm.distributed.parallel_state import get_tp_group
+
+ get_tp_group().broadcast(tensor, src=0)
+
+
+def barrier_sync(*, use_tp_group: bool) -> None:
+ """Barrier across either the default group or the TP group (see :func:`broadcast_from_tp_rank0`)."""
+ if not use_tp_group:
+ dist.barrier()
+ return
+ from vllm.distributed.parallel_state import get_tp_group
+
+ get_tp_group().barrier()
diff --git a/vllm/kvprune/utils/layout_bridge.py b/vllm/kvprune/utils/layout_bridge.py
new file mode 100644
index 0000000000000000000000000000000000000000..31321b2f7cf31db79880ecaef8e3601c8ff87662
--- /dev/null
+++ b/vllm/kvprune/utils/layout_bridge.py
@@ -0,0 +1,167 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Bridge vLLM paged KV layout to compactor Triton kernels.
+
+vLLM FlashAttention KV cache is shaped
+ [num_blocks, block_size, num_kv_heads, head_dim].
+Compactor kernels expect a flat buffer [CACHE_SIZE, head_dim] and a page table
+ global_page_table[batch, kv_head, logical_page] -> physical_page_id
+where each physical page holds ``block_size`` consecutive rows belonging to that
+KV head only.
+
+When num_kv_heads == 1 (MQA), a vLLM block maps 1:1 to compactor rows:
+ row_index = physical_block_id * block_size + offset_in_block.
+
+When ``num_kv_heads > 1``, we permute to head-major
+``[num_kv_heads, num_blocks, block_size, head_dim]`` and flatten to
+``[num_kv_heads * num_blocks * block_size, head_dim]`` so each KV head occupies
+a disjoint row range in the flat buffer. The page table is built so each
+logical compression page maps to ``global_row // PAGE_SIZE`` in that layout
+(see ``build_page_table_head_major``).
+"""
+
+from __future__ import annotations
+
+import torch
+
+
+def _cdiv(n: int, d: int) -> int:
+ return (n + d - 1) // d
+
+
+def flatten_kv_cache_head_major(
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """View ``[nb, bs, H, D]`` caches as ``[H*nb*bs, D]`` in head-major order."""
+ if key_cache.shape != value_cache.shape:
+ raise ValueError("key_cache and value_cache must match")
+ nb, bs, hkv, d = key_cache.shape
+ k_hm = key_cache.permute(2, 0, 1, 3).contiguous()
+ v_hm = value_cache.permute(2, 0, 1, 3).contiguous()
+ k_flat = k_hm.reshape(hkv * nb * bs, d)
+ v_flat = v_hm.reshape(hkv * nb * bs, d)
+ return k_flat, v_flat
+
+
+def write_head_major_flat_to_interleaved(
+ k_flat: torch.Tensor,
+ v_flat: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+) -> None:
+ """Copy ``[H*nb*bs, D]`` head-major flats back to ``[nb, bs, H, D]``."""
+ nb, bs, hkv, d = key_cache.shape
+ k_hm = k_flat.view(hkv, nb, bs, d)
+ v_hm = v_flat.view(hkv, nb, bs, d)
+ key_cache.copy_(k_hm.permute(1, 2, 0, 3))
+ value_cache.copy_(v_hm.permute(1, 2, 0, 3))
+
+
+def build_page_table_head_major(
+ block_table: torch.Tensor,
+ num_kv_heads: int,
+ num_blocks: int,
+ block_size: int,
+ page_size: int,
+ max_batches: int,
+) -> torch.Tensor:
+ """Build ``[max_batches, H, max_chain]`` page table for head-major flat KV.
+
+ Chains physical page ids in ``block_table`` order for each (batch, head).
+ Each entry is ``global_row // page_size`` where ``global_row`` indexes rows
+ in the head-major flat buffer (see ``flatten_kv_cache_head_major``).
+ """
+ bsz, max_blocks = block_table.shape
+ if bsz > max_batches:
+ raise ValueError("batch size exceeds max_batches for page table")
+ num_pages_per_block = _cdiv(block_size, page_size)
+ max_chain = max_blocks * num_pages_per_block
+ out = torch.zeros(
+ (max_batches, num_kv_heads, max_chain),
+ dtype=torch.int32,
+ device=block_table.device,
+ )
+ bt = block_table.to(torch.int64)
+ for b in range(bsz):
+ for h in range(num_kv_heads):
+ lp_idx = 0
+ for blk_i in range(max_blocks):
+ bid = int(bt[b, blk_i].item())
+ if bid < 0:
+ continue
+ if bid >= num_blocks:
+ raise ValueError(
+ f"block_table[{b},{blk_i}]={bid} out of range "
+ f"num_blocks={num_blocks}"
+ )
+ base_row = h * (num_blocks * block_size) + bid * block_size
+ for p in range(num_pages_per_block):
+ start_row = base_row + p * page_size
+ if start_row >= base_row + block_size:
+ break
+ phys = start_row // page_size
+ out[b, h, lp_idx] = int(phys)
+ lp_idx += 1
+ return out
+
+
+def flatten_kv_cache_plane(
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ num_kv_heads: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """View (num_blocks, block_size, HKV, D) caches as [num_blocks*block_size*HKV, D].
+
+ This matches compactor row indexing only when HKV == 1 (see module doc).
+ """
+ if num_kv_heads != 1:
+ raise ValueError(
+ "flatten_kv_cache_plane requires num_kv_heads==1 for compactor layout"
+ )
+ if key_cache.shape != value_cache.shape:
+ raise ValueError("key_cache and value_cache must match")
+ # [num_blocks, block_size, 1, D] -> [num_blocks * block_size, D]
+ nb, bs, hkv, d = key_cache.shape
+ if hkv != 1:
+ raise ValueError("expected num_kv_heads==1")
+ k_flat = key_cache.reshape(nb * bs, d)
+ v_flat = value_cache.reshape(nb * bs, d)
+ if not k_flat.is_contiguous():
+ k_flat = k_flat.contiguous()
+ if not v_flat.is_contiguous():
+ v_flat = v_flat.contiguous()
+ return k_flat, v_flat
+
+
+def block_table_to_global_page_table(
+ block_table: torch.Tensor,
+ num_kv_heads: int,
+ max_batches: int,
+) -> torch.Tensor:
+ """Build [max_batches, HKV, num_logical_pages] int32 page table.
+
+ For MQA, every KV head reuses the same physical block ids as vLLM's table.
+ """
+ # block_table: [num_reqs_padded, max_num_blocks]
+ bsz, max_lp = block_table.shape
+ if bsz > max_batches:
+ raise ValueError("batch size exceeds max_batches for page table")
+ out = torch.zeros(
+ (max_batches, num_kv_heads, max_lp),
+ dtype=torch.int32,
+ device=block_table.device,
+ )
+ bt = block_table.to(torch.int32)[:bsz]
+ if num_kv_heads == 1:
+ out[:bsz, 0, :max_lp] = bt
+ else:
+ for h in range(num_kv_heads):
+ out[:bsz, h, :max_lp] = bt
+ return out
+
+
+def build_batch_mapping(num_reqs: int, device: torch.device) -> torch.Tensor:
+ """Local batch index -> global batch row (identity)."""
+ return torch.arange(num_reqs, dtype=torch.int32, device=device)
diff --git a/vllm/kvprune/utils/sequence.py b/vllm/kvprune/utils/sequence.py
new file mode 100644
index 0000000000000000000000000000000000000000..489eea3998527f23d373c86f5b93b5c46c68c903
--- /dev/null
+++ b/vllm/kvprune/utils/sequence.py
@@ -0,0 +1,83 @@
+from dataclasses import dataclass, field
+from enum import Enum, auto
+from itertools import count
+from typing import List
+
+from vllm.kvprune.compression.compression_config import SequenceCompressionParams
+from vllm.kvprune.config.sampling_params import SamplingParams
+
+
+class SequenceStatus(Enum):
+ WAITING = auto()
+ RUNNING = auto()
+ FINISHED = auto()
+
+
+@dataclass
+class Sequence:
+ """
+ Represents a single user request / sequence being generated.
+ """
+
+ _counter = count()
+
+ prompt_token_ids: List[int]
+ completion_token_ids: List[int] = field(default_factory=list)
+ sampling_params: SamplingParams = field(default_factory=SamplingParams)
+ compression_params: SequenceCompressionParams = field(
+ default_factory=SequenceCompressionParams
+ )
+ status: SequenceStatus = SequenceStatus.WAITING
+
+ seq_id: int = field(default_factory=lambda: next(Sequence._counter), init=False)
+ num_tokens_processed: int = 0
+
+ @property
+ def num_prompt_tokens(self) -> int:
+ return len(self.prompt_token_ids)
+
+ @property
+ def num_generated_tokens(self) -> int:
+ return len(self.completion_token_ids)
+
+ def add_new_token(self, token_id: int) -> None:
+ if len(self.completion_token_ids) == 0:
+ self.num_tokens_processed += self.num_prompt_tokens
+ self.completion_token_ids.append(token_id)
+ self.num_tokens_processed += 1
+
+ def tokens_to_retain_per_layer(self, num_kv_heads: int) -> int:
+ n = int(
+ self.compression_params.compression_ratio
+ * self.num_prompt_tokens
+ * num_kv_heads
+ )
+ return max(1, n)
+
+ def __getstate__(self):
+ return dict(
+ prompt_token_ids=list(self.prompt_token_ids),
+ completion_token_ids=list(self.completion_token_ids),
+ sampling_params=self.sampling_params,
+ compression_params=self.compression_params,
+ status=self.status,
+ seq_id=self.seq_id,
+ num_tokens_processed=self.num_tokens_processed,
+ )
+
+ def __setstate__(self, state):
+ self.prompt_token_ids = list(state["prompt_token_ids"])
+ self.completion_token_ids = list(state["completion_token_ids"])
+ self.sampling_params = state["sampling_params"]
+ self.compression_params = state["compression_params"]
+ self.status = state["status"]
+ self.seq_id = state["seq_id"]
+ self.num_tokens_processed = state["num_tokens_processed"]
+
+ @property
+ def prompt_len(self) -> int:
+ return len(self.prompt_token_ids)
+
+ @property
+ def completion_len(self) -> int:
+ return len(self.completion_token_ids)
diff --git a/vllm/kvprune/utils/tp_collectives.py b/vllm/kvprune/utils/tp_collectives.py
new file mode 100644
index 0000000000000000000000000000000000000000..855792aa8f524dcf94e8655a307d9fcc64721261
--- /dev/null
+++ b/vllm/kvprune/utils/tp_collectives.py
@@ -0,0 +1,48 @@
+"""Tensor-parallel collectives for kvprune (match vLLM TP process group when embedded)."""
+
+from __future__ import annotations
+
+import torch.distributed as dist
+
+
+def tensor_parallel_all_reduce(tensor: torch.Tensor) -> torch.Tensor:
+ """All-reduce across tensor-parallel ranks (in-place on ``tensor`` when possible).
+
+ When vLLM :mod:`vllm.distributed.parallel_state` is initialized (e.g. kvprune
+ runs inside a vLLM GPU worker), uses the same TP NCCL group as the main model
+ (:func:`~vllm.distributed.communication_op.tensor_model_parallel_all_reduce`).
+
+ vLLM's TP :meth:`~vllm.distributed.parallel_state.GroupCoordinator.all_reduce`
+ is **out-of-place** and returns a new tensor. Call sites such as
+ :class:`~vllm.kvprune.layers.linear.RowParallelLinear` historically invoked
+ ``tensor_parallel_all_reduce(y)`` without using the return value, which left
+ ``y`` as the **unreduced** per-rank partial output under TP>1 — wrong activations,
+ wrong logits, and garbage tokens. We copy the reduced result back into ``tensor``
+ so existing call sites remain correct.
+
+ Standalone kvprune subprocesses only have the default process group (world ==
+ ``tensor_parallel_size``); in that case we fall back to :func:`torch.distributed.all_reduce`
+ on the default group.
+ """
+ if not dist.is_initialized() or dist.get_world_size() <= 1:
+ return tensor
+ try:
+ from vllm.distributed.parallel_state import model_parallel_is_initialized
+
+ if model_parallel_is_initialized():
+ from vllm.distributed.communication_op import (
+ tensor_model_parallel_all_reduce as vllm_tp_all_reduce,
+ )
+
+ reduced = vllm_tp_all_reduce(tensor)
+ if reduced is not tensor:
+ # vLLM TP all_reduce is out-of-place: `reduced` holds the cross-rank sum.
+ # Call sites ignore the return value and expect `tensor` to be updated — we
+ # MUST materialize the reduced values here or TP>1 keeps per-rank partials
+ # (RowParallel / VocabParallel outputs stay wrong without this copy).
+ tensor.copy_(reduced)
+ return tensor
+ except Exception:
+ pass
+ dist.all_reduce(tensor)
+ return tensor
diff --git a/vllm/kvprune/utils/tp_utils.py b/vllm/kvprune/utils/tp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e829f1f400ba5a162a9c6a1164f708fa229cfa9
--- /dev/null
+++ b/vllm/kvprune/utils/tp_utils.py
@@ -0,0 +1,40 @@
+"""Tensor-parallel helpers for kvprune when embedded in a vLLM worker."""
+
+from __future__ import annotations
+
+import torch.distributed as dist
+
+
+def tensor_parallel_rank_for_sharding() -> int:
+ """Rank within the tensor-parallel group (matches vLLM weight shards when embedded).
+
+ Falls back to :func:`torch.distributed.get_rank` when vLLM parallel state is
+ unavailable (standalone kvprune with only the default process group).
+ """
+ try:
+ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
+
+ return int(get_tensor_model_parallel_rank())
+ except Exception:
+ if dist.is_initialized():
+ return int(dist.get_rank())
+ return 0
+
+
+def tensor_parallel_world_size_for_sharding() -> int:
+ """World size of the tensor-parallel group."""
+ try:
+ from vllm.distributed.parallel_state import (
+ get_tensor_model_parallel_world_size,
+ )
+
+ return int(get_tensor_model_parallel_world_size())
+ except Exception:
+ if dist.is_initialized():
+ return int(dist.get_world_size())
+ return 1
+
+
+def kv_heads_shard_divisor() -> int:
+ """Return world size used to shard KV heads (TP group when vLLM is loaded)."""
+ return tensor_parallel_world_size_for_sharding()
diff --git a/vllm/kvprune/utils/triton_compat.py b/vllm/kvprune/utils/triton_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..89c5bc753fcff3915e4c2f713860841d3f635af5
--- /dev/null
+++ b/vllm/kvprune/utils/triton_compat.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+import inspect
+import os
+from typing import Any, Callable, Mapping
+
+import torch
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+_cache_results_warned = False
+
+
+def _ensure_kvprune_triton_cache_dir() -> None:
+ """Set a stable Triton cache dir for kvprune kernels unless already set."""
+ if os.environ.get("TRITON_CACHE_DIR"):
+ return
+ cache_root = os.environ.get("VLLM_CACHE_ROOT", os.path.expanduser("~/.cache/vllm"))
+ triton_cache = os.path.join(cache_root, "kvprune_triton_cache")
+ os.makedirs(triton_cache, exist_ok=True)
+ os.environ["TRITON_CACHE_DIR"] = triton_cache
+
+
+_ensure_kvprune_triton_cache_dir()
+
+
+def _filter_kwargs_for_callable(
+ fn: Callable[..., Any], kwargs: Mapping[str, Any]
+) -> dict[str, Any]:
+ try:
+ params = inspect.signature(fn).parameters
+ except (TypeError, ValueError):
+ return dict(kwargs)
+ return {k: v for k, v in kwargs.items() if k in params}
+
+
+def autotune(*, configs, key, **kwargs):
+ """
+ Compatibility wrapper around `triton.autotune`.
+
+ Some Triton builds (e.g., custom vendor builds) may not support newer
+ keyword arguments like `cache_results`. This wrapper filters unsupported
+ kwargs based on the runtime `triton.autotune` signature.
+ """
+ import triton
+
+ filtered = _filter_kwargs_for_callable(triton.autotune, kwargs)
+ global _cache_results_warned
+ if (
+ not _cache_results_warned
+ and "cache_results" in kwargs
+ and "cache_results" not in filtered
+ ):
+ logger.warning_once(
+ "Current Triton build does not accept cache_results in triton.autotune; "
+ "kvprune autotune results may not persist across runs."
+ )
+ _cache_results_warned = True
+ return triton.autotune(configs=configs, key=key, **filtered)
+
+
+def maybe_set_allocator(alloc_fn: Callable[[int, int, int | None], Any]) -> bool:
+ """
+ Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
+
+ Returns True if the allocator was set.
+ """
+ import triton
+
+ setter = getattr(triton, "set_allocator", None)
+ if setter is None:
+ return False
+ setter(alloc_fn)
+ return True
+
+
+def cuda_capability_geq(major: int, minor: int = 0, device: int | None = None) -> bool:
+ """
+ Host-side CUDA capability check that works even when `tl.target_info` is absent.
+ """
+ if not torch.cuda.is_available():
+ return False
+ if device is None:
+ try:
+ device = torch.cuda.current_device()
+ except Exception:
+ device = 0
+ cap = torch.cuda.get_device_capability(device)
+ return cap >= (major, minor)
diff --git a/vllm/kvprune_legacy_save/__init__.py b/vllm/kvprune_legacy_save/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c351684196503dd1c0777c160a700f201f47f288
--- /dev/null
+++ b/vllm/kvprune_legacy_save/__init__.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+KV-cache pruning (compactor-style) under ``vllm.kvprune``.
+
+Use the standard :class:`~vllm.LLM` and pass ``compression=`` to :meth:`~vllm.LLM.generate`
+with :class:`CompressionParams` when any prompt needs ``compression_ratio < 1``. The compactor
+``LLMEngine`` + ``PagedKVCache`` shares weights with vLLM (no second checkpoint).
+
+Subpackages (``attention``, ``kv_cache``, ``compression``, …) implement the compactor
+engine.
+"""
+
+from vllm.kvprune.compression.compression_config import CompressionMethod
+from vllm.kvprune.integration import CompressionParams
+
+__all__ = [
+ "CompressionMethod",
+ "CompressionParams",
+]
diff --git a/vllm/kvprune_legacy_save/attention/__init__.py b/vllm/kvprune_legacy_save/attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0c5bb5b76552eb0d0f03cdfa04f36218699ba69
--- /dev/null
+++ b/vllm/kvprune_legacy_save/attention/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Sparse attention Triton kernels (varlen prefill, decode, compile helpers)."""
+
+from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
+
+__all__ = ["causal_sparse_varlen_with_cache"]
diff --git a/vllm/kvprune_legacy_save/attention/compile_kernels.py b/vllm/kvprune_legacy_save/attention/compile_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cc9fc31c48e45bbc74c5f17b4d64cc5f79ffcb5
--- /dev/null
+++ b/vllm/kvprune_legacy_save/attention/compile_kernels.py
@@ -0,0 +1,261 @@
+import argparse
+import logging
+import math
+
+import torch
+from vllm.kvprune.attention.sparse_varlen_kernel import (
+ causal_sparse_varlen_with_cache,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def build_mock_paged_cache_from_lengths(
+ L_cache_per_b: torch.Tensor,
+ HKV: int,
+ D: int,
+ PAGE_SIZE: int,
+ N_LOGICAL_PAGES_MAX: int,
+ device,
+ dtype,
+):
+ B = len(L_cache_per_b)
+ max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
+ assert (L_cache_per_b <= max_len).all()
+
+ seq_lens_bh = torch.empty((B, HKV), dtype=torch.int32, device=device)
+ for b in range(B):
+ seq_lens_bh[b, :].fill_(L_cache_per_b[b])
+
+ num_phys_pages = B * HKV * N_LOGICAL_PAGES_MAX
+ CACHE_SIZE = num_phys_pages * PAGE_SIZE
+
+ K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
+ V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
+ page_table = torch.empty(
+ (B, HKV, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
+ )
+
+ # assign unique physical pages per (b, h, lp)
+ phys_page = 0
+ for b in range(B):
+ for h in range(HKV):
+ for lp in range(N_LOGICAL_PAGES_MAX):
+ page_table[b, h, lp] = phys_page
+ phys_page += 1
+
+ for b in range(B):
+ Lc = int(L_cache_per_b[b].item())
+ for h in range(HKV):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b, h, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_cache[idx] = torch.randn(D, device=device, dtype=dtype)
+ V_cache[idx] = torch.randn(D, device=device, dtype=dtype)
+
+ return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
+
+
+def autotune_causal_sparse_varlen_with_cache(
+ *,
+ max_length: int = 16384,
+ HKV: int = 8,
+ HQ: int = 32,
+ D: int = 128,
+ PAGE_SIZE: int = 128,
+ device: str = "cuda",
+ dtype=torch.float16,
+):
+ """
+ Autotune causal_sparse_varlen_with_cache over a sweep of cache/append lengths.
+ """
+ import itertools
+
+ import tqdm
+
+ N_LOGICAL_PAGES_MAX = ((max_length + PAGE_SIZE - 1) // PAGE_SIZE) * PAGE_SIZE
+ B = 4
+
+ # D must be a power of two (kernel requirement).
+ assert (D & (D - 1)) == 0
+
+ lengths_to_sweep = [0, 256]
+ i = 9
+ while (v := (1 << i)) < max_length:
+ lengths_to_sweep.append(v)
+ i += 1
+
+ combos = list(itertools.product(lengths_to_sweep, repeat=2))
+ logger.info(
+ "tuning kernels. this may take a few minutes, "
+ "but only needs to be run once per LLMConfig"
+ )
+
+ for cache_l, append_l in tqdm.tqdm(combos):
+ if cache_l + append_l == 0:
+ continue
+
+ L_cache_per_b = torch.tensor(
+ [cache_l] * B,
+ device=device,
+ dtype=torch.int32,
+ )
+ assert (L_cache_per_b <= PAGE_SIZE * N_LOGICAL_PAGES_MAX).all()
+ K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
+ build_mock_paged_cache_from_lengths(
+ L_cache_per_b=L_cache_per_b,
+ HKV=HKV,
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ L_app_list = [append_l] * B
+ cu = [0]
+ for L in L_app_list:
+ cu.append(cu[-1] + L)
+ cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
+ N = int(cu_seqlens_qk[-1].item())
+
+ max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
+ max_seqlen_k = seq_lens_bh.max().item()
+ q_raw = torch.randn(N, HQ, D, device=device, dtype=dtype)
+ k_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
+ v_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
+
+ # Identity batch mapping (local batch index == global)
+ batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
+
+ sm_scale = 1.0 / math.sqrt(D)
+
+ causal_sparse_varlen_with_cache(
+ q=q_raw,
+ k_cache=K_cache,
+ v_cache=V_cache,
+ k=k_append_raw,
+ v=v_append_raw,
+ seq_lens_bh=seq_lens_bh,
+ global_page_table=page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=cu_seqlens_qk,
+ HKV=HKV,
+ PAGE_SIZE=PAGE_SIZE,
+ sm_scale=sm_scale,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k_cache=max_seqlen_k,
+ )
+
+
+def _parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Autotune Triton kernels. "
+ "Results are cached, so this should only need to be run once per configuration."
+ "This script doesn't need to be run, as the kernels will be autotuned at runtime"
+ "if no cached autotuning data exists. Running this before hand will prevent run-time"
+ "autotuning, which will accelerate compactor-vllm at inference time."
+ )
+ parser.add_argument(
+ "--max-length",
+ type=int,
+ default=16384,
+ help="Maximum total sequence length to consider.",
+ )
+ parser.add_argument(
+ "--HKV",
+ type=int,
+ default=8,
+ help="Number of KV heads.",
+ )
+ parser.add_argument(
+ "--HQ",
+ type=int,
+ default=32,
+ help="Number of query heads.",
+ )
+ parser.add_argument(
+ "--D",
+ type=int,
+ default=128,
+ help="Per-head hidden dimension (must be power of 2).",
+ )
+ parser.add_argument(
+ "--page-size",
+ type=int,
+ default=128,
+ help="Page size (tokens per physical page).",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Torch device to run on (e.g. 'cuda', 'cuda:0', 'cpu').",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ default="float16",
+ help="Dtype for tensors: one of {float16, fp16, bfloat16, bf16, float32, fp32}.",
+ )
+ parser.add_argument(
+ "--log-level",
+ type=str,
+ default="INFO",
+ choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
+ help="Logging level.",
+ )
+ return parser.parse_args()
+
+
+def _resolve_dtype(dtype_str: str):
+ s = dtype_str.lower()
+ if s in ("float16", "fp16", "half"):
+ return torch.float16
+ if s in ("bfloat16", "bf16"):
+ return torch.bfloat16
+ if s in ("float32", "fp32"):
+ return torch.float32
+ raise ValueError(f"Unsupported dtype: {dtype_str}")
+
+
+def main():
+ args = _parse_args()
+ logging.basicConfig(
+ level=getattr(logging, args.log_level.upper()),
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ )
+
+ dtype = _resolve_dtype(args.dtype)
+ logger.info(
+ "Starting autotune with max_length=%d, HKV=%d, HQ=%d, D=%d, page_size=%d, "
+ "device=%s, dtype=%s",
+ args.max_length,
+ args.HKV,
+ args.HQ,
+ args.D,
+ args.page_size,
+ args.device,
+ dtype,
+ )
+
+ autotune_causal_sparse_varlen_with_cache(
+ max_length=args.max_length,
+ HKV=args.HKV,
+ HQ=args.HQ,
+ D=args.D,
+ PAGE_SIZE=args.page_size,
+ device=args.device,
+ dtype=dtype,
+ )
+
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s: %(message)s",
+ )
+ main()
diff --git a/vllm/kvprune_legacy_save/attention/fa_paged_bridge.py b/vllm/kvprune_legacy_save/attention/fa_paged_bridge.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf4e4bf2c1aa6a18b4b1c9f5ec359b87ff47b1a
--- /dev/null
+++ b/vllm/kvprune_legacy_save/attention/fa_paged_bridge.py
@@ -0,0 +1,192 @@
+# SPDX-License-Identifier: Apache-2.0
+"""FlashAttention paths over compactor paged KV (materialize + FA ops).
+
+Used when :class:`~vllm.kvprune.config.engine_config.KvpruneAttentionSchedule`
+selects FlashAttention for prefill and/or decode while KV **writes** remain on
+Triton (``prefill_store_*``, ``decode_store_kv``). Matches the reference checks
+in ``vllm/compactor-vllm/tests/test_triton_attention.py``.
+"""
+
+from __future__ import annotations
+
+import math
+from typing import TYPE_CHECKING
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+
+if TYPE_CHECKING:
+ pass
+
+
+def materialize_kv_for_flash_prefill(
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ L_cache_per_b: torch.Tensor,
+ k_append: torch.Tensor,
+ v_append: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ H_kv: int,
+ PAGE_SIZE: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Build packed K/V for :func:`flash_attn_varlen_func` (cache prefix + append)."""
+ device = k_cache.device
+ dtype = k_cache.dtype
+ B = cu_seqlens_q.numel() - 1
+ N, H_kv_raw, D = k_append.shape
+ assert H_kv_raw == H_kv
+ L_app = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(torch.int32)
+ seqlen_k = L_cache_per_b.to(torch.int32) + L_app
+
+ cu_seqlens_k = torch.empty(B + 1, device=device, dtype=torch.int32)
+ cu_seqlens_k[0] = 0
+ total_k = int(seqlen_k.sum().item())
+ K_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
+ V_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
+
+ for b in range(B):
+ offset_k = int(cu_seqlens_k[b].item())
+ Lc = int(L_cache_per_b[b].item())
+ La = int(L_app[b].item())
+ q_start = int(cu_seqlens_q[b].item())
+ b_true = int(batch_mapping[b].item())
+
+ for g in range(H_kv):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b_true, g, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_total[offset_k + i, g] = k_cache[idx]
+ V_total[offset_k + i, g] = v_cache[idx]
+
+ for g in range(H_kv):
+ for j in range(La):
+ src = q_start + j
+ dst = offset_k + Lc + j
+ K_total[dst, g] = k_append[src, g]
+ V_total[dst, g] = v_append[src, g]
+
+ cu_seqlens_k[b + 1] = cu_seqlens_k[b] + (Lc + La)
+
+ return K_total, V_total, cu_seqlens_k
+
+
+def flash_prefill_from_paged(
+ q: torch.Tensor,
+ k_append: torch.Tensor,
+ v_append: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ *,
+ seq_lens_bh_before: torch.Tensor,
+ global_page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ max_seqlen_q: int,
+ PAGE_SIZE: int,
+ HKV: int,
+ sm_scale: float | None,
+) -> torch.Tensor:
+ """Prefill attention via FlashAttention-2 varlen after materializing paged KV + append."""
+ L_cache_per_b = seq_lens_bh_before.max(dim=1).values.to(torch.int32)
+ K_total, V_total, cu_seqlens_k = materialize_kv_for_flash_prefill(
+ k_cache,
+ v_cache,
+ global_page_table,
+ batch_mapping,
+ L_cache_per_b,
+ k_append,
+ v_append,
+ cu_seqlens_q,
+ HKV,
+ PAGE_SIZE,
+ )
+ max_seqlen_k = int((cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item())
+ return flash_attn_varlen_func(
+ q,
+ K_total,
+ V_total,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ softmax_scale=sm_scale if sm_scale is not None else None,
+ causal=True,
+ )
+
+
+def materialize_kv_cache_for_flash_decode(
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ L_cache_per_b: torch.Tensor,
+ H_kv: int,
+ PAGE_SIZE: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Dense ``[B, S, H_kv, D]`` cache for :func:`flash_attn_func` decode."""
+ device = k_cache.device
+ dtype = k_cache.dtype
+ B = L_cache_per_b.shape[0]
+ D = k_cache.shape[1]
+
+ seqlen_cache_max = int(L_cache_per_b.max().item())
+ K_flash = torch.zeros((B, seqlen_cache_max, H_kv, D), device=device, dtype=dtype)
+ V_flash = torch.zeros_like(K_flash)
+
+ for b in range(B):
+ Lc = int(L_cache_per_b[b].item())
+ if Lc == 0:
+ continue
+ b_true = int(batch_mapping[b].item())
+ for g in range(H_kv):
+ for i in range(Lc):
+ lp = i // PAGE_SIZE
+ off = i % PAGE_SIZE
+ phys = int(page_table[b_true, g, lp].item())
+ idx = phys * PAGE_SIZE + off
+ K_flash[b, i, g] = k_cache[idx]
+ V_flash[b, i, g] = v_cache[idx]
+
+ return K_flash, V_flash
+
+
+def flash_decode_from_paged(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ *,
+ seq_lens_bh: torch.Tensor,
+ global_page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ PAGE_SIZE: int,
+ HKV: int,
+ sm_scale: float | None,
+) -> torch.Tensor:
+ """Decode step via FA: ``decode_store_kv`` has already appended the new K/V row."""
+ L_cache_per_b = seq_lens_bh.max(dim=1).values.to(torch.int32)
+ K_flash, V_flash = materialize_kv_cache_for_flash_decode(
+ k_cache,
+ v_cache,
+ global_page_table,
+ batch_mapping,
+ L_cache_per_b,
+ HKV,
+ PAGE_SIZE,
+ )
+ B, HQ, D = q.shape
+ q_b = q.unsqueeze(1)
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+ # One query position attends to all L keys already materialized in K/V (no causal mask).
+ out = flash_attn_func(
+ q_b,
+ K_flash,
+ V_flash,
+ softmax_scale=sm_scale,
+ causal=False,
+ )
+ return out.squeeze(1)
diff --git a/vllm/kvprune_legacy_save/attention/sparse_decode_kernel.py b/vllm/kvprune_legacy_save/attention/sparse_decode_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..77e2a6abd4965bb0b49bed80864d5a14d2ada344
--- /dev/null
+++ b/vllm/kvprune_legacy_save/attention/sparse_decode_kernel.py
@@ -0,0 +1,401 @@
+import functools
+import math
+
+import torch
+import triton
+import triton.language as tl
+
+from vllm.kvprune.utils.triton_compat import (
+ autotune as triton_autotune,
+ maybe_set_allocator,
+)
+
+
+def head_sparse_decode_attention(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_lens_bh: torch.Tensor,
+ global_page_table: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ HKV: int,
+ PAGE_SIZE: int,
+ sm_scale: float = None,
+ key_split: int = None,
+):
+ """
+ Decode-time head-sparse attention over a paged KV cache.
+
+ This is a wrapper around the Triton decode kernel used during incremental
+ generation. For each batch, we read the cached keys
+ and values from a global paged KV buffer, apply causal attention with one
+ new query token, and return the attention output.
+
+ The KV cache is stored in a single global K/V tensor of shape
+ ``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
+ (batch, kv_head, token_idx) is mapped to a physical row in the cache by:
+
+ 1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
+ 2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.
+
+ Grouped-query attention (GQA / MQA) is supported by passing more query
+ heads than KV heads (``HQ`` must be a multiple of ``HKV``).
+
+ Args:
+ :param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]`
+ containing the new decode tokens for each sequence in the launch batch.
+ :param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
+ backing buffer for all (batch, head) KV pages.
+ :param v: Global value cache of shape ``[CACHE_SIZE, D]``.
+ :param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
+ local batch index and KV head, the number of valid cached tokens
+ in the paged KV cache.
+ :param global_page_table: Tensor of shape
+ ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
+ ``(true_batch_idx, kv_head, logical_page)`` to a physical page id
+ in the global cache.
+ :param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
+ index used by this call to the true batch row used to index
+ ``global_page_table``.
+ :param HKV: Number of KV heads.
+ :param PAGE_SIZE: Number of tokens stored per physical KV page.
+ :param sm_scale: Optional scaling factor applied to the attention logits
+ before softmax. If ``None``, ``1 / sqrt(D)`` is used.
+ :param key_split: Optional number of splits along the key sequence length.
+ If > 1, the kernel will process the KV sequence in ``key_split``
+ chunks to reduce on-chip memory usage. If ``None`` or 0, a
+ heuristic is used.
+
+ Returns:
+ :return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
+ device and dtype as ``q``.
+ """
+
+ with torch.cuda.device(q.device):
+ if q.ndim != 3:
+ assert q.ndim == 4
+ B, HQ, S, D = q.shape
+ assert S == 1, "head_sparse_decode_attention only supports q_len=1"
+ q = q.squeeze(-2)
+ elif q.ndim == 3:
+ B, HQ, D = q.shape
+
+ CACHE_SIZE = k.shape[0]
+ assert PAGE_SIZE % 32 == 0, "PAGE_SIZE must be divisible by 32"
+ GROUP_M = HQ // HKV
+ assert GROUP_M * HKV == HQ, "HQ must be divisible by H_kv"
+
+ FP8 = hasattr(torch, "float8_e5m2") and q.dtype == torch.float8_e5m2
+
+ seq_lens_bh = seq_lens_bh.to(torch.int32)
+ assert B <= 32767, "too many batches"
+ assert global_page_table.shape[1] == HKV
+ assert q.is_contiguous()
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
+
+ sm_scale = 1 / math.sqrt(D) if sm_scale is None else sm_scale
+ if key_split is None:
+ # round max_seq_len to the next power of two to maximize cache hits
+ key_split = num_splits_heuristic(
+ B * HKV,
+ max_seq_len=1 << int(seq_lens_bh.max()).bit_length(),
+ num_sms=torch.cuda.get_device_properties(
+ q.device
+ ).multi_processor_count,
+ max_splits=12,
+ )
+
+ maybe_set_allocator(
+ lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
+ )
+
+ # stage 1 scratch
+ mid_o = torch.empty((B, key_split, HQ, D), device=q.device, dtype=q.dtype)
+ mid_lse = torch.empty((B, key_split, HQ), device=q.device, dtype=torch.float32)
+ # processes all queries for a KV head together
+ # pointers are lowercase, CONSTANTS are upper
+ grid1 = (B, HKV, key_split)
+ _varkv_stage1_groupM[grid1](
+ q=q,
+ k=k,
+ v=v,
+ mid_o=mid_o,
+ mid_lse=mid_lse,
+ page_table_bhl=global_page_table,
+ batch_mapping=batch_mapping,
+ seq_lens_bh=seq_lens_bh.contiguous(),
+ SM_SCALE=sm_scale,
+ B=B,
+ HKV=HKV,
+ HQ=HQ,
+ CACHE_SIZE=CACHE_SIZE,
+ STRIDE_LBS=mid_lse.stride(0),
+ STRIDE_LS=mid_lse.stride(1),
+ STRIDE_LH=mid_lse.stride(2),
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ D=D,
+ KEY_SPLIT=key_split,
+ GROUP_M=GROUP_M,
+ DTYPE=tl.float8e5
+ if FP8
+ else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
+ PAGE_SIZE=PAGE_SIZE,
+ )
+
+ if key_split == 1:
+ return mid_o.squeeze(1).contiguous()
+
+ # reduce partial results across splits
+ output = torch.empty_like(q)
+ grid2 = (B, HQ)
+ _varkv_stage2_reduce[grid2](
+ mid_o=mid_o,
+ mid_lse=mid_lse,
+ output=output,
+ STRIDE_LBS=mid_lse.stride(0),
+ STRIDE_LS=mid_lse.stride(1),
+ STRIDE_LH=mid_lse.stride(2),
+ STRIDE_OBS=output.stride(0),
+ STRIDE_OH=output.stride(1),
+ B=B,
+ HQ=HQ,
+ D=D, # type: ignore
+ KEY_SPLIT=key_split, # type: ignore
+ DTYPE=tl.float8e5
+ if FP8
+ else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
+ )
+ return output
+
+
+# similar to flash attention split heuristic
+@functools.lru_cache(maxsize=128)
+def num_splits_heuristic(
+ total_mblocks: int,
+ max_seq_len: int,
+ num_sms: int,
+ max_splits: int,
+) -> int:
+ # If we nearly fill SMs already, prefer 1 split
+ if total_mblocks >= 0.8 * num_sms or max_seq_len <= 1024:
+ return 1
+ eff = []
+ max_eff = 0.0
+ for s in range(1, min(max_splits, num_sms) + 1):
+ if (max_seq_len / s) <= 512:
+ break
+ n_waves = float(total_mblocks * s) / float(num_sms)
+ e = n_waves / math.ceil(n_waves) if n_waves > 0 else 0.0
+ eff.append(e)
+ max_eff = max(max_eff, e)
+ threshold = 0.75 * max_eff # if not split_min_hit else 0.9 * max_eff
+ for i, e in enumerate(eff, start=1):
+ if e >= threshold:
+ return i
+ return 1
+
+
+def prune_invalid_configs(configs, _, **kwargs):
+ PAGE_SIZE = kwargs["PAGE_SIZE"]
+ return [conf for conf in configs if conf.kwargs.get("BLOCK_N", 0) <= PAGE_SIZE]
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_N": BLOCK_N, "MIN_BLOCK_KV": MIN_BLOCK_KV, "WARPSPEC": ws},
+ num_warps=w,
+ num_stages=s,
+ )
+ for BLOCK_N in [32, 64, 128]
+ for MIN_BLOCK_KV in [8]
+ for s in [2, 3, 4]
+ for w in [4, 8]
+ for ws in [True, False]
+ ],
+ key=[
+ "HKV",
+ "GROUP_M",
+ "D",
+ "PAGE_SIZE", # "B"
+ ],
+ cache_results=True,
+ prune_configs_by={"early_config_prune": prune_invalid_configs},
+)
+@triton.jit
+def _varkv_stage1_groupM(
+ q, # [B, HQ, D] contiguous
+ k, # GLOBAL cache: [CACHE_SIZE, D], contiguous
+ v, # GLOBAL cache: [CACHE_SIZE, D], contiguous
+ mid_o,
+ mid_lse,
+ page_table_bhl, # int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
+ batch_mapping, # int32 [B] maps local pid_b -> true batch index
+ seq_lens_bh, # int32 [B*H_kv] valid tokens per (b,h)
+ SM_SCALE,
+ B,
+ HKV,
+ HQ,
+ CACHE_SIZE, # CACHE_SIZE = N_PAGES * PAGE_SIZE
+ STRIDE_LBS,
+ STRIDE_LS,
+ STRIDE_LH,
+ # constexprs
+ N_LOGICAL_PAGES_MAX: tl.constexpr, # page table width per (b,h)
+ D: tl.constexpr,
+ KEY_SPLIT: tl.constexpr,
+ GROUP_M: tl.constexpr,
+ DTYPE: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ MIN_BLOCK_KV: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+):
+ pid_b = tl.program_id(0) # batch
+ pid_kvh = tl.program_id(1) # kv head
+ pid_s = tl.program_id(2) # split
+
+ # valid length L for this (b,h)
+ bh_stride = HKV
+ L = tl.load(seq_lens_bh + pid_b * bh_stride + pid_kvh)
+ if L == 0:
+ return
+
+ tl.assume(L > 0)
+
+ # split sizing on logical token axis [0..L)
+ base = tl.cdiv(L, KEY_SPLIT)
+ per_split_len = tl.cdiv(base, MIN_BLOCK_KV) * MIN_BLOCK_KV
+ split_start = pid_s * per_split_len
+ split_end = tl.minimum(split_start + per_split_len, L)
+
+ # query heads mapped to this kv head
+ base_qh = pid_kvh * GROUP_M
+ GROUP_M_PAD: tl.constexpr = 16 if GROUP_M < 16 else GROUP_M
+ offs_m = tl.arange(0, GROUP_M_PAD)
+ mask_m = offs_m < GROUP_M
+ offs_d = tl.arange(0, D)
+
+ # load Q tile [M, D]
+ q_ptrs = q + (pid_b * HQ + base_qh + offs_m)[:, None] * D + offs_d[None, :]
+ q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0).to(DTYPE) # [M, D]
+
+ # streaming softmax state per query
+ e_max = tl.zeros([GROUP_M_PAD], dtype=tl.float32) - float("inf")
+ e_sum = tl.zeros([GROUP_M_PAD], dtype=tl.float32)
+ acc = tl.zeros([GROUP_M_PAD, D], dtype=tl.float32)
+
+ if split_end > split_start:
+ # logical pages covering [split_start, split_end)
+ lp0 = split_start // PAGE_SIZE
+ lp1 = tl.cdiv(split_end, PAGE_SIZE) # exclusive
+
+ mapped_b = tl.load(batch_mapping + pid_b)
+ tl.assume(mapped_b >= 0)
+ # page table base for this (b,h)
+ pt_stride = N_LOGICAL_PAGES_MAX
+ pt_base = (mapped_b * HKV + pid_kvh) * pt_stride
+
+ for lp in tl.range(lp0, lp1):
+ phys = tl.load(
+ page_table_bhl + pt_base + lp, cache_modifier=".cg"
+ ) # physical page id
+ # bounds within the logical page
+ local_start = tl.where(lp == lp0, split_start - lp * PAGE_SIZE, 0)
+ local_end = tl.where(lp == (lp1 - 1), split_end - lp * PAGE_SIZE, PAGE_SIZE)
+
+ page_base = phys * PAGE_SIZE
+ page_base = tl.multiple_of(page_base, BLOCK_N)
+ for s in tl.range(local_start, local_end, BLOCK_N):
+ s = tl.multiple_of(s, MIN_BLOCK_KV)
+ offs_bn = tl.arange(0, BLOCK_N)
+ key_idx = page_base + s + offs_bn
+ k_ptrs = k + key_idx[:, None] * D + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
+ qk = tl.dot(q, k_blk.T) * SM_SCALE # [M, BN]
+
+ offs_n = s + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < local_end
+ qk = tl.where(mask_n[None, :], qk, -float("inf"))
+
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max) # [M]
+ re_scale = tl.exp(e_max - n_e_max) # [M]
+ acc = acc * re_scale[:, None] # [M, D]
+ v_ptrs = v + key_idx[:, None] * D + offs_d[None, :]
+ v_blk = tl.load(v_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
+ p = tl.exp(qk - n_e_max[:, None]) # [M, BN]
+ acc = tl.dot(p.to(DTYPE), v_blk, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # write mid outputs [M, D] for this split
+ tmp = (acc / e_sum[:, None]).to(DTYPE)
+ row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
+ mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
+ tl.store(mid_ptrs, tmp, mask=mask_m[:, None])
+
+ ml_ptrs = (
+ mid_lse
+ + pid_b * STRIDE_LBS
+ + pid_s * STRIDE_LS
+ + (base_qh + offs_m) * STRIDE_LH
+ )
+ safe_sum = tl.where(mask_m, e_sum, 1.0)
+ tl.store(ml_ptrs, e_max + tl.log(safe_sum), mask=mask_m)
+ else:
+ # empty split
+ zero_md = tl.zeros([GROUP_M_PAD, D], dtype=DTYPE)
+ row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
+ mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
+ tl.store(mid_ptrs, zero_md, mask=mask_m[:, None])
+ ml_ptrs = (
+ mid_lse
+ + pid_b * STRIDE_LBS
+ + pid_s * STRIDE_LS
+ + (base_qh + offs_m) * STRIDE_LH
+ )
+ tl.store(ml_ptrs, -float("inf"), mask=mask_m)
+
+
+@triton.jit
+def _varkv_stage2_reduce(
+ mid_o,
+ mid_lse,
+ output,
+ STRIDE_LBS,
+ STRIDE_LS,
+ STRIDE_LH,
+ STRIDE_OBS,
+ STRIDE_OH,
+ B,
+ HQ,
+ D: tl.constexpr,
+ KEY_SPLIT: tl.constexpr,
+ DTYPE: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_h = tl.program_id(1)
+ offs_d = tl.arange(0, D)
+
+ # across split LSE combine
+ e_sum = 0.0
+ e_max = -float("inf")
+ acc = tl.zeros([D], dtype=tl.float32)
+
+ for s in tl.range(KEY_SPLIT):
+ row_mid = pid_b * (KEY_SPLIT * HQ) + s * HQ + pid_h
+ tv = tl.load(mid_o + row_mid * D + offs_d).to(DTYPE)
+ tl_ptr = mid_lse + pid_b * STRIDE_LBS + s * STRIDE_LS + pid_h * STRIDE_LH
+ tlogic = tl.load(tl_ptr)
+
+ n_e_max = tl.maximum(e_max, tlogic)
+ old_scale = tl.exp(e_max - n_e_max)
+ acc = acc * old_scale + tl.exp(tlogic - n_e_max) * tv.to(tl.float32)
+ e_sum = e_sum * old_scale + tl.exp(tlogic - n_e_max)
+ e_max = n_e_max
+
+ o = (acc / e_sum).to(DTYPE)
+ o_ptr = output + pid_b * STRIDE_OBS + pid_h * STRIDE_OH + offs_d
+ tl.store(o_ptr, o)
diff --git a/vllm/kvprune_legacy_save/attention/sparse_varlen_kernel.py b/vllm/kvprune_legacy_save/attention/sparse_varlen_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec1af5cc19d45ef6dca36d3a2f7bdab36a462164
--- /dev/null
+++ b/vllm/kvprune_legacy_save/attention/sparse_varlen_kernel.py
@@ -0,0 +1,455 @@
+import logging
+import math
+
+import torch
+import triton
+import triton.language as tl
+
+from vllm.kvprune.utils.triton_compat import (
+ autotune as triton_autotune,
+ cuda_capability_geq,
+ maybe_set_allocator,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def causal_sparse_varlen_with_cache(
+ q,
+ k,
+ v,
+ k_cache,
+ v_cache,
+ seq_lens_bh,
+ global_page_table,
+ batch_mapping,
+ cu_seqlens_q,
+ max_seqlen_q: int,
+ max_seqlen_k_cache: int,
+ HKV: int,
+ PAGE_SIZE: int,
+ sm_scale=None,
+):
+ """
+ Causal prefill attention over a paged KV cache plus a block of newly
+ appended tokens in a packed batch format.
+
+ This function wraps the Triton kernel
+ ``_causal_head_sparse_varlen_with_cache`` to compute prefill attention for
+ a batch of variable-length sequences, where:
+ • Past keys/values are stored in a paged global KV cache
+ (``k_cache``, ``v_cache``) and indexed via ``global_page_table``.
+
+ • New tokens for this step are given as K/V blocks (``k``, ``v``)
+ together with a packed query block ``q``.
+
+ Grouped-query attention (GQA / MQA) is supported: ``HQ`` must be divisible
+ by ``HKV``.
+ """
+ assert q.ndim == 3, "q should be [N, HQ, D]"
+ N, HQ, D = q.shape
+ assert (D & (D - 1)) == 0, "D must be power of two"
+
+ B = cu_seqlens_q.numel() - 1
+ assert B > 0
+ assert HQ % HKV == 0, "Number of query heads must divide number of keys heads"
+ H_g = HQ // HKV
+ # view Q as [HKV, N, QUERY_GROUP_SIZE, D]
+ out = torch.empty_like(q)
+ q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+ out = out.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+
+ # K_app/V_app: [N, HKV, D] -> [HKV, N, D]
+ k_app = k.view(N, HKV, D).permute(1, 0, 2)
+ v_app = v.view(N, HKV, D).permute(1, 0, 2)
+
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=q.device)
+ seq_lens_bh = seq_lens_bh.to(dtype=torch.int32, device=q.device)
+ batch_mapping = batch_mapping.to(dtype=torch.int16, device=q.device)
+
+ N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
+ CACHE_SIZE = k_cache.shape[0]
+ assert v_cache.shape[0] == CACHE_SIZE
+ assert k_cache.shape[1] == D and v_cache.shape[1] == D
+ assert PAGE_SIZE > 0 and CACHE_SIZE % PAGE_SIZE == 0
+
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ # strides for Q [G, N, QUERY_GROUP_SIZE, D]
+ STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
+ STRIDE_KC, STRIDE_VC = k_cache.stride(0), v_cache.stride(0)
+ # [G, N, D]
+ STRIDE_KA_G, STRIDE_KA_N, STRIDE_KA_D = k_app.stride()
+ STRIDE_VA_G, STRIDE_VA_N, STRIDE_VA_D = v_app.stride()
+
+ # OUT [G, N, QUERY_GROUP_SIZE, D]
+ STRIDE_OUT_G, STRIDE_OUT_N, STRIDE_OUT_H, STRIDE_OUT_D = out.stride()
+ # launch grid
+ maybe_set_allocator(
+ lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
+ )
+ assert STRIDE_KA_D == STRIDE_VA_D == STRIDE_Q_D == STRIDE_OUT_D == 1, (
+ "final dimension must be contiguous"
+ )
+
+ def grid(META):
+ return HKV, B, triton.cdiv(max_seqlen_q, META["BLOCK_M"])
+
+ # On a fresh batch, max_seqlen_k_cache==0 (no KV prefix yet). Passing
+ # `triton.next_power_of_2(0)` into autotune constexpr keys breaks
+ # kernel selection / tuning and can yield garbage outputs.
+ _k_max_autotune = max(int(max_seqlen_k_cache), 1)
+ AUTOTUNE_MAX_Q_LEN = triton.next_power_of_2(max_seqlen_q)
+ AUTOTUNE_MAX_K_LEN = triton.next_power_of_2(_k_max_autotune)
+ _causal_head_sparse_varlen_with_cache[grid](
+ Q=q,
+ K_cache=k_cache,
+ V_cache=v_cache,
+ K_app=k_app,
+ V_app=v_app,
+ cu_seqlens_qk=cu_seqlens_q,
+ seq_lens_bh=seq_lens_bh,
+ page_table=global_page_table,
+ batch_mapping=batch_mapping,
+ OUT=out,
+ HKV=HKV,
+ QUERY_GROUP_SIZE=H_g,
+ PAGE_SIZE=PAGE_SIZE,
+ N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
+ STRIDE_Q_G=STRIDE_Q_G,
+ STRIDE_Q_N=STRIDE_Q_N,
+ STRIDE_Q_H=STRIDE_Q_H,
+ STRIDE_KC=STRIDE_KC,
+ STRIDE_VC=STRIDE_VC,
+ STRIDE_KA_G=STRIDE_KA_G,
+ STRIDE_KA_N=STRIDE_KA_N,
+ STRIDE_VA_G=STRIDE_VA_G,
+ STRIDE_VA_N=STRIDE_VA_N,
+ STRIDE_OUT_G=STRIDE_OUT_G,
+ STRIDE_OUT_N=STRIDE_OUT_N,
+ STRIDE_OUT_H=STRIDE_OUT_H,
+ sm_scale=sm_scale,
+ D=D,
+ AUTOTUNE_MAX_Q_LEN=AUTOTUNE_MAX_Q_LEN,
+ AUTOTUNE_MAX_K_LEN=AUTOTUNE_MAX_K_LEN,
+ )
+ return out.permute(1, 0, 2, 3).view(N, HQ, D) # already contiguous
+
+
+autotune_configs_cc9 = [
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=16, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=4, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=8, num_stages=4
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=8, num_stages=3
+ ),
+ triton.Config(
+ {"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
+ ),
+]
+
+autotune_configs_cc8 = [
+ triton.Config(
+ {"BLOCK_N": BN, "BLOCK_M": BM, "WARPSPEC": True}, num_warps=w, num_stages=s
+ )
+ for BN in [16, 32]
+ for BM in [64]
+ for w in [4, 8]
+ for s in [2, 3]
+]
+
+
+def prune_invalid_configs(configs, _, **kwargs):
+ return [
+ conf
+ for conf in configs
+ if not (conf.kwargs.get("BLOCK_N") == 32 and conf.kwargs.get("num_stages") == 4)
+ ]
+
+
+def get_autotune_configs():
+ if cuda_capability_geq(9, 0):
+ return autotune_configs_cc9
+ else:
+ return autotune_configs_cc8
+
+
+@triton_autotune(
+ configs=get_autotune_configs(),
+ key=[
+ "HKV",
+ "QUERY_GROUP_SIZE",
+ "D",
+ "PAGE_SIZE",
+ "AUTOTUNE_MAX_K_LEN",
+ "AUTOTUNE_MAX_Q_LEN",
+ ],
+ cache_results=True,
+)
+@triton.jit
+def _causal_head_sparse_varlen_with_cache(
+ Q, # [HKV, N, QUERY_GROUP_SIZE, D] (non-contiguous)
+ K_cache,
+ V_cache, # [CACHE_SIZE, D]
+ K_app,
+ V_app, # [HKV, N, D]
+ cu_seqlens_qk, # [B+1]
+ seq_lens_bh, # [B, HKV]
+ page_table, # [B_total, HKV, N_LOGICAL_PAGES_MAX]
+ batch_mapping, # [B], maps local b -> global batch index
+ OUT, # [HKV, N, QUERY_GROUP_SIZE, D]
+ #
+ HKV: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ N_LOGICAL_PAGES_MAX,
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_KC,
+ STRIDE_VC,
+ STRIDE_KA_G,
+ STRIDE_KA_N,
+ STRIDE_VA_G,
+ STRIDE_VA_N,
+ STRIDE_OUT_G,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ #
+ D: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+ AUTOTUNE_MAX_Q_LEN: tl.constexpr, # used for autotune key
+ AUTOTUNE_MAX_K_LEN: tl.constexpr, # used for autotune key
+):
+ TOTAL_N_QUERIES: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
+ pid_g = tl.program_id(0) # kv_head id in [0, HKV)
+ pid_b = tl.program_id(1) # batch id
+ pid_m = tl.program_id(2) # query-tile id within batch
+
+ # batch segment [qb, qe) in N
+ off_b = tl.load(cu_seqlens_qk + pid_b)
+ off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
+ seq_len_append = off_b1 - off_b
+
+ q_start = off_b + pid_m * BLOCK_M
+ q_end = tl.minimum(q_start + BLOCK_M, off_b1)
+ # number of queries in this tile for this batch
+ M = q_end - q_start
+ if M <= 0:
+ return
+
+ # cached length for (b, kv_head=pid_g)
+ L_cache = tl.load(seq_lens_bh + pid_b * HKV + pid_g)
+ # row indices flattened over [QUERY_GROUP_SIZE, M]
+ offs_row = tl.arange(0, TOTAL_N_QUERIES)
+ row_m = offs_row % BLOCK_M
+ row_h = offs_row // BLOCK_M
+ # valid rows: only those with row_m < M
+ row_mask = row_m < M
+
+ # global query index per row
+ q_idx = q_start + row_m
+ offs_d = tl.arange(0, D)
+ # Q tile: [TOTAL_N_QUERIES, D]
+ # Q layout: [HKV, N, QUERY_GROUP_SIZE, D]
+ q_ptrs = (
+ Q
+ + pid_g * STRIDE_Q_G
+ + q_idx[:, None] * STRIDE_Q_N
+ + row_h[:, None] * STRIDE_Q_H
+ + offs_d[None, :]
+ )
+ q = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+
+ e_max = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32) - float("inf")
+ e_sum = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32)
+ acc = tl.zeros([TOTAL_N_QUERIES, D], dtype=tl.float32)
+
+ offs_block_n = tl.arange(0, BLOCK_N)
+ qk_scale = sm_scale * 1.44269504
+
+ # 1) attend over cachee K/V
+ if L_cache > 0:
+ # map local (b) to global batch index
+ mapped_b = tl.load(batch_mapping + pid_b)
+ pt_base = (mapped_b * HKV + pid_g) * N_LOGICAL_PAGES_MAX
+ # iterate logical pages
+ num_lp = tl.cdiv(L_cache, PAGE_SIZE)
+ for lp in tl.range(0, num_lp):
+ # can overflow in 32 bits so upcast
+ phys = tl.load(page_table + pt_base + lp).to(tl.int64)
+ page_start = phys * PAGE_SIZE
+ # how many valid tokens in this page for this (b,g)
+ remain = L_cache - lp * PAGE_SIZE
+ page_len = tl.minimum(PAGE_SIZE, remain)
+ # iterate over this page in BLOCK_N chunks
+ for ks in tl.range(0, page_len, BLOCK_N):
+ offs_n = ks + offs_block_n
+ mask_n = offs_n < page_len
+
+ key_idx = page_start + offs_n
+ k_ptrs = K_cache + key_idx[:, None] * STRIDE_KC + offs_d[None, :]
+
+ k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
+ qk = tl.dot(q, k.T) * qk_scale # [TOTAL_N_QUERIES, BN]
+ qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
+
+ # softmax update
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ v_ptrs = V_cache + key_idx[:, None] * STRIDE_VC + offs_d[None, :]
+ v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # 2) attend over appended K_app/V_app (causal)
+ # appended tokens for batch b are in [off_b, off_b1)
+ # query tile is [q_start, q_end)
+ # for each query at index q_idx, valid appended keys k satisfy off_b <= k <= q_idx
+ if q_end > off_b:
+ # exactly one appended token
+ if seq_len_append == 1:
+ ka_ptrs = K_app + pid_g * STRIDE_KA_G + off_b * STRIDE_KA_N + offs_d
+ k = tl.load(ka_ptrs) # [D]
+ qk = tl.sum(q * k[None, :], 1) * qk_scale
+ qk = tl.where(row_mask, qk, -1.0e6)
+ n_e_max = tl.maximum(e_max, qk)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max)
+ va_ptrs = V_app + pid_g * STRIDE_VA_G + off_b * STRIDE_VA_N + offs_d
+ v = tl.load(va_ptrs) # [D]
+ acc = acc * re_scale[:, None] + p[:, None] * v[None, :]
+ e_sum = e_sum * re_scale + p
+ else:
+ # off-band: k in [off_b, q_start)
+ # for all queries t in [q_start, q_end), any k < q_start satisfies k <= t.
+ # so no causal mask needed.
+ off_band_start = off_b
+ off_band_end = q_start
+
+ if off_band_end > off_band_start:
+ for ks in tl.range(off_band_start, off_band_end, BLOCK_N):
+ offs_n = ks + offs_block_n
+ mask_n = offs_n < off_band_end
+
+ ka_ptrs = (
+ K_app
+ + pid_g * STRIDE_KA_G
+ + offs_n[:, None] * STRIDE_KA_N
+ + offs_d[None, :]
+ )
+ k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
+
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ va_ptrs = (
+ V_app
+ + pid_g * STRIDE_VA_G
+ + offs_n[:, None] * STRIDE_VA_N
+ + offs_d[None, :]
+ )
+ v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # on-band remaining k
+ on_band_start = tl.maximum(q_start, off_b)
+ if on_band_start < q_end:
+ for ks in tl.range(on_band_start, q_end, BLOCK_N):
+ offs_n = ks + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < q_end
+
+ ka_ptrs = (
+ K_app
+ + pid_g * STRIDE_KA_G
+ + offs_n[:, None] * STRIDE_KA_N
+ + offs_d[None, :]
+ )
+
+ k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+
+ caus_mask = offs_n[None, :] <= q_idx[:, None]
+ full_mask = row_mask[:, None] & mask_n[None, :] & caus_mask
+
+ qk = tl.where(full_mask, qk, -1.0e6)
+
+ cur_max = tl.max(qk, 1)
+ n_e_max = tl.maximum(e_max, cur_max)
+ re_scale = tl.math.exp2(e_max - n_e_max)
+ p = tl.math.exp2(qk - n_e_max[:, None])
+
+ va_ptrs = (
+ V_app
+ + pid_g * STRIDE_VA_G
+ + offs_n[:, None] * STRIDE_VA_N
+ + offs_d[None, :]
+ )
+ v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
+
+ acc = acc * re_scale[:, None]
+ acc = tl.dot(p.to(v.dtype), v, acc)
+
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ # 3) write outputs
+ o = (acc / e_sum[:, None]).to(q.dtype)
+ out_ptrs = (
+ OUT
+ + pid_g * STRIDE_OUT_G
+ + q_idx[:, None] * STRIDE_OUT_N
+ + row_h[:, None] * STRIDE_OUT_H
+ + offs_d[None, :]
+ )
+ tl.store(out_ptrs, o, mask=row_mask[:, None])
diff --git a/vllm/kvprune_legacy_save/benchmark/__init__.py b/vllm/kvprune_legacy_save/benchmark/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f8699a480b2a9ab11562d2e9fcb7f546eb8f9b4
--- /dev/null
+++ b/vllm/kvprune_legacy_save/benchmark/__init__.py
@@ -0,0 +1,47 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Benchmark helpers for kv-prune / compactor kernels.
+
+Upstream snapshot (``compactor-vllm/src/compactor_vllm/benchmark``) contained **only**
+an empty ``__init__.py`` — no additional ``.py`` scripts. Those files are merged here
+as-is; there is nothing else to list under that directory in upstream.
+
+Use :data:`BENCHMARK_REGISTRY` to register microbenchmarks or CLI entrypoints you
+add under ``vllm.kvprune.benchmark``.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Callable
+
+# Files copied from upstream ``compactor_vllm/benchmark/`` (relative to that dir).
+UPSTREAM_BENCHMARK_FILES: tuple[str, ...] = ("__init__.py",)
+
+# Optional: name -> benchmark callable or import path string (e.g. "mymod:main").
+# Populated when you add real benchmarks beside this package.
+BENCHMARK_REGISTRY: dict[str, Callable[..., Any] | str] = {}
+
+
+def list_upstream_benchmark_files() -> tuple[str, ...]:
+ """Return the list of filenames that existed in upstream ``benchmark/``."""
+ return UPSTREAM_BENCHMARK_FILES
+
+
+def register_benchmark(name: str, target: Callable[..., Any] | str) -> None:
+ """Register a benchmark by name (callable or ``"module:attr"`` import path)."""
+ BENCHMARK_REGISTRY[name] = target
+
+
+def iter_registered_benchmarks() -> list[tuple[str, Callable[..., Any] | str]]:
+ """Return ``(name, target)`` pairs from :data:`BENCHMARK_REGISTRY`."""
+ return list(BENCHMARK_REGISTRY.items())
+
+
+__all__ = [
+ "BENCHMARK_REGISTRY",
+ "UPSTREAM_BENCHMARK_FILES",
+ "iter_registered_benchmarks",
+ "list_upstream_benchmark_files",
+ "register_benchmark",
+]
diff --git a/vllm/kvprune_legacy_save/compactor_porting_status.py b/vllm/kvprune_legacy_save/compactor_porting_status.py
new file mode 100644
index 0000000000000000000000000000000000000000..92333df8a47040cdc972daba2784ba499d572761
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compactor_porting_status.py
@@ -0,0 +1,56 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Layout notes: ``vllm/compactor-vllm/src/compactor_vllm`` (or sibling tree) →
+``vllm.kvprune.``.
+
+The upstream tree is merged into parallel subpackages under ``vllm/kvprune/``
+(``attention``, ``kv_cache``, ``compression``, ``config``, ``core``, ``layers``,
+``models``, ``triton_kernels``, ``utils``, ``benchmark``). Imports use
+``from vllm.kvprune..*``.
+
+v1 integration (FlashAttention, ``gpu_model_runner``) lives in
+``core.runtime``, ``core.flash_integration``, and ``compression/prefill.py``.
+
+**Note:** filenames with hyphens under ``compression/`` are not importable as
+Python modules; rename or load via ``importlib`` if needed.
+
+**TP / embedding in vLLM workers:** upstream compactor-vllm used only
+``vllm.kvprune`` ``ParallelLMHead`` + ``dist.gather``. When embedded in v1 workers,
+prefer ``delegate_kvprune_embed_tokens_to_vllm`` and
+``delegate_kvprune_compute_logits_to_vllm`` so token masking and logits match
+``vocab_parallel_embedding`` + ``LogitsProcessor`` (garbled text often came from
+TP gather / padded-vocab handling, not from the transformer body).
+"""
+
+from __future__ import annotations
+
+import pathlib
+
+
+def kvprune_root() -> pathlib.Path:
+ """Absolute path to ``vllm/kvprune``."""
+ return pathlib.Path(__file__).resolve().parent
+
+
+def list_py_files() -> list[str]:
+ """Relative paths of all ``.py`` files under ``kvprune`` (excluding __pycache__)."""
+ root = kvprune_root()
+ return sorted(
+ str(p.relative_to(root)).replace("\\", "/")
+ for p in root.rglob("*.py")
+ if "__pycache__" not in p.parts
+ )
+
+
+def format_layout_report() -> str:
+ files = list_py_files()
+ lines = [
+ "vllm.kvprune — merged compactor layout",
+ f"python file count: {len(files)}",
+ "=" * 50,
+ *files[:250],
+ ]
+ if len(files) > 250:
+ lines.append(f"... and {len(files) - 250} more")
+ return "\n".join(lines)
diff --git a/vllm/kvprune_legacy_save/compression/__init__.py b/vllm/kvprune_legacy_save/compression/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aee2d4640f8b46955788ac25c6987b4868afde35
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/__init__.py
@@ -0,0 +1,41 @@
+from vllm.kvprune.compression.common import (
+ BaseCompressionMethod,
+ NoCompression,
+)
+from vllm.kvprune.compression.criticalkv import CriticalAdaKVCompression
+from vllm.kvprune.compression.compactor import CompactorCompression
+from vllm.kvprune.compression.compression_config import (
+ BatchCompressionParams,
+ CompressionMethod,
+ SequenceCompressionParams,
+)
+from vllm.kvprune.compression.snapkv import SnapKVCompression
+
+COMPRESSION_REGISTRY: dict[CompressionMethod, type[BaseCompressionMethod]] = {
+ CompressionMethod.CRITICALADAKV: CriticalAdaKVCompression,
+ CompressionMethod.COMPACTOR: CompactorCompression,
+ CompressionMethod.SNAPKV: SnapKVCompression,
+ CompressionMethod.NONE: NoCompression,
+}
+
+
+def apply_prerope_compression(q, k, v, context):
+ method = context.compression_context.compression_method
+ return COMPRESSION_REGISTRY[method].pre_rope_scoring(q, k, v, context=context)
+
+
+def apply_postrope_compression(q, k, v, prerope_scores, context):
+ method = context.compression_context.compression_method
+ return COMPRESSION_REGISTRY[method].post_rope_scoring(
+ q, k, v, prerope_scores, context=context
+ )
+
+
+__all__ = [
+ "apply_prerope_compression",
+ "apply_postrope_compression",
+ "CompressionMethod",
+ "BatchCompressionParams",
+ "SequenceCompressionParams",
+ "COMPRESSION_REGISTRY"
+]
diff --git a/vllm/kvprune_legacy_save/compression/common.py b/vllm/kvprune_legacy_save/compression/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..35b157fe0d3cf28a1b96b0401d8e06bc6f4f3975
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/common.py
@@ -0,0 +1,243 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+
+from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv
+
+
+class BaseCompressionMethod(ABC):
+ """
+ Abstract interface for KV cache compression methods.
+
+ A compression method is implemented as a pair of optional scoring phases
+ that run before and after rotary position embedding (RoPE) is applied:
+
+ 1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.
+
+ 2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
+ - refine / reweight the pre-RoPE scores, or
+ - compute potentially position-aware.
+
+ Concrete subclasses are expected to implement both
+ static methods and return a single tensor of scores (or ``None`` if the
+ phase is a no-op), which the caller can then feed into the shared
+ “scores → top-k indices → KV extraction” pipeline.
+ """
+
+ @staticmethod
+ @abstractmethod
+ def pre_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ """
+ Compute per-token importance scores from pre-RoPE queries/keys.
+
+ Args:
+ :param q:
+ Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
+ :param k:
+ Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
+ :param v:
+ Value tensor. Shape ``[total_tokens, HKV, D]```
+ :param context:
+ ``compactor_vllm.utils.context.Context`` object carrying additional metadata,
+ such as batch mappings or temporary buffers
+
+ Returns:
+ :return Optional[torch.Tensor]:
+ A tensor of scores (e.g. per-token, per-head importance values)
+ to be passed to ``post_rope_scoring`` or directly into the
+ top-k selection step. If this phase is a no-op, implementations
+ should return ``None``. Shape ``[total_tokens, HKV]```.
+ """
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ """
+ Compute or refine importance scores from post-RoPE queries/keys.
+
+ This method is called after rotary embeddings have been applied. It can
+ optionally use both the post-RoPE Q/K and any scores produced by
+ ``pre_rope_scoring`` to produce final scores used for token selection.
+
+ Common patterns include:
+ * Using ``pre_rope_scores`` as a base signal and applying a
+ position-aware correction.
+ * Only computing scores that depend on absolute or relative positions.
+ * Simply passing through ``pre_rope_scores`` unchanged.
+
+ Args:
+ :param q:
+ Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
+ :param k:
+ Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
+ :param pre_rope_scores:
+ Optional scores returned by ``pre_rope_scoring``. May be
+ ``None`` if the pre-RoPE phase returned None.
+ :param v:
+ Value tensor. Shape ``[total_tokens, HKV, D]```
+ :param context:
+ ``compactor_vllm.utils.context.Context`` object carrying additional metadata,
+ such as batch mappings or temporary buffers
+ Returns:
+ :return Optional[torch.Tensor]:
+ Final importance scores to be consumed by the compression
+ pipeline (for top-k token selection). If this phase is a
+ no-op, implementations may return ``pre_rope_scores``. If
+ None is returned, no compression will be applied.
+ """
+ pass
+
+
+class NoCompression(BaseCompressionMethod):
+ """
+ Trivial compression method that disables KV cache compression.
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ return pre_rope_scores
+
+
+def extract_and_store_top_kv(
+ scores: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_k_len: int,
+ top_k: int,
+ H: int,
+ new_keys: torch.Tensor, # [N_total, H, D]
+ new_vals: torch.Tensor, # [N_total, H, D]
+ num_tokens_to_retain: torch.Tensor, # [B] int32
+ page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
+ batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
+ bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
+ k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ PAD_TO_PAGE_SIZE: bool = True,
+ K_TILE: int = 16,
+ padding: float = -float("inf"),
+):
+ """helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
+ indices_topk = scores_to_retain_indices(
+ scores,
+ cu_seqlens_k=cu_seqlens_k,
+ max_k_len=max_k_len,
+ top_k=top_k,
+ H=H,
+ padding=padding,
+ )
+ prefill_store_topk_kv(
+ new_keys=new_keys,
+ new_vals=new_vals,
+ indices_topk=indices_topk,
+ num_tokens_to_retain=num_tokens_to_retain,
+ page_table=page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=bh_lens,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ cu_seqlens_k=cu_seqlens_k,
+ PAGE_SIZE=PAGE_SIZE,
+ PAD_TO_PAGE_SIZE=PAD_TO_PAGE_SIZE,
+ K_TILE=K_TILE,
+ )
+
+
+def scores_to_retain_indices(
+ scores: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_k_len: int,
+ top_k: int,
+ H: int,
+ padding: float = -float("inf"),
+) -> torch.Tensor:
+ """
+ Select global top-k token–head indices per sequence from packed scores.
+
+ This helper takes per-token, per-head scores in packed varlen form and
+ returns, for each batch element, the indices of the top-k (token, head)
+ pairs in the flattened global layout.
+ Inputs are assumed to follow the usual packed varlen convention:
+ • ``scores`` is laid out as ``[N_total, H]``, where:
+ ``N_total = sum_b seqlen_k[b]``
+ and ``HKV`` is the number of KV heads.
+
+ • ``cu_seqlens_k`` is ``[B + 1]`` (int32), giving cumulative lengths
+ for the keys per batch:
+ ``seqlen_k[b] = cu_seqlens_k[b + 1] - cu_seqlens_k[b]``.
+
+ • ``max_k_len`` is an upper bound on ``seqlen_k[b]`` across the batch.
+
+ The function pads each sequence to length ``max_k_len`` with ``padding``
+ (default: ``-inf``), flattens the per-sequence scores into shape
+ ``[B, max_k_len * H]``, and runs a per-batch top-k. The returned indices
+ are shifted so that they directly index into the flattened global
+ score layout of shape ``[N_total * H]``:
+ global_index = (token_global_offset * H) + head_index
+
+ Args:
+ :param scores:
+ Tensor of shape ``[N_total, HKV]`` containing scores for each
+ (token, head) pair in packed varlen format.
+ :param cu_seqlens_k:
+ Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
+ lengths for each batch element. The total number of tokens
+ satisfies ``N_total = cu_seqlens_k[-1]``.
+ :param max_k_len:
+ Maximum key sequence length across the batch (i.e.
+ ``max_b seqlen_k[b]``). Used to allocate the padded buffer.
+ :param top_k:
+ Number of (token, head) entries to retain **per batch element**.
+ If ``top_k > max_k_len * HKV``, it is clamped to ``max_k_len * HKV``.
+ :param H:
+ Number of key heads; must match ``scores.shape[1]``.
+ :param padding:
+ Padding value used when extending sequences shorter than
+ ``max_k_len``. Defaults to ``-inf``, so that padded positions are
+ never selected in the top-k.
+
+ Returns:
+ :return torch.Tensor:
+ Tensor of shape ``[B, k_eff]`` (int64) where
+ ``k_eff = min(top_k, max_k_len * H)``. Each entry is a global
+ index into the flattened score array of shape ``[N_total * H]``
+ (i.e. scores viewed as ``scores.view(-1)``),
+ """
+ # idea: pad and then select top-k.
+ B, device = cu_seqlens_k.numel() - 1, scores.device
+ padded = torch.full(
+ (B, max_k_len, H), fill_value=padding, dtype=scores.dtype, device=device
+ )
+ for b in range(B):
+ s, e = int(cu_seqlens_k[b]), int(cu_seqlens_k[b + 1])
+ padded[b, : e - s, :].copy_(scores[s:e, :])
+ flat = padded.view(B, max_k_len * H)
+ idx = torch.topk(
+ flat, k=min(top_k, max_k_len * H), dim=1, largest=True, sorted=True
+ ).indices
+ return idx + (cu_seqlens_k[:-1] * H).unsqueeze(-1)
diff --git a/vllm/kvprune_legacy_save/compression/compactor.py b/vllm/kvprune_legacy_save/compression/compactor.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8e5db1656ea8b53bae09b6cde154b0e3aa4ef7
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/compactor.py
@@ -0,0 +1,739 @@
+"""
+Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress``
+算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、
+全局 z-score、blending 与首尾 sink pad)。
+
+非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。
+"""
+
+from __future__ import annotations
+
+import math
+from typing import List, Optional
+
+import torch
+import triton
+import triton.language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.utils.context import get_context
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+
+
+def resolve_kvpress_compactor_blending(compression_context) -> float:
+ """与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。"""
+ if compression_context is None:
+ return 0.35
+ b = getattr(compression_context, "compactor_blending", None)
+ if b is not None:
+ return float(b)
+ cr = getattr(compression_context, "compression_ratio", None)
+ if cr is not None:
+ return float(cr)
+ return 0.35
+
+
+class CompactorCompression(BaseCompressionMethod):
+ """与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。"""
+
+ chunk_size: int = 256
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ # Index key rows by K packed layout (matches master/peer packed buffers).
+ # Do not use `or` — cu_seqlens_* are tensors and `bool(tensor)` is invalid.
+ _cu_k = getattr(context, "cu_seqlens_k", None)
+ cu_k = context.cu_seqlens_q if _cu_k is None else _cu_k
+ ctx = get_context()
+ host_k = ctx.cu_seqlens_k_host
+ if host_k is None:
+ host_k = ctx.cu_seqlens_q_host
+ return maybe_execute_in_stream(
+ kvpress_leverage_scores_packed,
+ k,
+ cu_k,
+ compression_context,
+ host_k,
+ STORE_STREAM=None,
+ )
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ blending = resolve_kvpress_compactor_blending(compression_context)
+ return maybe_execute_in_stream(
+ kvpress_compactor_post_rope,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ pre_rope_scores,
+ compression_context,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ blending=float(blending),
+ STORE_STREAM=context.STORE_STREAM,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Cholesky 杠杆分(kvpress ``LeverageScorePress``)
+# ---------------------------------------------------------------------------
+
+
+def chol_with_jitter(
+ G: torch.Tensor, jitter: float = 0.0, max_tries: int = 5
+) -> torch.Tensor:
+ identity = torch.eye(G.shape[-1], device=G.device, dtype=G.dtype)
+ cur = float(jitter)
+ for _ in range(max_tries):
+ L, info = torch.linalg.cholesky_ex(G + cur * identity, upper=False)
+ if bool((info == 0).all()):
+ return L
+ cur = max(1e-8, (1e-2 if cur == 0.0 else 10.0 * cur))
+ raise RuntimeError(f"Cholesky failed after {max_tries} tries.")
+
+
+def compute_leverage_scores_mid(
+ key_states: torch.Tensor, sketch_dimension: int
+) -> torch.Tensor:
+ """
+ 与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``,
+ 返回 ``[L, H]``(未 z-score)。
+
+ 维序与 kvpress 的 ``(B, H, S, D)`` 对齐:先变为 ``[1, H, L, D]``,在序列维(``dim=-2``)
+ 上中心化,再与 ``Phi`` 为 ``(1, H, D, K)`` 的 batch 矩阵乘得到 ``[1, H, L, K]``。
+ """
+ d, k = key_states.shape[-1], sketch_dimension
+ device, dtype = key_states.device, key_states.dtype
+ H = key_states.shape[1]
+ Phi = torch.randn(1, H, d, k, device=device, dtype=dtype) * (1.0 / math.sqrt(k))
+ # [L, H, d] -> [1, H, L, d],与 kvpress (B,H,S,d) 一致
+ X0 = key_states.transpose(0, 1).unsqueeze(0).contiguous()
+ # ROCm batched GEMM is sensitive to non-contiguous strides after transpose/mean.
+ X = (X0 - X0.mean(dim=-2, keepdim=True)).contiguous()
+ X = torch.matmul(X, Phi).to(torch.float32).contiguous()
+ XT = X.transpose(-2, -1).contiguous()
+ G = (XT @ X).contiguous()
+ G_sym = 0.5 * (G + G.transpose(-2, -1)).contiguous()
+ # HIP/ROCm: rocBLAS TRSM (used by cholesky_solve and often by linalg.solve for
+ # triangular solves) can launch blocks (e.g. 16x64x1) > __launch_bounds__(256).
+ # Small sketch_dim k: inv(G) @ XT avoids TRSM; k is typically <= 128.
+ if torch.version.hip is not None:
+ kk = G_sym.shape[-1]
+ eye = torch.eye(
+ kk, device=G_sym.device, dtype=G_sym.dtype, requires_grad=False
+ )
+ G_reg = G_sym + 1e-2 * eye
+ inv_Xt = torch.linalg.inv(G_reg) @ XT
+ else:
+ L = chol_with_jitter(G_sym, jitter=1e-2, max_tries=5)
+ inv_Xt = torch.cholesky_solve(XT, L, upper=False)
+ scores = (X * inv_Xt.transpose(-2, -1)).sum(dim=-1).clamp_min(0)
+ # [1, H, L] -> [L, H]
+ return scores.squeeze(0).transpose(0, 1).contiguous()
+
+
+def kvpress_leverage_scores_packed(
+ key_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ compression_ctx,
+ cu_seqlens_host: tuple[int, ...] | None = None,
+) -> torch.Tensor:
+ device = key_states.device
+ N, Hkv, _D = key_states.shape
+ sketch_dim = int(getattr(compression_ctx, "sketch_dimension", 48))
+ sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
+ sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
+
+ if cu_seqlens_host is not None:
+ bounds = list(cu_seqlens_host)
+ total = bounds[-1]
+ else:
+ cu_cpu = cu_seqlens.detach().cpu().view(-1)
+ total = int(cu_cpu[-1])
+ bounds = cu_cpu.tolist()
+ if total != N:
+ raise RuntimeError(
+ f"kvpress_leverage_scores_packed: cu_seqlens[-1]={total} != key_states "
+ f"num_rows={N} (check packed prefill / TP broadcast)."
+ )
+
+ out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
+ mids_flat: list[torch.Tensor] = []
+ mid_ranges: list[tuple[int, int, int]] = []
+
+ for b in range(len(bounds) - 1):
+ k_beg = int(bounds[b])
+ k_end = int(bounds[b + 1])
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ k_mid = key_states[mid_start:mid_end, :, :].contiguous()
+ raw = compute_leverage_scores_mid(k_mid, sketch_dim)
+ mids_flat.append(raw.reshape(-1))
+ mid_ranges.append((mid_start, mid_end, Hkv))
+
+ if not mids_flat:
+ return out
+
+ flat = torch.cat(mids_flat, dim=0)
+ z = _zscore_flat_f32_global(flat)
+ offset = 0
+ for (mid_start, mid_end, _Hkv), r in zip(mid_ranges, mids_flat):
+ n = r.numel()
+ seg = z[offset : offset + n].view(mid_end - mid_start, Hkv)
+ out[mid_start:mid_end, :] = seg
+ offset += n
+ return out
+
+
+# ---------------------------------------------------------------------------
+# 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton
+# ---------------------------------------------------------------------------
+
+
+def _non_causal_chunked_attn_pytorch(
+ q: torch.Tensor, k: torch.Tensor, chunk_size: int
+) -> torch.Tensor:
+ """参考实现:与 kvpress 逐算子一致。"""
+ assert chunk_size > 0 and q.shape == k.shape
+ L, H, d = q.shape
+ B = 1
+ q = q.permute(1, 0, 2).unsqueeze(0).contiguous()
+ k = k.permute(1, 0, 2).unsqueeze(0).contiguous()
+ _B, H, S, _d = k.shape
+ S_pad = math.ceil(S / chunk_size) * chunk_size
+ pad_len = S_pad - S
+
+ if pad_len > 0:
+ q_padded = torch.cat(
+ [q, torch.zeros(B, H, pad_len, d, device=q.device, dtype=q.dtype)], dim=2
+ )
+ k_padded = torch.cat(
+ [k, torch.zeros(B, H, pad_len, d, device=k.device, dtype=k.dtype)], dim=2
+ )
+ last_chunk_start = (S // chunk_size) * chunk_size
+ in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
+ query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
+ else:
+ q_padded, k_padded = q, k
+ last_chunk_start = ((S - 1) // chunk_size) * chunk_size
+ in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
+ query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
+
+ num_chunks = S_pad // chunk_size
+ q_chunks = q_padded.view(B, H, num_chunks, chunk_size, d)
+ k_chunks = k_padded.view(B, H, num_chunks, chunk_size, d)
+ dots = torch.matmul(q_chunks, k_chunks.transpose(-2, -1))
+ dots[:, :, -1].masked_fill_(query_mask.unsqueeze(-1), 0)
+ dots[:, :, -1].masked_fill_(key_mask.unsqueeze(-2), -1e-9)
+ attn = torch.softmax(dots.to(torch.float32), dim=-1)
+ out = attn.sum(dim=-2).view(B, H, S_pad)[..., :S]
+ return out.squeeze(0).transpose(0, 1).contiguous()
+
+
+@triton.jit
+def _non_causal_chunk_row_kernel(
+ Q_ptr,
+ K_ptr,
+ Out_ptr,
+ stride_qh,
+ stride_qs,
+ stride_qd,
+ stride_kh,
+ stride_ks,
+ stride_kd,
+ stride_oh,
+ stride_os,
+ S,
+ S_pad,
+ num_chunks,
+ CHUNK_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ ND: tl.constexpr,
+):
+ """
+ 每个 program:一个 head、一个 chunk、一条 query 行。
+ 对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。
+ """
+ h = tl.program_id(0)
+ c = tl.program_id(1)
+ iq = tl.program_id(2)
+
+ g_i = c * CHUNK_SIZE + iq
+
+ offs_j = tl.arange(0, CHUNK_SIZE)
+ logits = tl.zeros([CHUNK_SIZE], dtype=tl.float32)
+
+ for db in range(ND):
+ offs_d = tl.arange(0, BLOCK_D) + db * BLOCK_D
+ mask_d = offs_d < D
+ q_off = (
+ h * stride_qh + g_i * stride_qs + offs_d * stride_qd
+ )
+ qd = tl.load(Q_ptr + q_off, mask=mask_d, other=0.0).to(tl.float32)
+
+ g_j = c * CHUNK_SIZE + offs_j
+ k_row_off = h * stride_kh + g_j[:, None] * stride_ks + offs_d[None, :] * stride_kd
+ kj = tl.load(K_ptr + k_row_off, mask=mask_d[None, :], other=0.0).to(tl.float32)
+ logits += tl.sum(qd[None, :] * kj, axis=1)
+
+ row_invalid = g_i >= S
+ g_j_all = c * CHUNK_SIZE + offs_j
+ col_invalid = g_j_all >= S
+
+ logits = tl.where(row_invalid, tl.zeros([CHUNK_SIZE], dtype=tl.float32), logits)
+ logits = tl.where(
+ row_invalid,
+ logits,
+ tl.where(col_invalid, tl.full([CHUNK_SIZE], -1e-9, dtype=tl.float32), logits),
+ )
+
+ m = tl.max(logits)
+ logits = logits - m
+ exp_v = tl.exp(logits)
+ denom = tl.sum(exp_v)
+ p = exp_v / denom
+
+ out_base = h * stride_oh + g_j_all * stride_os
+ tl.atomic_add(Out_ptr + out_base, p, mask=g_j_all < S)
+
+
+def _non_causal_chunked_attn_triton(
+ q: torch.Tensor, k: torch.Tensor, chunk_size: int
+) -> torch.Tensor:
+ """CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。"""
+ assert q.is_cuda and k.is_cuda and q.shape == k.shape
+ L, H, d = q.shape
+ assert chunk_size > 0
+ S_pad = math.ceil(L / chunk_size) * chunk_size
+ pad_len = S_pad - L
+ if pad_len > 0:
+ zq = torch.zeros(
+ pad_len, H, d, device=q.device, dtype=q.dtype, requires_grad=False
+ )
+ zk = torch.zeros(
+ pad_len, H, d, device=k.device, dtype=k.dtype, requires_grad=False
+ )
+ q = torch.cat([q, zq], dim=0)
+ k = torch.cat([k, zk], dim=0)
+
+ Q = q.transpose(0, 1).contiguous().to(dtype=torch.float32)
+ K = k.transpose(0, 1).contiguous().to(dtype=torch.float32)
+
+ num_chunks = S_pad // chunk_size
+ out_acc = torch.zeros(H, S_pad, device=q.device, dtype=torch.float32)
+
+ S = int(L)
+ grid = (H, num_chunks, chunk_size)
+ BLOCK_D = 32 if d <= 128 else 64
+ ND = (d + BLOCK_D - 1) // BLOCK_D
+ _non_causal_chunk_row_kernel[grid](
+ Q,
+ K,
+ out_acc,
+ Q.stride(0),
+ Q.stride(1),
+ Q.stride(2),
+ K.stride(0),
+ K.stride(1),
+ K.stride(2),
+ out_acc.stride(0),
+ out_acc.stride(1),
+ S,
+ S_pad,
+ int(num_chunks),
+ CHUNK_SIZE=chunk_size,
+ D=d,
+ BLOCK_D=BLOCK_D,
+ ND=ND,
+ num_warps=4,
+ )
+ return out_acc[:, :S].transpose(0, 1).contiguous()
+
+
+def non_causal_chunked_attn(q: torch.Tensor, k: torch.Tensor, chunk_size: int) -> torch.Tensor:
+ """q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。"""
+ if q.is_cuda and k.is_cuda:
+ return _non_causal_chunked_attn_triton(q, k, chunk_size)
+ return _non_causal_chunked_attn_pytorch(q, k, chunk_size)
+
+
+# ---------------------------------------------------------------------------
+# ×||V|| + avg_pool1d(k=3) — Triton(CUDA)
+# ---------------------------------------------------------------------------
+
+
+@triton.jit
+def _mul_vnorm_avgpool3_kernel(
+ A_ptr,
+ V_ptr,
+ OUT_ptr,
+ stride_al,
+ stride_ah,
+ stride_vl,
+ stride_vh,
+ stride_vd,
+ stride_ol,
+ stride_oh,
+ L,
+ D: tl.constexpr,
+):
+ """Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。"""
+ l = tl.program_id(0)
+ h = tl.program_id(1)
+ offs = tl.arange(0, D)
+
+ pos_m1 = l - 1
+ inb_m1 = (pos_m1 >= 0) & (pos_m1 < L)
+ ps_m1 = tl.where(inb_m1, pos_m1, 0)
+ a_m1 = tl.load(
+ A_ptr + ps_m1 * stride_al + h * stride_ah,
+ mask=inb_m1,
+ other=0.0,
+ ).to(tl.float32)
+ v_m1 = tl.load(
+ V_ptr + ps_m1 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_m1,
+ other=0.0,
+ ).to(tl.float32)
+ s_m1 = tl.where(inb_m1, a_m1 * tl.sqrt(tl.sum(v_m1 * v_m1)), 0.0)
+
+ inb_0 = (l >= 0) & (l < L)
+ ps0 = tl.where(inb_0, l, 0)
+ a0 = tl.load(
+ A_ptr + ps0 * stride_al + h * stride_ah,
+ mask=inb_0,
+ other=0.0,
+ ).to(tl.float32)
+ v0 = tl.load(
+ V_ptr + ps0 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_0,
+ other=0.0,
+ ).to(tl.float32)
+ s_0 = tl.where(inb_0, a0 * tl.sqrt(tl.sum(v0 * v0)), 0.0)
+
+ pos_p1 = l + 1
+ inb_p1 = (pos_p1 >= 0) & (pos_p1 < L)
+ ps_p1 = tl.where(inb_p1, pos_p1, 0)
+ a_p1 = tl.load(
+ A_ptr + ps_p1 * stride_al + h * stride_ah,
+ mask=inb_p1,
+ other=0.0,
+ ).to(tl.float32)
+ v_p1 = tl.load(
+ V_ptr + ps_p1 * stride_vl + h * stride_vh + offs * stride_vd,
+ mask=inb_p1,
+ other=0.0,
+ ).to(tl.float32)
+ s_p1 = tl.where(inb_p1, a_p1 * tl.sqrt(tl.sum(v_p1 * v_p1)), 0.0)
+
+ out = (s_m1 + s_0 + s_p1) * (1.0 / 3.0)
+ tl.store(OUT_ptr + l * stride_ol + h * stride_oh, out)
+
+
+def _mul_vnorm_avgpool3_fused(
+ a: torch.Tensor, v: torch.Tensor, out: torch.Tensor | None = None
+) -> torch.Tensor:
+ assert a.dim() == 2 and v.dim() == 3 and a.shape[0] == v.shape[0] and a.shape[1] == v.shape[1]
+ L, H, D = v.shape
+ a = a.contiguous()
+ v = v.contiguous()
+ if a.dtype != torch.float32:
+ a = a.float()
+ if out is None:
+ out = torch.empty((L, H), device=v.device, dtype=torch.float32)
+ if L == 0 or H == 0:
+ return out
+ grid = (L, H)
+ _mul_vnorm_avgpool3_kernel[grid](
+ a,
+ v,
+ out,
+ a.stride(0),
+ a.stride(1),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ out.stride(0),
+ out.stride(1),
+ L,
+ D=D,
+ num_warps=4,
+ )
+ return out
+
+
+def _maybe_mul_vnorm_avgpool3_fused(a: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ if not a.is_cuda or not v.is_cuda:
+ import torch.nn.functional as F
+
+ s = a * v.norm(dim=-1)
+ return (
+ F.avg_pool1d(s.transpose(0, 1).unsqueeze(0), kernel_size=3, padding=1, stride=1)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
+ return _mul_vnorm_avgpool3_fused(a, v)
+
+
+@triton.jit
+def _zscore_elem_1d_kernel(
+ X_ptr,
+ OUT_ptr,
+ n,
+ mean,
+ inv_std,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+ mask = offs < n
+ x = tl.load(X_ptr + offs, mask=mask, other=0.0)
+ tl.store(OUT_ptr + offs, (x - mean) * inv_std, mask=mask)
+
+
+def _zscore_flat_f32_global(x: torch.Tensor) -> torch.Tensor:
+ """
+ 与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。
+ ``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。
+ """
+ if x.numel() == 0:
+ return x
+ mu = x.mean()
+ sig = x.std().clamp_min(1e-6)
+ inv = 1.0 / sig
+ if not x.is_cuda:
+ return (x - mu) * inv
+ x = x.contiguous()
+ out = torch.empty_like(x)
+ n = x.numel()
+ BLOCK = 1024
+ grid = (triton.cdiv(n, BLOCK),)
+ _zscore_elem_1d_kernel[grid](
+ x,
+ out,
+ n,
+ float(mu.item()),
+ float(inv.item()),
+ BLOCK=BLOCK,
+ num_warps=4,
+ )
+ return out
+
+
+def _attn_scores_kvpress_middle(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ sink_start: int,
+ sink_end: int,
+ chunk_size: int,
+ do_zscore: bool = True,
+) -> torch.Tensor:
+ """仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。"""
+ N, HQ, D = q.shape
+ Hkv = k.shape[1]
+ G = HQ // Hkv
+ device = q.device
+ attn_out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
+ parts: list[torch.Tensor] = []
+
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ q_m = q[mid_start:mid_end, :, :].contiguous()
+ k_m = k[mid_start:mid_end, :, :].contiguous()
+ v_m = v[mid_start:mid_end, :, :].contiguous()
+ # HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]``
+ k_4d = k_m.unsqueeze(0).transpose(1, 2).contiguous() # [1, Hkv, Lm, D]
+ k_rep = repeat_kv(k_4d, G)[0].transpose(0, 1).contiguous() # [Lm, HQ, D]
+ A = non_causal_chunked_attn(q_m, k_rep, chunk_size)
+ Lm, HQa = A.shape
+ assert HQa == HQ
+ A = A.view(Lm, Hkv, G).mean(dim=-1)
+ scores = _maybe_mul_vnorm_avgpool3_fused(A, v_m)
+ parts.append(scores.reshape(-1))
+
+ if not parts:
+ return attn_out
+
+ flat_a = torch.cat(parts, dim=0)
+ if do_zscore:
+ z_a = _zscore_flat_f32_global(flat_a)
+ else:
+ z_a = flat_a
+ offset = 0
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ n = (mid_end - mid_start) * Hkv
+ attn_out[mid_start:mid_end, :] = z_a[offset : offset + n].view(
+ mid_end - mid_start, Hkv
+ )
+ offset += n
+ return attn_out
+
+
+def non_causal_attn_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens_qk: torch.Tensor,
+ max_seqlen_qk: int,
+ chunk_size: int,
+ sm_scale: float = None,
+ normalize: bool = True,
+ context_lens: Optional[List[int]] = None,
+ protected_first_tokens: Optional[List[int]] = None,
+ protected_last_tokens: Optional[List[int]] = None,
+ *,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+) -> torch.Tensor:
+ """
+ 与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。
+ ``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。
+ 然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。
+ """
+ del sm_scale, max_seqlen_qk
+ sink_start, sink_end = 8, 4
+ out = _attn_scores_kvpress_middle(
+ q,
+ k,
+ v,
+ cu_seqlens_qk,
+ sink_start,
+ sink_end,
+ chunk_size,
+ do_zscore=normalize,
+ )
+
+ if accum_scores is not None:
+ w = 0.5 if accum_blending is None else float(accum_blending)
+ out = out + w * accum_scores.to(device=out.device, dtype=out.dtype)
+
+ if protected_first_tokens is not None and protected_last_tokens is not None and context_lens:
+ start = 0
+ for first, last, Lc in zip(
+ protected_first_tokens, protected_last_tokens, context_lens
+ ):
+ out[start : start + int(first)].fill_(torch.inf)
+ out[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
+ start += int(Lc)
+ return out
+
+
+def kvpress_compactor_post_rope(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ compression_ctx,
+ max_seqlen_q: int,
+ chunk_size: int,
+ blending: float,
+) -> torch.Tensor:
+ del max_seqlen_q
+ Hkv = k.shape[1]
+ device = q.device
+
+ sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
+ sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
+ context_lens: Optional[List[int]] = getattr(
+ compression_ctx, "context_lens", None
+ )
+ protected_first: Optional[List[int]] = getattr(
+ compression_ctx, "protected_first_tokens", None
+ )
+ protected_last: Optional[List[int]] = getattr(
+ compression_ctx, "protected_last_tokens", None
+ )
+
+ attn_out = _attn_scores_kvpress_middle(
+ q, k, v, cu_seqlens, sink_start, sink_end, chunk_size
+ )
+ lev = pre_rope_scores.to(device=device, dtype=torch.float32)
+ blended = torch.zeros_like(lev)
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if mid_start >= mid_end:
+ continue
+ blended[mid_start:mid_end, :] = (
+ blending * lev[mid_start:mid_end, :] + attn_out[mid_start:mid_end, :]
+ )
+
+ pad_val = blended.max()
+ if not torch.isfinite(pad_val) or pad_val == 0:
+ pad_val = torch.tensor(1.0, device=device, dtype=torch.float32)
+ for b in range(cu_seqlens.numel() - 1):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ L = k_end - k_beg
+ if L == 0:
+ continue
+ left_keep = min(sink_start, L)
+ right_keep = min(sink_end, max(0, L - left_keep))
+ mid_start = k_beg + left_keep
+ mid_end = k_end - right_keep
+ if left_keep > 0:
+ blended[k_beg:mid_start, :] = pad_val
+ if right_keep > 0:
+ blended[mid_end:k_end, :] = pad_val
+
+ if protected_first is not None and protected_last is not None and context_lens:
+ start = 0
+ for first, last, Lc in zip(
+ protected_first, protected_last, context_lens
+ ):
+ blended[start : start + int(first)].fill_(torch.inf)
+ blended[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
+ start += int(Lc)
+
+ return blended
diff --git a/vllm/kvprune_legacy_save/compression/compactor_origin.py b/vllm/kvprune_legacy_save/compression/compactor_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2dcb6e5012fd062a173dad94deb3462d07650f4
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/compactor_origin.py
@@ -0,0 +1,606 @@
+import logging
+import math
+from typing import List, Optional
+
+import torch
+import triton
+from tqdm.contrib.logging import logging_redirect_tqdm
+from triton import language as tl
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+logger = logging.getLogger(__name__)
+
+
+class CompactorCompression(BaseCompressionMethod):
+ chunk_size: int = 128
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ scores = maybe_execute_in_stream(
+ approximate_leverage_scores,
+ k,
+ compression_context.context_lens,
+ compression_context.PHI,
+ normalize=True,
+ chunk_size=compression_context.compression_chunk_size,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ return maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+
+def split_into_chunks(xs, chunk_size):
+ """
+ Convert a list of sequence lengths into a sequence of coalesced chunk lengths.
+
+ Given an iterable of per-sequence context lengths ``xs`` and a target ``chunk_size``,
+ this helper produces two parallel lists:
+
+ * ``coalesced_chunks`` – lengths of contiguous segments in the
+ **concatenated** sequence space, where each segment corresponds either
+ to a full chunk of size ``chunk_size`` or to a residual "epilogue"
+ tail shorter than ``chunk_size``.
+
+ * ``chunks`` – the actual chunk sizes used within each original sequence.
+ For a length ``n``, we produce ``n // chunk_size`` entries of
+ ``chunk_size`` (the "prologue") and at most one final entry equal to
+ ``n % chunk_size`` (the "epilogue").
+
+ ``chunks`` reflects how each input length is decomposed into
+ fixed-size (plus optional tail) processing blocks, while
+ ``coalesced_chunks`` describes those same blocks after concatenating consecutive
+ chunks of size ``chunk_size``. together
+
+ Example:
+ xs = [257, 127], chunk_size = 128
+ coalesced_chunks = [256, 1, 127]
+ chunks = [128, 128, 1, 127]
+
+ Args:
+ :param xs:
+ Iterable of non-negative integers
+ :param chunk_size:
+ Target chunk size
+
+ Returns:
+ :return Tuple[List[int], List[int]]:
+ ``(coalesced_chunks, chunks)`` as described above.
+ """
+ coalesced_chunks, chunks = [], []
+ for n in xs:
+ nchunks = n // chunk_size
+ prologue = nchunks * chunk_size
+ epilogue = n - prologue
+ if prologue > 0:
+ coalesced_chunks.append(prologue)
+ chunks.extend([chunk_size] * nchunks)
+ if epilogue > 0:
+ coalesced_chunks.append(epilogue)
+ chunks.append(epilogue)
+ return coalesced_chunks, chunks
+
+
+def approximate_leverage_scores(
+ key_states: torch.Tensor, # [N, H, D]
+ context_lens: List[int], # [B]
+ PHI: torch.Tensor, # [D, k]
+ regularizer: float = 5e-3,
+ normalize: bool = False,
+ chunk_size: int = 512,
+) -> torch.Tensor: # returns [N, H]
+ """
+ Approximate leverage scores for keys via randomized sketching.
+
+ This implements a randomized approximation to per-token leverage scores for
+ the key matrix, as described in Compactor: Calibrated Query-Agnostic KV Cache
+ Compression with Approximate Leverage Scores (https://arxiv.org/abs/2507.08143).
+ Args:
+ :param key_states:
+ Tensor of shape ``[N, H, D]`` containing pre-RoPE key states for
+ all tokens across the batch, packed along the sequence dimension.
+ ``N = sum(context_lens)``.
+ :param context_lens:
+ List of per-sequence context lengths, length ``B``.
+ :param PHI:
+ Random projection matrix of shape ``[D, k]`` used to sketch the
+ keys into a lower-dimensional subspace (k < D).
+ :param regularizer:
+ Small positive scalar added to the diagonal of each Gram matrix
+ before SVD to improve numerical stability. Defaults to ``1e-2``.
+ :param normalize:
+ If True, apply per-sequence z-score normalization to the scores
+ across all heads and tokens in a batch.
+ :param chunk_size:
+ Target chunk size along the sequence dimension. If > 0, the
+ concatenated sequence is split into chunks of at most this size
+ before forming Gram matrices and SVD. If ≤ 0, the entire sequence
+ for each context is treated as a single chunk.
+ Returns:
+ :return torch.Tensor:
+ Approximate leverage scores of shape ``[N, H]``, where each row
+ corresponds to a token and each column to a head.
+ """
+ if chunk_size > 0:
+ coalesced_chunk_lens, chunks_lens = split_into_chunks(context_lens, chunk_size)
+ else:
+ coalesced_chunk_lens, chunks_lens = context_lens, context_lens
+ # Same device as key_states (avoid bare .cuda() → wrong GPU in multi-device
+ # processes); int32 matches Triton zscore kernel expectations for cu_k.
+ chunk_lens_cuda = torch.tensor(
+ [0] + chunks_lens,
+ device=key_states.device,
+ dtype=torch.int32,
+ )
+ X = torch.matmul(key_states.transpose(0, 1), PHI)
+ H, N, k = X.shape
+ chunks = torch.split(X, coalesced_chunk_lens, dim=-2)
+ gram_matrices = []
+ for i, L in enumerate(coalesced_chunk_lens):
+ chunk = chunks[i]
+ if chunk_size <= 0 or L % chunk_size != 0:
+ chunk.sub_(chunk.mean(dim=-2, keepdim=True))
+ g = torch.matmul(chunk.transpose(-1, -2), chunk) # [H, k, k]
+ g = g.unsqueeze(1)
+ else:
+ chunk = chunk.view(H, -1, chunk_size, k) # [H, num_chunks, chunk_size, k]
+ chunk.sub_(chunk.mean(dim=-2, keepdim=True))
+ g = torch.matmul(chunk.transpose(-1, -2), chunk) # [H, num_chunks, k, k]
+ gram_matrices.append(g)
+ G = torch.cat(gram_matrices, dim=1).to(torch.float32)
+ diag = G.diagonal(dim1=-2, dim2=-1)
+ diag.add_(regularizer)
+ try:
+ V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
+ except RuntimeError:
+ try:
+ diag = G.diagonal(dim1=-2, dim2=-1)
+ diag.add_(regularizer * 10)
+ V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
+ except RuntimeError:
+ with logging_redirect_tqdm():
+ logger.warning(
+ "GESVDA failed, falling back to QR decomposition, which will be MUCH slower. "
+ "Try increasing chunk_size if this issue persists."
+ )
+ # this is over 50 times slower than using GESVDA
+ return _approximate_leverage_scores_qr_fallback(
+ X=X,
+ chunks_lens=chunks_lens,
+ chunk_lens_cuda=chunk_lens_cuda,
+ normalize=normalize,
+ chunk_size=chunk_size,
+ )
+ SV = (V * S.rsqrt().unsqueeze(-2)).to(X.dtype)
+ start = 0
+ all_scores = []
+ for i, L in enumerate(coalesced_chunk_lens):
+ chunk = chunks[i]
+ if chunk_size <= 0 or L % chunk_size != 0:
+ num_chunks = 1
+ sv = SV[:, start]
+ else:
+ num_chunks = L // chunk_size
+ chunk = chunk.view(H, -1, chunk_size, k) # [H, NC, CS]
+ sv = SV[:, start : start + num_chunks]
+ U = torch.matmul(chunk, sv)
+ scores = (U * U).sum(dim=-1).clamp_min_(0.0).view(H, -1)
+ all_scores.append(scores.transpose(-1, -2))
+ start += num_chunks
+
+ scores = torch.cat(all_scores, dim=0)
+ if normalize:
+ grid = (len(chunks_lens),)
+ cu_k = chunk_lens_cuda.cumsum(dim=0)
+ _zscore_per_batch_epilogue_no_window[grid](
+ scores, cu_k, scores.stride(0), scores.stride(1), H
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue_no_window(
+ OUT, # [Nk, Hk], float32
+ cu_k, # [B+1] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ if k_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+def _approximate_leverage_scores_qr_fallback(
+ X: torch.Tensor, # [H, N, k], already sketched (KΦ) and centered in-place
+ chunks_lens: List[int], # [num_chunks]
+ chunk_lens_cuda: torch.Tensor, # [num_chunks + 1] (prefix base)
+ normalize: bool,
+ chunk_size: int,
+) -> torch.Tensor:
+ H, N, k = X.shape
+ device, dtype = X.device, X.dtype
+ offsets: List[int] = []
+ offset = 0
+ for L in chunks_lens:
+ offsets.append(offset)
+ offset += L
+ if offset != N:
+ raise RuntimeError(
+ f"QR fallback: sum(chunks_lens)={offset} does not match N={N}"
+ )
+
+ blocks = torch.split(X, chunks_lens, dim=-2)
+ scores = torch.empty(N, H, device=device, dtype=dtype)
+ if chunk_size > 0:
+ full_indices = [i for i, L in enumerate(chunks_lens) if L == chunk_size]
+ epi_indices = [i for i, L in enumerate(chunks_lens) if L != chunk_size]
+
+ if full_indices:
+ # stack full chunks
+ full_blocks = torch.stack(
+ [blocks[i] for i in full_indices], dim=0
+ ) # [M, H, CS, k]
+ M, Hf, Lf, kf = full_blocks.shape
+ assert Lf == chunk_size
+
+ # merge (M, H) into a single batch dim for torch.linalg.q
+ full_blocks_2d = full_blocks.view(M * Hf, Lf, kf).to(torch.float32)
+
+ U_full, _ = torch.linalg.qr(full_blocks_2d, mode="reduced")
+ U_full = U_full.to(dtype)
+ scores_full = (U_full * U_full).sum(dim=-1).clamp_min(0.0) # [M * Hf, Lf]
+ scores_full = scores_full.view(M, Hf, Lf).transpose(-1, -2) # [M, H, CS]
+ for m, chunk_idx in enumerate(full_indices):
+ start = offsets[chunk_idx]
+ Lc = chunks_lens[chunk_idx]
+ scores[start : start + Lc].copy_(scores_full[m])
+ else:
+ epi_indices = list(range(len(chunks_lens)))
+
+ for chunk_idx in epi_indices:
+ block = blocks[chunk_idx]
+ _, Lc, _ = block.shape
+ if Lc == 0:
+ continue
+ U_epi, _ = torch.linalg.qr(block.to(torch.float32), mode="reduced")
+ scores_epi = (U_epi * U_epi).sum(dim=-1).to(dtype) # [H, Lc]
+ start = offsets[chunk_idx]
+ scores[start : start + Lc] = scores_epi.transpose(0, 1) # [Lc, H]
+
+ if normalize:
+ grid = (len(chunks_lens),)
+ cu_k = chunk_lens_cuda.cumsum(dim=0)
+ _zscore_per_batch_epilogue_no_window[grid](
+ scores, cu_k, scores.stride(0), scores.stride(1), H
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_M": BM, "BLOCK_K": BK, "WARPSPEC": False}, num_warps=w, num_stages=s
+ )
+ for BM in [64]
+ for BK in [64]
+ for w in [4]
+ for s in [2]
+ ],
+ key=[
+ "QUERY_GROUP_SIZE",
+ "D",
+ "CHUNK_SIZE",
+ ],
+ cache_results=True,
+)
+@triton.jit
+def _non_causal_attn_kernel(
+ Q,
+ K,
+ V,
+ accum_scores,
+ cu_seqlens_qk,
+ #
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_Q_D,
+ STRIDE_K_G,
+ STRIDE_K_N,
+ STRIDE_K_D,
+ STRIDE_V_G,
+ STRIDE_V_N,
+ STRIDE_V_D,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ #
+ CHUNK_SIZE: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ D: tl.constexpr,
+ WARPSPEC: tl.constexpr,
+):
+ TOTAL_QUERIES_PER_BLOCK: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
+ INVERSE_CHUNK: tl.constexpr = 1.0 / CHUNK_SIZE
+ pid_g = tl.program_id(0) # KV head in [0, HKV)
+ pid_b = tl.program_id(1) # batch id
+ pid_m = tl.program_id(2) # chunk id within batch
+
+ off_b = tl.load(cu_seqlens_qk + pid_b)
+ off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
+
+ chunk_start = off_b + pid_m * CHUNK_SIZE
+ chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, off_b1)
+ M = chunk_end - chunk_start
+ if M <= 0:
+ return
+
+ offs_d = tl.arange(0, D)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ # Flattened query rows inside a [BLOCK_M, QUERY_GROUP_SIZE] tile
+ offs_q = tl.arange(0, TOTAL_QUERIES_PER_BLOCK)
+ row_m = offs_q % BLOCK_M # token offset in this tile
+ row_h = offs_q // BLOCK_M # query-group index
+
+ qk_scale = sm_scale * 1.44269504 # convert to log2-domain
+ NEG_INF = -1.0e9
+
+ # Iterate over query tiles within this chunk
+ for qs in tl.range(chunk_start, chunk_end, BLOCK_M):
+ # Global query indices for rows in this tile
+ q_idx = qs + row_m # [TOTAL_QUERIES_PER_BLOCK]
+ q_mask = q_idx < chunk_end # mask for valid rows in this tile
+
+ # Load Q tile: [TOTAL_QUERIES_PER_BLOCK, D]
+ q_ptrs = (
+ Q
+ + pid_g * STRIDE_Q_G
+ + q_idx[:, None] * STRIDE_Q_N
+ + row_h[:, None] * STRIDE_Q_H
+ + offs_d[None, :] * STRIDE_Q_D
+ )
+ q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0)
+
+ # ---- Pass 1: per-row max and denominator over all keys in this chunk ----
+ row_max = tl.full([TOTAL_QUERIES_PER_BLOCK], NEG_INF, tl.float32)
+ row_sum = tl.zeros([TOTAL_QUERIES_PER_BLOCK], dtype=tl.float32)
+
+ for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
+ k_idx = ks + offs_k # [BLOCK_K]
+ k_mask = k_idx < chunk_end # which keys are valid in this tile
+
+ k_ptrs = (
+ K
+ + pid_g * STRIDE_K_G
+ + k_idx[:, None] * STRIDE_K_N
+ + offs_d[None, :] * STRIDE_K_D
+ )
+ k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # [BLOCK_K, D]
+
+ # logits: [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)
+
+ cur_max = tl.max(qk, 1)
+ new_max = tl.maximum(row_max, cur_max)
+
+ # rescale previous sum to new_max (base 2)
+ rescale = tl.math.exp2(row_max - new_max)
+ p = tl.math.exp2(qk - new_max[:, None])
+
+ row_sum = row_sum * rescale + tl.sum(p, 1)
+ row_max = new_max
+
+ # Avoid division by zero for inactive rows
+ denom = tl.where(q_mask, row_sum, 1.0)
+
+ for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
+ k_idx = ks + offs_k
+ k_mask = k_idx < chunk_end
+
+ k_ptrs = (
+ K
+ + pid_g * STRIDE_K_G
+ + k_idx[:, None] * STRIDE_K_N
+ + offs_d[None, :] * STRIDE_K_D
+ )
+ k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0)
+
+ qk = tl.dot(q, k.T) * qk_scale
+ qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)
+
+ # p has shape [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
+ p = tl.math.exp2(qk - row_max[:, None]) / denom[:, None]
+ # zero-out invalid rows / columns
+ p = tl.where(
+ q_mask[:, None], p, INVERSE_CHUNK
+ ) # preserve attention mass in shorter chunks
+
+ contrib = tl.sum(p, 0) # [BLOCK_K], sum over queries & query-groups
+
+ out_ptrs = accum_scores + k_idx * STRIDE_OUT_N + pid_g * STRIDE_OUT_H
+ old = tl.load(out_ptrs, mask=k_mask, other=0.0)
+ new = old + contrib.to(old.dtype)
+ tl.store(out_ptrs, new, mask=k_mask)
+
+
+def non_causal_attn_scores(
+ q: torch.Tensor, # [N, HQ, D]
+ k: torch.Tensor, # [N, HKV, D]
+ v: torch.Tensor, # [N, HKV, D]
+ cu_seqlens_qk: torch.Tensor, # [B + 1]
+ max_seqlen_qk: int,
+ chunk_size: int,
+ sm_scale: float = None,
+ normalize: bool = True,
+ context_lens: Optional[List[int]] = None,
+ protected_first_tokens: Optional[List[int]] = None,
+ protected_last_tokens: Optional[List[int]] = None,
+ *,
+ accum_scores: torch.Tensor = None, # [N, HKV] (float32)
+ accum_blending: float = None,
+) -> torch.Tensor:
+ """
+ :param q: Tensor of shape ``[N, H, D]`` containing post-rope queries
+ :param k: Tensor of shape ``[N, H, D]`` containing post-rope keys
+ :param v: Tensor of shape ``[N, H, D]`` containing values
+ :param cu_seqlens_qk Tensor of shape ``[B + 1]`` demarcating batch boundaries
+ :param max_seqlen_qk int containing the maximum sequence length
+ :param chunk_size: int specifying the size of the chunk to perform non-causal attention over
+ :param sm_scale: float specifying the scaling factor applied to attention scores (1/sqrt(D) if None)
+ :param normalize: bool specifying whether to z-score normalize final attention scores
+ :param context_lens: List[int] specifying the context lengths. CPU version of cu_seqlens_qk.diff(0)
+ :param protected_first_tokens: List[int] specifying how many tokens should be protected at the
+ start of each sequence
+ :param protected_last_tokens: List[int] specifying how many tokens should be protected at the
+ end of each sequence
+ :param accum_scores: Tensor of shape ``[N, H]`` containing key scores that should be accumulated into
+ :param accum_blending float specifying the scaling of ``accum_scores`` prior to adding the new
+ non-causal attention scores. Final output is equivalent to return out + accum_blending * accum_scores
+ """
+ assert q.ndim == 3 and k.ndim == 3
+ assert q.shape[0] == k.shape[0] and q.shape[-1] == k.shape[-1]
+ N, HQ, D = q.shape
+ HKV = k.shape[1]
+ assert HQ % HKV == 0, "Number of query heads must divide number of KV heads"
+ assert (D & (D - 1)) == 0, "D must be a power of two"
+
+ B = cu_seqlens_qk.numel() - 1
+ H_g = HQ // HKV # query-group size per KV head
+
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+ out = torch.zeros(N, HKV, device=q.device, dtype=torch.float32)
+ q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
+ k = k.view(N, HKV, D).permute(1, 0, 2)
+ # v = v.view(N, HKV, D).permute(1, 0, 2)
+
+ if cu_seqlens_qk.device != q.device:
+ cu_seqlens_qk = cu_seqlens_qk.to(device=q.device)
+ cu_seqlens_qk = cu_seqlens_qk.to(torch.int32)
+
+ STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
+ STRIDE_K_G, STRIDE_K_N, STRIDE_K_D = k.stride()
+ STRIDE_V_G, STRIDE_V_N, STRIDE_V_D = v.stride()
+ STRIDE_OUT_N, STRIDE_OUT_H = out.stride()
+
+ assert STRIDE_Q_D == 1 and STRIDE_K_D == 1, "last dim must be contiguous"
+
+ def grid(_):
+ return (
+ HKV,
+ B,
+ triton.cdiv(max_seqlen_qk, chunk_size),
+ )
+
+ _non_causal_attn_kernel[grid](
+ q,
+ k,
+ v,
+ out,
+ cu_seqlens_qk,
+ STRIDE_Q_G,
+ STRIDE_Q_N,
+ STRIDE_Q_H,
+ STRIDE_Q_D,
+ STRIDE_K_G,
+ STRIDE_K_N,
+ STRIDE_K_D,
+ STRIDE_V_G,
+ STRIDE_V_N,
+ STRIDE_V_D,
+ STRIDE_OUT_N,
+ STRIDE_OUT_H,
+ sm_scale,
+ CHUNK_SIZE=chunk_size,
+ QUERY_GROUP_SIZE=H_g,
+ D=D,
+ )
+ if normalize:
+ grid = (B,)
+ _zscore_per_batch_epilogue_no_window[grid](
+ out, cu_seqlens_qk, out.stride(0), out.stride(1), HKV
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ out += accum_scores * accum_blending
+ else:
+ out += accum_scores
+ if protected_first_tokens is not None or protected_last_tokens is not None:
+ start = 0
+ for first, last, L in zip(
+ protected_first_tokens, protected_last_tokens, context_lens
+ ):
+ out[start : start + first].fill_(torch.inf)
+ out[start + L - last : start + L].fill_(torch.inf)
+ start += L
+ return out
diff --git a/vllm/kvprune_legacy_save/compression/compression_config.py b/vllm/kvprune_legacy_save/compression/compression_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e861e663644b0ff6e9d0d2641e6940e6514bf6f3
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/compression_config.py
@@ -0,0 +1,45 @@
+import logging
+from dataclasses import dataclass
+from enum import Enum, auto
+
+logger = logging.getLogger(__name__)
+
+
+class CompressionMethod(Enum):
+ CRITICALADAKV = auto()
+ COMPACTOR = auto()
+ SNAPKV = auto()
+ NONE = auto()
+
+
+# class CachingPolicy(Enum):
+# CACHE_PROMPT = auto()
+# DONT_CACHE = auto()
+
+
+# class CompressionType(Enum):
+# QUERY_AWARE = auto()
+# QUERY_AGNOSTIC = auto()
+
+
+@dataclass
+class SequenceCompressionParams:
+ compression_ratio: float = 1.0
+ protected_first_tokens: int = 16
+ protected_last_tokens: int = 64
+
+
+@dataclass
+class BatchCompressionParams:
+ # compression_type: CompressionType = CompressionType.QUERY_AGNOSTIC
+ compression_method: CompressionMethod = CompressionMethod.COMPACTOR
+
+ do_chunked_compression: bool = True
+ chunk_size: int = 512
+
+ def __post_init__(self):
+ if self.compression_method == CompressionMethod.SNAPKV:
+ self.do_chunked_compression = False
+ logger.warning(
+ "CompressionMethod.SNAPKV is not compatible with chunked compression. Disabling it."
+ )
diff --git a/vllm/kvprune_legacy_save/compression/criticalkv-cursor.py b/vllm/kvprune_legacy_save/compression/criticalkv-cursor.py
new file mode 100644
index 0000000000000000000000000000000000000000..20aaec214a77030ced076bb4bb40c7c2c03c1210
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/criticalkv-cursor.py
@@ -0,0 +1,459 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)及首/尾保护段长度。
+
+注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.compression.compactor import (
+ CompactorCompression,
+ non_causal_attn_scores,
+)
+from compactor_vllm.compression.snapkv import SnapKVCompression
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+
+
+# ============================================================================
+# Triton Kernel 1: 计算 ||Wo @ V||₁ (L1 范数)
+# ============================================================================
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ],
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+# ============================================================================
+# Triton Kernel 2: Stage 1 保护 + Stage 2 加权融合
+# ============================================================================
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
+ key=["Hk"],
+ cache_results=True,
+)
+@triton.jit
+def _critical_ada_fuse_kernel(
+ BASE_SCORES,
+ WO_V_NORM,
+ STAGE1_MASK,
+ cu_k,
+ OUT,
+ EPSILON: tl.constexpr,
+ STRIDE_BS_NK,
+ STRIDE_BS_HK,
+ STRIDE_WN_NK,
+ STRIDE_WN_HK,
+ STRIDE_S1_NK,
+ STRIDE_S1_HK,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
+ wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
+ s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
+
+ base = tl.load(bs_ptrs, mask=kmask, other=0.0)
+ wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
+ stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
+
+ fused = (base + EPSILON) * wnorm
+ fused = tl.where(stage1_protect == 1, float("inf"), fused)
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, fused, mask=kmask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 在每条序列上尽量贴近 kvpress 的 CriticalAdaKV 语义:
+ 1) alpha_safeguard 安全预算(每头至少保留一部分);
+ 2) 基于 base_scores 的 head-wise 自适应预算分配(head_budgets);
+ 3) Stage-1 按 head_budgets * first_stage_ratio 保护;
+ 4) Stage-2 计算 ``(base + eps) * ||Wo@V||_1``,再按 head_budgets 做每头 top-k 保护。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``、``protected_first_tokens``、``protected_last_tokens``;
+ 可选 ``critical_ada_epsilon``、``critical_ada_first_stage_ratio``、
+ ``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ prot_first = compression_ctx.protected_first_tokens or [0] * B
+ prot_last = compression_ctx.protected_last_tokens or [0] * B
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(getattr(compression_ctx, "critical_ada_alpha_safeguard", 0.2))
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
+ # kvpress 风格的每头预算(按序列自适应),用于 Stage-1/Stage-2。
+ head_budgets_by_batch = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ s = int(prot_first[b]) if b < len(prot_first) else 0
+ e = int(prot_last[b]) if b < len(prot_last) else 0
+ lo, hi = k_beg + s, k_end - e
+ compressible = max(0, hi - lo)
+ keep_pairs = int(btr[b].item())
+ if compressible <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ # 每头 token 预算(kvpress 的 n_kept)
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, compressible)
+ # 安全预算(每头至少保留 n_safe)
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ if n_safe > 0:
+ tk_safe = min(n_safe, compressible)
+ for hk in range(Hk):
+ safe_idx = torch.topk(base_scores[lo:hi, hk], tk_safe, sorted=False).indices
+ stage1_mask[lo + safe_idx, hk] = 1
+
+ # 自适应预算分配:在扁平 (token, head) 空间取 top n_kept_tokens*Hk,统计每个 head 的预算
+ budget_scores = base_scores[lo:hi, :].clone()
+ if n_safe > 0:
+ budget_scores[stage1_mask[lo:hi, :] == 1] = float("inf")
+ top_pairs = min(n_kept_tokens * Hk, budget_scores.numel())
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ top_idx_flat = torch.topk(
+ budget_scores.reshape(-1), top_pairs, sorted=False
+ ).indices
+ top_head_idx = top_idx_flat % Hk
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
+ head_budgets_by_batch.append(head_budgets)
+
+ # Stage-1:按 head_budgets 的 first_stage_ratio 分头保护(kvpress 语义)
+ for hk in range(Hk):
+ phase1_budget = int(head_budgets[hk].item() * first_stage_ratio)
+ if phase1_budget <= 0:
+ continue
+ tk = min(phase1_budget, compressible)
+ top_idx = torch.topk(base_scores[lo:hi, hk], tk, sorted=False).indices
+ stage1_mask[lo + top_idx, hk] = 1
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_fuse(_META):
+ return (B, Hk)
+
+ _critical_ada_fuse_kernel[grid_fuse](
+ base_scores,
+ wo_v_norm,
+ stage1_mask,
+ cu_seqlens,
+ final_scores,
+ EPSILON=epsilon,
+ *base_scores.stride(),
+ *wo_v_norm.stride(),
+ *stage1_mask.stride(),
+ *final_scores.stride(),
+ Hk=Hk,
+ )
+
+ # Stage-2(kvpress 语义):在融合后按每头预算再做一次 top-k 保护。
+ for b in range(B):
+ hb = head_budgets_by_batch[b]
+ if hb is None:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ s = int(prot_first[b]) if b < len(prot_first) else 0
+ e = int(prot_last[b]) if b < len(prot_last) else 0
+ lo, hi = k_beg + s, k_end - e
+ if hi <= lo:
+ continue
+ region_len = hi - lo
+ for hk in range(Hk):
+ budget = int(hb[hk].item())
+ if budget <= 0:
+ continue
+ tk = min(budget, region_len)
+ idx = torch.topk(final_scores[lo:hi, hk], tk, sorted=False).indices
+ final_scores[lo + idx, hk] = float("inf")
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx % Hk
+ seq_idx = prune_idx // Hk + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
+ 再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = getattr(cc, "critical_ada_base_scorer", "compactor") if cc is not None else "compactor"
+ if str(base).lower() == "snapkv":
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "snapkv":
+ base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
+ else:
+ # 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
+ # maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
+ # 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ base_scores = maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
diff --git a/vllm/kvprune_legacy_save/compression/criticalkv.py b/vllm/kvprune_legacy_save/compression/criticalkv.py
new file mode 100644
index 0000000000000000000000000000000000000000..316b1c5d3ed8eea7083b13e9f7f4dcbd8bc1a0c4
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/criticalkv.py
@@ -0,0 +1,471 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)。CriticalAda 主链在 **PyTorch** 中与 kvpress ``CriticalAdaKVPress.compress``
+对齐;``||Wo@V||_1`` 仍默认用 Triton ``_compute_wo_v_l1_kernel``(与 ``CriticalKVPress.vwl1norm`` 同式)。
+将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 可改走 ``_vwl1_norm_kvpress_reference``。
+
+注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.compression.compactor import (
+ CompactorCompression,
+ kvpress_compactor_post_rope,
+ resolve_kvpress_compactor_blending,
+)
+from vllm.kvprune.compression.snapkv import SnapKVCompression
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+
+def _criticalkv_prune_hip_pipeline(configs, _, **kwargs):
+ """HIP: TritonHCUGPUStreamPipelineV2 breaks on nested loops + hid_idx arange (see snapkv)."""
+ if torch.version.hip is None:
+ return list(configs)
+ return [c for c in configs if getattr(c, "num_stages", 1) == 1]
+
+
+def _compute_wo_v_l1_autotune_configs():
+ """CUDA: full autotune. HIP: single num_stages=1 config (avoids pipeliner + long autotune)."""
+ if torch.version.hip is not None:
+ return [
+ triton.Config(
+ {"BLOCK_K": 64, "BLOCK_D": 64}, num_warps=4, num_stages=1
+ ),
+ ]
+ return [
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ]
+
+
+# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
+_USE_WO_L1_REFERENCE_BACKEND = False
+
+
+def _vwl1_norm_kvpress_reference(
+ values_seg: torch.Tensor,
+ wo: torch.Tensor,
+ num_kv_heads: int,
+ num_query_groups: int,
+) -> torch.Tensor:
+ """
+ 与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
+ 将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
+
+ 算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
+ """
+ k_len, Hk, D = values_seg.shape
+ Hq, D_wo, hidden = wo.shape
+ assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
+ # [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
+ v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
+ v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
+ # Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
+ wo_f = wo
+ head_list = []
+ for head in range(Hq):
+ v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
+ head_wov = v_h.matmul(wo_f[head, :, :])
+ head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
+ head_list.append(head_wov_norm)
+ stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
+ stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
+ return stacked.transpose(0, 1).contiguous()
+
+
+# ============================================================================
+# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
+# ============================================================================
+@triton_autotune(
+ configs=_compute_wo_v_l1_autotune_configs(),
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+ prune_configs_by={"early_config_prune": _criticalkv_prune_hip_pipeline},
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 按 kvpress ``CriticalAdaKVPress.compress`` 的顺序实现:safeguard scatter →
+ head-major 展平做 head_budgets → Stage1 在 **已抬高** 的分数上 top-k →
+ ``(scores + ε) * ||WoV||₁`` → Stage2 scatter → 最终按 head-major 展平做 bottom-k。
+
+ ``||Wo@V||₁`` 仍用 Triton(``_compute_wo_v_l1_kernel``);中间 CriticalAda 步骤用 PyTorch
+ 与 kvpress 逐句对齐。仅 base 分数来自 Compactor/SnapKV。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
+ ``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ if B > 0 and int(k_lengths.max().item()) > 0:
+ if _USE_WO_L1_REFERENCE_BACKEND:
+ for b in range(B):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_end <= k_beg:
+ continue
+ v_seg = v[k_beg:k_end, :, :].contiguous()
+ wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
+ v_seg, wo, Hk, G
+ )
+ else:
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ # kvpress 用 finfo.max 抬高分数;与 inf 混用时 topk 行为一致
+ _score_max = float(torch.finfo(torch.float32).max)
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ head_budgets_by_batch: list[Optional[torch.Tensor]] = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+
+ scores_seg = base_scores[k_beg:k_end, :].float()
+ keep_pairs = int(btr[b].item())
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, k_len)
+
+ # scores_work: 布局 [k_len, Hk],对应 kvpress [bsz=1, H, k_len] 的 transpose(0,2) 视角下沿 token 维的 topk
+ scores_work = scores_seg.clone()
+
+ # --- Alpha safeguard(kvpress L148–152)---
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ nk = min(n_safe, k_len) if n_safe > 0 else 0
+ if nk > 0:
+ for hk in range(Hk):
+ top_idx = torch.topk(scores_work[:, hk], nk, dim=0, largest=True).indices
+ scores_work[top_idx, hk] = _score_max
+
+ # --- Head budgets:kvpress L158–164,展平顺序与 [bsz, H, k_len] 一致(head-major:h*K + t)---
+ top_pairs = min(n_kept_tokens * Hk, k_len * Hk)
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ wn = wo_v_norm[k_beg:k_end, :]
+ final_scores[k_beg:k_end, :] = (scores_seg + epsilon) * wn
+ continue
+
+ budget_flat = scores_work.permute(1, 0).contiguous().reshape(-1)
+ top_idx_flat = torch.topk(
+ budget_flat, top_pairs, largest=True, sorted=False
+ ).indices
+ top_head_idx = top_idx_flat // k_len
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int64)
+ head_budgets_by_batch.append(head_budgets)
+
+ # --- Stage 1(kvpress L166–171):在已 safeguard 的 scores_work 上沿 token 维 top-k ---
+ head_selection_budget_1st = (
+ (head_budgets.to(torch.float32) * float(first_stage_ratio))
+ .to(torch.int64)
+ .tolist()
+ )
+ M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
+ mk = min(M1, k_len) if M1 > 0 else 0
+ if mk > 0:
+ top_k_index = torch.topk(scores_work, mk, dim=0, largest=True, sorted=True).indices
+ for hk in range(Hk):
+ phase1_budget = int(head_selection_budget_1st[hk])
+ if phase1_budget <= 0:
+ continue
+ take = min(phase1_budget, mk)
+ scores_work[top_k_index[:take, hk], hk] = _score_max
+
+ # --- Stage 2 重加权(kvpress L173–175)---
+ wn = wo_v_norm[k_beg:k_end, :]
+ scores_fused = (scores_work + epsilon) * wn
+
+ # --- Stage 2 scatter(kvpress L176–179)---
+ M2 = int(head_budgets.max().item())
+ mk2 = min(M2, k_len) if M2 > 0 else 0
+ if mk2 > 0:
+ top_k_index2 = torch.topk(
+ scores_fused, mk2, dim=0, largest=True, sorted=True
+ ).indices
+ for hk in range(Hk):
+ budget = int(head_budgets[hk].item())
+ if budget <= 0:
+ continue
+ take = min(budget, mk2)
+ scores_fused[top_k_index2[:take, hk], hk] = _score_max
+
+ final_scores[k_beg:k_end, :] = scores_fused
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ # kvpress L187:``scores.reshape(bsz, -1)`` 即 [H, K] 按 head-major 展平(flat = h*K + t)
+ flat_scores = (
+ final_scores[k_beg:k_end, :].permute(1, 0).contiguous().reshape(-1)
+ )
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx // k_len
+ seq_idx = prune_idx % k_len + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 仅 ``critical_ada_base_scorer == "compactor"`` 时与 kvpress ``CompactorPress.score`` 一致
+ (``kvpress_compactor_post_rope``:``blending * l_scores + attn_scores``);其它 base(如 SnapKV)
+ 走对应单一 ScorerPress,再叠 CriticalAda。须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = (
+ getattr(cc, "critical_ada_base_scorer", "compactor")
+ if cc is not None
+ else "compactor"
+ )
+ if str(base).lower() == "compactor":
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "compactor":
+ # 特例:与 ``CompactorPress.score`` / ``CompactorCompression.post_rope_scoring`` 一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ blending = resolve_kvpress_compactor_blending(compression_context)
+ base_scores = maybe_execute_in_stream(
+ kvpress_compactor_post_rope,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ pre_rope_scores,
+ compression_context,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ blending=float(blending),
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ else:
+ base_scores = SnapKVCompression.post_rope_scoring(
+ q, k, v, pre_rope_scores, context
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
+
diff --git a/vllm/kvprune_legacy_save/compression/criticalkv_origin.py b/vllm/kvprune_legacy_save/compression/criticalkv_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5964c95908ddd2529c97e5cb617ec11ebffa878
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/criticalkv_origin.py
@@ -0,0 +1,502 @@
+"""
+CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
+用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
+
+预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
+(token, head) 对数量)。Stage1/2 与 kvpress 论文/实现一致;``||Wo@V||_1`` 在 **算法上** 与
+``CriticalKVPress.vwl1norm`` 相同(GQA 上逐 query 头 L1 再对组取均值)。**默认用 Triton**
+(``_compute_wo_v_l1_kernel``);若需与 PyTorch 逐行对齐,将模块内 ``_USE_WO_L1_REFERENCE_BACKEND`` 改为 ``True`` 即走 ``_vwl1_norm_kvpress_reference``。
+
+注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
+与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional, Tuple
+
+import torch
+import triton
+from triton import language as tl
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.compression.compactor import (
+ CompactorCompression,
+ non_causal_attn_scores,
+)
+from compactor_vllm.compression.snapkv import SnapKVCompression
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
+_USE_WO_L1_REFERENCE_BACKEND = False
+
+
+def _vwl1_norm_kvpress_reference(
+ values_seg: torch.Tensor,
+ wo: torch.Tensor,
+ num_kv_heads: int,
+ num_query_groups: int,
+) -> torch.Tensor:
+ """
+ 与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
+ 将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
+
+ 算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
+ """
+ k_len, Hk, D = values_seg.shape
+ Hq, D_wo, hidden = wo.shape
+ assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
+ # [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
+ v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
+ v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
+ # Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
+ wo_f = wo
+ head_list = []
+ for head in range(Hq):
+ v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
+ head_wov = v_h.matmul(wo_f[head, :, :])
+ head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
+ head_list.append(head_wov_norm)
+ stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
+ stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
+ return stacked.transpose(0, 1).contiguous()
+
+
+# ============================================================================
+# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
+# ============================================================================
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
+ for bk in [32, 64, 128]
+ for bd in [32, 64]
+ for nw in [4, 8]
+ for ns in [3, 4]
+ ],
+ key=["Hk", "D", "HIDDEN"],
+ cache_results=True,
+)
+@triton.jit
+def _compute_wo_v_l1_kernel(
+ V,
+ WO,
+ cu_k,
+ OUT,
+ STRIDE_V_NK,
+ STRIDE_V_HK,
+ STRIDE_V_D,
+ STRIDE_WO_HQ,
+ STRIDE_WO_D,
+ STRIDE_WO_HID,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ Hk: tl.constexpr,
+ Hq: tl.constexpr,
+ D: tl.constexpr,
+ HIDDEN: tl.constexpr,
+ QUERY_GROUP_SIZE: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ """对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ ks = tl.program_id(2)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
+ nk = k_beg + nk_off
+ k_mask = nk < k_end
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for g in range(QUERY_GROUP_SIZE):
+ hq = hk * QUERY_GROUP_SIZE + g
+
+ v_ptrs = (
+ V
+ + nk[:, None] * STRIDE_V_NK
+ + hk * STRIDE_V_HK
+ + tl.arange(0, D)[None, :] * STRIDE_V_D
+ )
+ v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
+
+ for hid_off in range(0, HIDDEN, BLOCK_D):
+ hid_idx = hid_off + tl.arange(0, BLOCK_D)
+ hid_mask = hid_idx < HIDDEN
+
+ wo_ptrs = (
+ WO
+ + hq * STRIDE_WO_HQ
+ + tl.arange(0, D)[:, None] * STRIDE_WO_D
+ + hid_idx[None, :] * STRIDE_WO_HID
+ )
+ wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
+
+ wov_tile = tl.dot(v_blk, wo_tile)
+ l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
+
+ l1_sum = l1_sum / QUERY_GROUP_SIZE
+ tl.store(out_ptrs, l1_sum, mask=k_mask)
+
+
+# ============================================================================
+# Triton:Stage 1 保护 + Stage 2 加权融合(逐元素)
+# ============================================================================
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
+ key=["Hk"],
+ cache_results=True,
+)
+@triton.jit
+def _critical_ada_fuse_kernel(
+ BASE_SCORES,
+ WO_V_NORM,
+ STAGE1_MASK,
+ cu_k,
+ OUT,
+ STRIDE_BS_NK,
+ STRIDE_BS_HK,
+ STRIDE_WN_NK,
+ STRIDE_WN_HK,
+ STRIDE_S1_NK,
+ STRIDE_S1_HK,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ EPSILON: tl.constexpr,
+ Hk: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
+ wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
+ s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
+
+ base = tl.load(bs_ptrs, mask=kmask, other=0.0)
+ wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
+ stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
+
+ fused = (base + EPSILON) * wnorm
+ fused = tl.where(stage1_protect == 1, float("inf"), fused)
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, fused, mask=kmask)
+
+
+def critical_ada_key_scores(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ wo_weight: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ base_scores: torch.Tensor,
+ compression_ctx: Any,
+ *,
+ store_stream: Optional[torch.cuda.Stream] = None,
+) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
+ """
+ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
+ 在每条序列上对齐 kvpress ``CriticalAdaKVPress.compress``(整段 ``k_len``、与源实现相同的
+ top-k / scatter 顺序);仅 base 分数来自 compactor_vllm 的 Compactor/SnapKV。
+
+ Args:
+ compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
+ ``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
+ ``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
+ """
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
+ device = q.device
+ _, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert D == Dk and Hq % Hk == 0
+
+ # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
+ # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
+ B = cu_seqlens.numel() - 1
+ G = Hq // Hk
+ k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+
+ btr = compression_ctx.batch_tokens_to_retain
+ assert btr is not None and btr.numel() == B
+ btr = btr.to(device=device, dtype=torch.int32)
+
+ epsilon = compression_ctx.critical_ada_epsilon
+ first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
+ alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
+ alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
+
+ if wo_weight.dim() == 2:
+ hidden_size, _ = wo_weight.shape
+ wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
+ else:
+ wo = wo_weight.contiguous()
+ hidden_size = wo.size(-1)
+
+ wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ if B > 0 and int(k_lengths.max().item()) > 0:
+ if _USE_WO_L1_REFERENCE_BACKEND:
+ for b in range(B):
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ if k_end <= k_beg:
+ continue
+ v_seg = v[k_beg:k_end, :, :].contiguous()
+ wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
+ v_seg, wo, Hk, G
+ )
+ else:
+
+ def grid_wo(META):
+ max_k_len = int(k_lengths.max().item())
+ return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
+
+ _compute_wo_v_l1_kernel[grid_wo](
+ v,
+ wo,
+ cu_seqlens,
+ wo_v_norm,
+ *v.stride(),
+ *wo.stride(),
+ *wo_v_norm.stride(),
+ Hk=Hk,
+ Hq=Hq,
+ D=D,
+ HIDDEN=hidden_size,
+ QUERY_GROUP_SIZE=G,
+ )
+
+ stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
+ head_budgets_by_batch: list[Optional[torch.Tensor]] = []
+
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ head_budgets_by_batch.append(None)
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ keep_pairs = int(btr[b].item())
+ scores_seg = base_scores[k_beg:k_end, :]
+ # 与 kvpress 的 n_kept 一致:每头保留 n_kept 个 token
+ n_kept_tokens = max(1, keep_pairs // Hk)
+ n_kept_tokens = min(n_kept_tokens, k_len)
+
+ # kvpress:topk 在「未改动的」scores 上取索引,scatter 只写在副本上,供 head_budgets 用;
+ # Stage1 仍用原始 scores_seg(见下)。
+ working = scores_seg.clone()
+ n_safe = int(n_kept_tokens * alpha_safeguard)
+ if n_safe > 0:
+ nk = min(n_safe, k_len)
+ for hk in range(Hk):
+ top_idx = torch.topk(scores_seg[:, hk], nk, sorted=True).indices
+ working[:, hk].scatter_(0, top_idx, float("inf"))
+
+ top_pairs = min(n_kept_tokens * Hk, working.numel())
+ if top_pairs <= 0:
+ head_budgets_by_batch.append(None)
+ continue
+ top_idx_flat = torch.topk(working.reshape(-1), top_pairs, sorted=False).indices
+ top_head_idx = top_idx_flat % Hk
+ head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
+ head_budgets_by_batch.append(head_budgets)
+
+ # Stage 1:与 kvpress 相同 — 先 topk(..., M1, sorted=True),再每头取前 phase1 个下标
+ head_selection_budget_1st = (
+ (head_budgets.to(torch.float32) * float(first_stage_ratio))
+ .to(torch.int64)
+ .tolist()
+ )
+ M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
+ if M1 > 0:
+ mk = min(M1, k_len)
+ for hk in range(Hk):
+ phase1_budget = int(head_selection_budget_1st[hk])
+ if phase1_budget <= 0:
+ continue
+ full_idx = torch.topk(scores_seg[:, hk], mk, sorted=True).indices
+ take = min(phase1_budget, mk)
+ stage1_mask[k_beg + full_idx[:take], hk] = 1
+
+ final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+
+ def grid_fuse(_META):
+ return (B, Hk)
+
+ _critical_ada_fuse_kernel[grid_fuse](
+ base_scores,
+ wo_v_norm,
+ stage1_mask,
+ cu_seqlens,
+ final_scores,
+ *base_scores.stride(),
+ *wo_v_norm.stride(),
+ *stage1_mask.stride(),
+ *final_scores.stride(),
+ Hk=Hk,
+ EPSILON=float(epsilon),
+ )
+
+ # Stage 2(kvpress):对融合后分数先 topk(..., M2, sorted=True),再每头取前 budget 个下标置 inf
+ for b in range(B):
+ hb = head_budgets_by_batch[b]
+ if hb is None:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ k_len = k_end - k_beg
+ if k_len <= 0:
+ continue
+ fused_seg = final_scores[k_beg:k_end, :]
+ M2 = int(hb.max().item())
+ if M2 <= 0:
+ continue
+ mk = min(M2, k_len)
+ for hk in range(Hk):
+ budget = int(hb[hk].item())
+ if budget <= 0:
+ continue
+ full_idx = torch.topk(fused_seg[:, hk], mk, sorted=True).indices
+ take = min(budget, mk)
+ final_scores[k_beg + full_idx[:take], hk] = float("inf")
+
+ masked_key_indices = None
+ for b in range(B):
+ k_len = int(k_lengths[b].item())
+ if k_len == 0:
+ continue
+ keep_pairs = int(btr[b].item())
+ total_pairs = k_len * Hk
+ if keep_pairs >= total_pairs:
+ continue
+ k_beg = int(cu_seqlens[b].item())
+ k_end = int(cu_seqlens[b + 1].item())
+ n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
+ if n_prune_pairs <= 0:
+ continue
+
+ flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
+ prune_idx = torch.topk(
+ -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
+ ).indices
+ batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
+ head_idx = prune_idx % Hk
+ seq_idx = prune_idx // Hk + k_beg
+ if masked_key_indices is None:
+ masked_key_indices = (batch_idx, head_idx, seq_idx)
+ else:
+ masked_key_indices = (
+ torch.cat([masked_key_indices[0], batch_idx]),
+ torch.cat([masked_key_indices[1], head_idx]),
+ torch.cat([masked_key_indices[2], seq_idx]),
+ )
+
+ if store_stream is not None:
+ final_scores.record_stream(store_stream)
+
+ return final_scores, masked_key_indices
+
+
+class CriticalAdaKVCompression(BaseCompressionMethod):
+ """
+ 以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
+ 再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
+ """
+
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ cc = context.compression_context
+ base = getattr(cc, "critical_ada_base_scorer", "snapkv") if cc is not None else "compactor"
+ if str(base).lower() == "snapkv":
+ return SnapKVCompression.pre_rope_scoring(q, k, v, context)
+ return CompactorCompression.pre_rope_scoring(q, k, v, context)
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: Optional[torch.Tensor],
+ context,
+ ) -> Optional[torch.Tensor]:
+ compression_context = context.compression_context
+ assert compression_context is not None
+ base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
+
+ if base == "snapkv":
+ base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
+ else:
+ # 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
+ # maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
+ # 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+
+ base_scores = maybe_execute_in_stream(
+ non_causal_attn_scores,
+ q,
+ k,
+ v,
+ context.cu_seqlens_q,
+ context.max_seqlen_q,
+ chunk_size=CompactorCompression.chunk_size,
+ sm_scale=1.0,
+ normalize=True,
+ accum_scores=pre_rope_scores,
+ context_lens=compression_context.context_lens,
+ protected_first_tokens=compression_context.protected_first_tokens,
+ protected_last_tokens=compression_context.protected_last_tokens,
+ accum_blending=0.5,
+ )
+
+ wo_weight = compression_context.wo_weight
+ if wo_weight is None:
+ return base_scores
+
+ scores, _masked = maybe_execute_in_stream(
+ critical_ada_key_scores,
+ q,
+ k,
+ v,
+ wo_weight,
+ context.cu_seqlens_q,
+ base_scores,
+ compression_context,
+ STORE_STREAM=context.STORE_STREAM,
+ store_stream=context.STORE_STREAM,
+ )
+ return scores
+
+ @staticmethod
+ def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
+ """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
+ if not hasattr(module, "o_proj") or module.o_proj.weight is None:
+ return
+ if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
+ return
+ wo_raw = module.o_proj.weight.data
+ hidden_size, _ = wo_raw.shape
+ Hq = module.num_heads
+ head_dim = module.head_dim
+ wo = (
+ wo_raw.transpose(0, 1)
+ .view(Hq, head_dim, hidden_size)
+ .to(device=device, dtype=torch.float32)
+ )
+ module._critical_ada_wo_weight = wo
+
diff --git a/vllm/kvprune_legacy_save/compression/prefill.py b/vllm/kvprune_legacy_save/compression/prefill.py
new file mode 100644
index 0000000000000000000000000000000000000000..cae7697a569568afe31279f1bf863100acd25e57
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/prefill.py
@@ -0,0 +1,310 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Compactor-style sparse prefill: Triton varlen attention + paged KV store.
+
+Migrated kernels: ``sparse_varlen_kernel.causal_sparse_varlen_with_cache`` and
+``store_kv_cache.prefill_store_topk_kv``.
+
+Layout: MQA uses ``flatten_kv_cache_plane``; GQA/MHA uses head-major flatten
+(see ``layout_bridge``).
+
+Execution order note: vLLM runs ``unified_kv_cache_update`` (writes KV) before
+``unified_attention_with_output``. Compactor's sparse attention kernel assumes
+the paged cache holds only the prefix *before* the current K/V append, while
+K_app carries the new tokens. That differs from vLLM's order (cache already
+contains the current step after reshape). Therefore ``try_sparse_prefill_forward``
+is provided as a reference / future hook and is not invoked from the default
+FlashAttention forward path; prefill KV pruning uses ``prefill_store_topk_kv``
+in ``do_kv_cache_update_kv_prune`` instead.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+
+from vllm.forward_context import get_forward_context
+from vllm.kvprune.compression.prefill_registry import try_topk_indices_from_registry
+from vllm.kvprune.core.compression_bridge import compression_method_id_to_enum
+from vllm.kvprune.core.runtime import get_kv_prune_state, layer_index_from_layer_name
+from vllm.kvprune.utils.layout_bridge import (
+ block_table_to_global_page_table,
+ build_batch_mapping,
+ build_page_table_head_major,
+ flatten_kv_cache_head_major,
+ flatten_kv_cache_plane,
+ write_head_major_flat_to_interleaved,
+)
+from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
+from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv
+
+if TYPE_CHECKING:
+ from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl, FlashAttentionMetadata
+
+_RATIO_EPS = 1.0e-6
+
+
+def _get_flash_attn_metadata(layer_name: str) -> "FlashAttentionMetadata | None":
+ try:
+ fc = get_forward_context()
+ except AssertionError:
+ return None
+ am = fc.attn_metadata
+ if isinstance(am, list):
+ if not am:
+ return None
+ am = am[0]
+ meta = am.get(layer_name)
+ return meta
+
+
+def try_sparse_prefill_forward(
+ impl: "FlashAttentionImpl",
+ layer: torch.nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ attn_metadata: "FlashAttentionMetadata",
+ output: torch.Tensor,
+ num_actual_tokens: int,
+) -> bool:
+ """Run compactor ``causal_sparse_varlen_with_cache`` when eligible. Returns True if ran."""
+ state = get_kv_prune_state()
+ if state is None or not state.is_prefill:
+ return False
+ comp = state.compression_ratio_gpu[: state.num_reqs]
+ pruned = comp < 1.0 - _RATIO_EPS
+ if not torch.any(pruned):
+ return False
+ mids = state.compression_method_id_gpu[: state.num_reqs]
+ if torch.unique(mids).numel() > 1:
+ return False
+ # Mixed pruned + non-pruned requests: keep default FlashAttention path for now.
+ if torch.any(pruned) and torch.any(~pruned):
+ return False
+ if impl.num_kv_heads != 1:
+ return False
+ if impl.kv_cache_dtype.startswith("fp8"):
+ return False
+ if impl.alibi_slopes is not None:
+ return False
+ if impl.sliding_window != (-1, -1):
+ return False
+ d = impl.head_size
+ if d <= 0 or (d & (d - 1)) != 0:
+ return False
+
+ num_reqs = state.num_reqs
+ cu = state.query_start_loc[: num_reqs + 1].to(device=query.device, dtype=torch.int32)
+ seq_lens = attn_metadata.seq_lens[:num_reqs].to(torch.int32)
+ seqlen_q = cu[1:] - cu[:-1]
+ cached = seq_lens - seqlen_q
+ if torch.any(cached < 0):
+ return False
+
+ seq_lens_bh = cached.unsqueeze(1).expand(-1, 1).contiguous()
+ block_table = attn_metadata.block_table[:num_reqs]
+ max_batches = block_table.shape[0]
+ n_lp = block_table.shape[1]
+ global_page_table = block_table_to_global_page_table(
+ block_table, impl.num_kv_heads, max_batches=max_batches
+ )
+ batch_mapping = build_batch_mapping(num_reqs, query.device)
+
+ try:
+ k_flat, v_flat = flatten_kv_cache_plane(key_cache, value_cache, impl.num_kv_heads)
+ except ValueError:
+ return False
+
+ page_size = key_cache.shape[1]
+ if page_size <= 0 or k_flat.shape[0] % page_size != 0:
+ return False
+
+ q3 = query[:num_actual_tokens].view(num_actual_tokens, impl.num_heads, d)
+ k3 = key[:num_actual_tokens].view(num_actual_tokens, 1, d)
+ v3 = value[:num_actual_tokens].view(num_actual_tokens, 1, d)
+
+ max_seqlen_q = int(attn_metadata.max_query_len)
+ max_cached = int(seq_lens_bh.max().item()) if seq_lens_bh.numel() else 0
+
+ out = causal_sparse_varlen_with_cache(
+ q3,
+ k3,
+ v3,
+ k_flat,
+ v_flat,
+ seq_lens_bh,
+ global_page_table,
+ batch_mapping,
+ cu,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k_cache=max_cached,
+ HKV=1,
+ PAGE_SIZE=page_size,
+ sm_scale=None,
+ )
+ output[:num_actual_tokens].copy_(out.reshape(num_actual_tokens, impl.num_heads * d))
+ return True
+
+
+def _build_tail_topk_indices(
+ cu_seqlens: torch.Tensor,
+ num_reqs: int,
+ hkv: int,
+ compression_ratio: float | torch.Tensor,
+ max_sel: int,
+ device: torch.device,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Return (indices [B, max_sel], num_pairs_to_retain [B]) for tail tokens × heads."""
+ indices = torch.zeros(num_reqs, max_sel, dtype=torch.int32, device=device)
+ n_pairs = torch.zeros(num_reqs, dtype=torch.int32, device=device)
+ cu_cpu = cu_seqlens[: num_reqs + 1].detach()
+ for b in range(num_reqs):
+ start = int(cu_cpu[b].item())
+ end = int(cu_cpu[b + 1].item())
+ chunk_len = end - start
+ if chunk_len <= 0:
+ continue
+ if isinstance(compression_ratio, torch.Tensor):
+ r_b = float(compression_ratio[b].item())
+ else:
+ r_b = compression_ratio
+ k_tok = max(1, int(round(chunk_len * r_b)))
+ k_tok = min(k_tok, chunk_len)
+ pairs: list[int] = []
+ for tok in range(end - k_tok, end):
+ for h in range(hkv):
+ pairs.append(tok * hkv + h)
+ if len(pairs) >= max_sel:
+ break
+ if len(pairs) >= max_sel:
+ break
+ n = len(pairs)
+ if n > 0:
+ indices[b, :n] = torch.tensor(pairs, dtype=torch.int32, device=device)
+ n_pairs[b] = n
+ return indices, n_pairs
+
+
+def try_prefill_kv_store(
+ layer: torch.nn.Module,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ kv_cache: torch.Tensor,
+) -> bool:
+ """Top-k or full compactor prefill KV store; updates per-layer logical lengths."""
+ state = get_kv_prune_state()
+ if state is None or not state.is_prefill:
+ return False
+ num_reqs = state.num_reqs
+ comp = state.compression_ratio_gpu[:num_reqs]
+ pruned = comp < 1.0 - _RATIO_EPS
+ if not torch.any(pruned):
+ return False
+ if torch.any(pruned) and torch.any(~pruned):
+ return False
+ mids = state.compression_method_id_gpu[:num_reqs]
+ if torch.unique(mids).numel() > 1:
+ return False
+
+ meta = _get_flash_attn_metadata(layer.layer_name)
+ if meta is None:
+ return False
+
+ num_kv_heads = key.shape[1]
+ d = key.shape[2]
+ if d <= 0 or (d & (d - 1)) != 0:
+ return False
+
+ key_cache, value_cache = kv_cache.unbind(0)
+ page_size = key_cache.shape[1]
+ nb = key_cache.shape[0]
+ bs = key_cache.shape[1]
+ head_major = num_kv_heads > 1
+ try:
+ if head_major:
+ k_flat, v_flat = flatten_kv_cache_head_major(key_cache, value_cache)
+ else:
+ k_flat, v_flat = flatten_kv_cache_plane(
+ key_cache, value_cache, num_kv_heads
+ )
+ except ValueError:
+ return False
+
+ block_table = meta.block_table[:num_reqs]
+ max_batches = block_table.shape[0]
+ if head_major:
+ global_page_table = build_page_table_head_major(
+ block_table,
+ num_kv_heads,
+ num_blocks=nb,
+ block_size=bs,
+ page_size=page_size,
+ max_batches=max_batches,
+ )
+ else:
+ global_page_table = block_table_to_global_page_table(
+ block_table, num_kv_heads, max_batches=max_batches
+ )
+ batch_mapping = build_batch_mapping(num_reqs, key.device)
+
+ cu = state.query_start_loc[: num_reqs + 1].to(device=key.device, dtype=torch.int32)
+ seq_lens = meta.seq_lens[:num_reqs].to(torch.int32)
+ seqlen_q = cu[1:] - cu[:-1]
+ cached = (seq_lens - seqlen_q).unsqueeze(1).expand(-1, num_kv_heads).contiguous()
+
+ layer_idx = layer_index_from_layer_name(layer.layer_name)
+ max_seqlen_k = int(seqlen_q.max().item()) if seqlen_q.numel() else 0
+
+ max_sel = min(max_seqlen_k * num_kv_heads, 8192)
+ max_sel = max(max_sel, 1)
+ mid = int(state.compression_method_id_gpu[0].item())
+ method_enum = compression_method_id_to_enum(mid)
+ registry_out = try_topk_indices_from_registry(
+ method_enum, key, value, cu, num_reqs, comp, max_sel, key.device
+ )
+ if registry_out is not None:
+ indices, n_pairs = registry_out
+ else:
+ indices, n_pairs = _build_tail_topk_indices(
+ cu, num_reqs, num_kv_heads, comp, max_sel, key.device
+ )
+ bh = cached.clone()
+ prefill_store_topk_kv(
+ new_keys=key,
+ new_vals=value,
+ indices_topk=indices,
+ num_tokens_to_retain=n_pairs,
+ page_table=global_page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=bh,
+ k_cache=k_flat,
+ v_cache=v_flat,
+ PAGE_SIZE=page_size,
+ PAD_TO_PAGE_SIZE=False,
+ cu_seqlens_k=None,
+ )
+ if head_major:
+ write_head_major_flat_to_interleaved(
+ k_flat, v_flat, key_cache, value_cache
+ )
+
+ new_lens = bh.to(torch.int32)
+ if state.logical_seq_lens_gpu.dim() == 3:
+ state.logical_seq_lens_gpu[layer_idx, :num_reqs, :] = new_lens
+ else:
+ state.logical_seq_lens_gpu[layer_idx, :num_reqs] = new_lens.max(
+ dim=1
+ ).values
+
+ return True
+
+
+__all__ = [
+ "try_sparse_prefill_forward",
+ "try_prefill_kv_store",
+]
diff --git a/vllm/kvprune_legacy_save/compression/prefill_registry.py b/vllm/kvprune_legacy_save/compression/prefill_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b077694858b8815a3d5ea580247d58c96a62677
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/prefill_registry.py
@@ -0,0 +1,201 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Map COMPRESSION_REGISTRY scoring to prefill top-k indices (with tail fallback)."""
+
+from __future__ import annotations
+
+import logging
+
+import torch
+
+from vllm.kvprune.compression import COMPRESSION_REGISTRY
+from vllm.kvprune.compression.compression_config import CompressionMethod
+from vllm.kvprune.utils.context import CompressionContext, Context
+
+logger = logging.getLogger(__name__)
+
+
+def _scores_to_topk_pair_indices(
+ cu_seqlens: torch.Tensor,
+ num_reqs: int,
+ hkv: int,
+ scores: torch.Tensor,
+ compression_ratio: float | torch.Tensor,
+ max_sel: int,
+ device: torch.device,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Select (token, head) pairs with highest scores per request up to budget."""
+ if scores.dim() == 1:
+ scores = scores.unsqueeze(-1).expand(-1, hkv)
+ elif scores.dim() > 2:
+ scores = scores.reshape(scores.shape[0], -1)[:, :hkv]
+ indices = torch.zeros(num_reqs, max_sel, dtype=torch.int32, device=device)
+ n_pairs = torch.zeros(num_reqs, dtype=torch.int32, device=device)
+ cu_cpu = cu_seqlens[: num_reqs + 1].detach()
+ for b in range(num_reqs):
+ start = int(cu_cpu[b].item())
+ end = int(cu_cpu[b + 1].item())
+ chunk_len = end - start
+ if chunk_len <= 0:
+ continue
+ if isinstance(compression_ratio, torch.Tensor):
+ r_b = float(compression_ratio[b].item())
+ else:
+ r_b = compression_ratio
+ k_tok = max(1, int(round(chunk_len * r_b)))
+ k_tok = min(k_tok, chunk_len)
+ budget = min(k_tok * hkv, max_sel)
+ flat_scores: list[tuple[float, int]] = []
+ for tok in range(start, end):
+ for h in range(hkv):
+ if scores.dim() == 2:
+ s = float(scores[tok, h].item())
+ else:
+ s = float(scores[tok].item())
+ idx = tok * hkv + h
+ flat_scores.append((s, idx))
+ flat_scores.sort(key=lambda x: -x[0])
+ n = min(budget, len(flat_scores))
+ if n > 0:
+ chosen = [x[1] for x in flat_scores[:n]]
+ indices[b, :n] = torch.tensor(chosen, dtype=torch.int32, device=device)
+ n_pairs[b] = n
+ return indices, n_pairs
+
+
+def try_topk_indices_from_registry(
+ method: CompressionMethod,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu: torch.Tensor,
+ num_reqs: int,
+ compression_ratio: torch.Tensor,
+ max_sel: int,
+ device: torch.device,
+) -> tuple[torch.Tensor, torch.Tensor] | None:
+ """Return (indices, n_pairs) using registry scoring, or None to use tail fallback."""
+ if method == CompressionMethod.NONE:
+ return None
+ num_kv_heads = key.shape[1]
+ n_tokens, hkv, d = key.shape
+ if n_tokens <= 0 or hkv <= 0:
+ return None
+
+ k_flat = key.reshape(n_tokens, hkv, d)
+ v_flat = value.reshape(n_tokens, hkv, d)
+
+ context_lens = []
+ cu_cpu = cu[: num_reqs + 1].detach().cpu()
+ for b in range(num_reqs):
+ context_lens.append(int(cu_cpu[b + 1].item() - cu_cpu[b].item()))
+
+ max_seqlen_q = int((cu_cpu[1 : num_reqs + 1] - cu_cpu[:num_reqs]).max().item())
+
+ if method == CompressionMethod.COMPACTOR:
+ try:
+ k_proj = min(64, d)
+ phi = torch.randn(d, k_proj, device=key.device, dtype=torch.float32)
+ cc = CompressionContext(
+ compression_method=CompressionMethod.COMPACTOR,
+ context_lens=context_lens,
+ PHI=phi,
+ compression_chunk_size=512,
+ protected_first_tokens=[0] * num_reqs,
+ protected_last_tokens=[0] * num_reqs,
+ )
+ ctx = Context(
+ is_prefill=True,
+ do_compression=True,
+ cu_seqlens_q=cu,
+ max_seqlen_q=max_seqlen_q,
+ compression_context=cc,
+ )
+ cls = COMPRESSION_REGISTRY[CompressionMethod.COMPACTOR]
+ q_dummy = torch.zeros_like(k_flat)
+ scores = cls.pre_rope_scoring(
+ q_dummy,
+ k_flat,
+ v_flat,
+ context=ctx,
+ )
+ if scores is None:
+ return None
+ return _scores_to_topk_pair_indices(
+ cu, num_reqs, hkv, scores, compression_ratio, max_sel, device
+ )
+ except Exception:
+ logger.debug("Compactor pre_rope scoring failed; using tail fallback", exc_info=True)
+ return None
+
+ if method == CompressionMethod.CRITICALADAKV:
+ try:
+ k_proj = min(64, d)
+ phi = torch.randn(d, k_proj, device=key.device, dtype=torch.float32)
+ cc = CompressionContext(
+ compression_method=CompressionMethod.CRITICALADAKV,
+ context_lens=context_lens,
+ PHI=phi,
+ compression_chunk_size=512,
+ protected_first_tokens=[0] * num_reqs,
+ protected_last_tokens=[0] * num_reqs,
+ )
+ ctx = Context(
+ is_prefill=True,
+ do_compression=True,
+ cu_seqlens_q=cu,
+ max_seqlen_q=max_seqlen_q,
+ compression_context=cc,
+ )
+ cls_ada = COMPRESSION_REGISTRY[CompressionMethod.CRITICALADAKV]
+ q_dummy = torch.zeros_like(k_flat)
+ pre_scores = cls_ada.pre_rope_scoring(
+ q_dummy, k_flat, v_flat, context=ctx
+ )
+ scores = cls_ada.post_rope_scoring(
+ q_dummy, k_flat, v_flat, pre_scores, context=ctx
+ )
+ if scores is None:
+ return None
+ return _scores_to_topk_pair_indices(
+ cu, num_reqs, hkv, scores, compression_ratio, max_sel, device
+ )
+ except Exception:
+ logger.debug(
+ "CriticalAdaKV registry path failed; using tail fallback", exc_info=True
+ )
+ return None
+
+ if method == CompressionMethod.SNAPKV:
+ try:
+ cc = CompressionContext(compression_method=CompressionMethod.SNAPKV)
+ ctx = Context(
+ is_prefill=True,
+ do_compression=True,
+ cu_seqlens_q=cu,
+ cu_seqlens_k=cu,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_q,
+ compression_context=cc,
+ )
+ cls = COMPRESSION_REGISTRY[CompressionMethod.SNAPKV]
+ q_dummy = torch.zeros_like(k_flat)
+ scores = cls.post_rope_scoring(
+ q_dummy,
+ k_flat,
+ v_flat,
+ None,
+ context=ctx,
+ )
+ if scores is None:
+ return None
+ return _scores_to_topk_pair_indices(
+ cu, num_reqs, hkv, scores, compression_ratio, max_sel, device
+ )
+ except Exception:
+ logger.debug("SnapKV registry path failed; using tail fallback", exc_info=True)
+ return None
+
+ return None
+
+
+__all__ = ["try_topk_indices_from_registry"]
diff --git a/vllm/kvprune_legacy_save/compression/snapkv.py b/vllm/kvprune_legacy_save/compression/snapkv.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a228138a63fa8daa746249fabff8caf38295e9f
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/snapkv.py
@@ -0,0 +1,545 @@
+import math
+from typing import Optional
+
+import torch
+import triton
+from triton import language as tl
+
+from vllm.kvprune.compression.common import BaseCompressionMethod
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
+
+# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
+DEFAULT_SNAPKV_WINDOW_SIZE = 64
+DEFAULT_SNAPKV_KERNEL_SIZE = 5
+
+
+class SnapKVCompression(BaseCompressionMethod):
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ scores = maybe_execute_in_stream(
+ query_aware_key_scores,
+ q,
+ k,
+ context.cu_seqlens_q,
+ context.cu_seqlens_k,
+ w=DEFAULT_SNAPKV_WINDOW_SIZE,
+ kernel_size=DEFAULT_SNAPKV_KERNEL_SIZE,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
+ )
+ for bq in [32, 64]
+ for bk in [32, 64]
+ for num_warps in [4, 8]
+ for num_stages in [3, 4]
+ ],
+ key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
+ cache_results=True,
+)
+@triton.jit
+def _lse_and_store_logits_kernel(
+ Q,
+ K,
+ cu_q,
+ cu_k,
+ w_b, # int32 pointers
+ out_m,
+ out_S, # [B, Hk, ROWS_MAX] float32
+ LOGITS, # [Nk, Hk, ROWS_MAX] float32
+ sm_scale, # float
+ QUERY_GROUP_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ STRIDE_Q_NQ,
+ STRIDE_Q_HQ,
+ STRIDE_K_NK,
+ STRIDE_K_HK,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ ROWS_MAX,
+):
+ # program ids
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ rid = tl.program_id(2) # row-tile id
+ # batch segment bounds
+ q_end = tl.load(cu_q + b + 1)
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ q_win_beg = q_end - win
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ # rows for this (b,hk)
+ rows_b = win * QUERY_GROUP_SIZE
+ row0 = rid * BLOCK_Q
+ if row0 >= rows_b:
+ return
+
+ # exp(x) = exp2(x * 1/ln2)
+ qk_scale = sm_scale * 1.4426950408889634
+
+ offs_qrow = row0 + tl.arange(0, BLOCK_Q)
+ row_mask = offs_qrow < rows_b
+
+ # map row -> (q_idx, hq_local)
+ hq_local = offs_qrow % QUERY_GROUP_SIZE
+ q_off = offs_qrow // QUERY_GROUP_SIZE
+ q_idx = q_win_beg + q_off
+ hq_glob = hk * QUERY_GROUP_SIZE + hq_local
+
+ offs_d = tl.arange(0, D)
+
+ q_ptrs = (
+ Q
+ + q_idx[:, None] * STRIDE_Q_NQ
+ + hq_glob[:, None] * STRIDE_Q_HQ
+ + offs_d[None, :]
+ )
+ q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+ m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
+ S = tl.zeros([BLOCK_Q], dtype=tl.float32)
+
+ # Full-sequence causal attention (matches kvpress softmax), then use prefix columns only.
+ for ks in tl.range(k_beg, k_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_end
+
+ k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
+
+ s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
+ s = tl.where(kmask[None, :], s, -float("inf"))
+ # Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
+ causal_ok = nk[None, :] <= q_idx[:, None]
+ s = tl.where(causal_ok, s, -float("inf"))
+
+ # store prefix logits only (for marginal probs on prefix keys)
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ store_mask = kmask & (nk < k_eff_end)
+ tl.store(log_ptrs, s.T, mask=store_mask[:, None] & row_mask[None, :])
+
+ # log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask)
+ cur_max = tl.max(s, 1) # [BQ]
+ n_m = tl.maximum(m, cur_max)
+ rescale = tl.math.exp2(m - n_m)
+ S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
+ m = n_m
+
+ # store m,S for these rows
+ m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
+ tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
+
+
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
+ for bq in [16, 32, 64]
+ for bk in [32, 64, 128]
+ ],
+ key=["HK", "HQ"],
+ cache_results=True,
+)
+@triton.jit
+def _prefix_probs_kernel(
+ cu_k,
+ w_b,
+ in_m,
+ in_S, # [B, Hk, ROWS_MAX] f32
+ LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only)
+ PROBS, # [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs
+ #
+ QUERY_GROUP_SIZE: tl.constexpr,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ STRIDE_PB_NK,
+ STRIDE_PB_HK,
+ STRIDE_PB_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ rows_b = win * QUERY_GROUP_SIZE
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ for row0 in tl.range(0, rows_b, BLOCK_Q):
+ r_idx = row0 + tl.arange(0, BLOCK_Q)
+ rmask = r_idx < rows_b
+
+ m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ m = tl.load(
+ m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
+ mask=rmask,
+ other=-float("inf"),
+ )
+ S = tl.load(
+ S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
+ )
+
+ valid_row = S > 0
+ m = tl.where(valid_row, m, 0.0)
+ S = tl.where(valid_row, S, 1.0)
+
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ s_T = tl.load(
+ log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
+ ) # [BK, BQ]
+
+ probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
+ probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
+
+ prob_ptrs = (
+ PROBS
+ + nk[:, None] * STRIDE_PB_NK
+ + hk * STRIDE_PB_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_PB_R
+ )
+ tl.store(prob_ptrs, probs_T, mask=kmask[:, None] & rmask[None, :])
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue(
+ OUT, # [Nk, Hk], float32
+ cu_k,
+ w_b, # [B+1], [B] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ EPS: tl.constexpr, # e.g., 1e-12
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if k_eff_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_eff_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var + EPS)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_T": bt}) for bt in [32, 64, 128, 256]],
+ key=["KERNEL_SIZE"],
+ cache_results=True,
+)
+@triton.jit
+def _snapkv_avg_pool1d_kernel(
+ IN,
+ OUT,
+ Lp,
+ STRIDE_IN_C,
+ STRIDE_IN_L,
+ STRIDE_OUT_C,
+ STRIDE_OUT_L,
+ KERNEL_SIZE: tl.constexpr,
+ PAD: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+):
+ """
+ Symmetric 1D average pool on the last dimension, matching
+ `F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp]
+ (equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size).
+ """
+ c = tl.program_id(0)
+ t0 = tl.program_id(1) * BLOCK_T + tl.arange(0, BLOCK_T)
+ mask = t0 < Lp
+
+ acc = tl.zeros([BLOCK_T], dtype=tl.float32)
+ for j in tl.static_range(KERNEL_SIZE):
+ idx = t0 - PAD + j
+ valid = (idx >= 0) & (idx < Lp)
+ ptrs = IN + c * STRIDE_IN_C + idx * STRIDE_IN_L
+ v = tl.load(ptrs, mask=valid & mask, other=0.0).to(tl.float32)
+ acc += v
+ acc = acc / tl.cast(KERNEL_SIZE, tl.float32)
+
+ out_ptrs = OUT + c * STRIDE_OUT_C + t0 * STRIDE_OUT_L
+ tl.store(out_ptrs, acc, mask=mask)
+
+
+def _snapkv_avg_pool1d_triton(x: torch.Tensor, kernel_size: int) -> torch.Tensor:
+ """
+ kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp].
+ `x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]).
+ """
+ assert x.dtype == torch.float32
+ Hk, G, Lp = x.shape
+ if Lp == 0:
+ return x
+ pad = kernel_size // 2
+ x2 = x.reshape(Hk * G, Lp).contiguous()
+ out = torch.empty_like(x2)
+ C = Hk * G
+ si_c, si_l = x2.stride()
+ so_c, so_l = out.stride()
+
+ def grid(meta):
+ return (C, triton.cdiv(Lp, meta["BLOCK_T"]))
+
+ _snapkv_avg_pool1d_kernel[grid](
+ x2,
+ out,
+ Lp,
+ si_c,
+ si_l,
+ so_c,
+ so_l,
+ KERNEL_SIZE=kernel_size,
+ PAD=pad,
+ )
+ return out.view(Hk, G, Lp)
+
+
+def _snapkv_kvpress_epilogue(
+ probs_buf: torch.Tensor,
+ out: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ w: torch.Tensor,
+ G: int,
+ Hk: int,
+ kernel_size: int,
+) -> None:
+ """
+ Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d
+ → mean over GQA groups → pad tail with global max of prefix scores.
+ """
+ B = cu_seqlens_k.numel() - 1
+ for b in range(B):
+ k_beg = int(cu_seqlens_k[b].item())
+ k_end = int(cu_seqlens_k[b + 1].item())
+ win = int(w[b].item())
+ k_eff_end = k_end - win
+ if win <= 0 or k_eff_end <= k_beg:
+ continue
+ Lp = k_eff_end - k_beg
+ rows_b = win * G
+ p = probs_buf[k_beg:k_eff_end, :, :rows_b]
+ # [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout
+ x = p.view(Lp, Hk, win, G).mean(dim=2)
+ x = x.permute(1, 2, 0).contiguous() # [Hk, G, Lp]
+ x = _snapkv_avg_pool1d_triton(x, kernel_size)
+ x = x.mean(dim=1)
+ seg = x.permute(1, 0).contiguous()
+ out[k_beg:k_eff_end, :] = seg
+ pad_val = seg.max()
+ out[k_eff_end:k_end, :] = pad_val
+
+
+def query_aware_key_scores(
+ q: torch.Tensor, # [N_q, Hq, D]
+ k: torch.Tensor, # [N_k, Hk, D]
+ cu_seqlens_q: torch.Tensor, # [B+1], int32
+ cu_seqlens_k: torch.Tensor, # [B+1], int32
+ w: torch.Tensor | int, # [B], int32
+ sm_scale: float = None, # defaults to 1/sqrt(D)
+ *,
+ kernel_size: int = DEFAULT_SNAPKV_KERNEL_SIZE,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+ normalize: bool = False,
+) -> Optional[torch.Tensor]:
+ assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
+ device = q.device
+ N_q, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ B = cu_seqlens_q.numel() - 1
+ assert B == cu_seqlens_k.numel() - 1
+
+ G = Hq // Hk
+ if type(w) is int:
+ max_w = w
+ w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
+ else:
+ max_w = int(w.max().item())
+ assert w.numel() == B
+ ROWS_MAX = max_w * G
+ if ROWS_MAX == 0:
+ return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+
+ out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+ m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ probs_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+
+ # strides
+ STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
+ STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
+ STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
+ STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
+ STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
+ STRIDE_PB_NK, STRIDE_PB_HK, STRIDE_PB_R = probs_buf.stride()
+ STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
+
+ def grid(META):
+ return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
+
+ _lse_and_store_logits_kernel[grid](
+ q,
+ k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ sm_scale,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ D=D,
+ STRIDE_Q_NQ=STRIDE_Q_NQ,
+ STRIDE_Q_HQ=STRIDE_Q_HQ,
+ STRIDE_K_NK=STRIDE_K_NK,
+ STRIDE_K_HK=STRIDE_K_HK,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ ROWS_MAX=ROWS_MAX,
+ )
+
+ _prefix_probs_kernel[(B, Hk)](
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ probs_buf,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ STRIDE_PB_NK=STRIDE_PB_NK,
+ STRIDE_PB_HK=STRIDE_PB_HK,
+ STRIDE_PB_R=STRIDE_PB_R,
+ )
+ _snapkv_kvpress_epilogue(
+ probs_buf, out, cu_seqlens_k, w, G, Hk, kernel_size
+ )
+ if normalize:
+ _zscore_per_batch_epilogue[(B,)](
+ out,
+ cu_seqlens_k,
+ w,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK=Hk,
+ EPS=1e-12,
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ accum_scores.mul_(accum_blending)
+ accum_scores.add_(out)
+ return accum_scores
+ else:
+ return out
diff --git a/vllm/kvprune_legacy_save/compression/snapkv_origin.py b/vllm/kvprune_legacy_save/compression/snapkv_origin.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eaaba64a384f4bc74840724168710439eec9a16
--- /dev/null
+++ b/vllm/kvprune_legacy_save/compression/snapkv_origin.py
@@ -0,0 +1,449 @@
+import math
+from typing import Optional
+
+import torch
+import triton
+from triton import language as tl
+
+from compactor_vllm.compression.common import BaseCompressionMethod
+from compactor_vllm.utils.helpers import maybe_execute_in_stream
+from compactor_vllm.utils.triton_compat import autotune as triton_autotune
+
+
+class SnapKVCompression(BaseCompressionMethod):
+ @staticmethod
+ def pre_rope_scoring(
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
+ ) -> Optional[torch.Tensor]:
+ return None
+
+ @staticmethod
+ def post_rope_scoring(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pre_rope_scores: torch.Tensor,
+ context,
+ ) -> Optional[torch.Tensor]:
+ scores = maybe_execute_in_stream(
+ query_aware_key_scores,
+ q,
+ k,
+ context.cu_seqlens_q,
+ context.cu_seqlens_k,
+ w=32,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ return scores
+
+
+@triton_autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
+ )
+ for bq in [32, 64]
+ for bk in [32, 64]
+ for num_warps in [4, 8]
+ for num_stages in [3, 4]
+ ],
+ key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
+ cache_results=True,
+)
+@triton.jit
+def _lse_and_store_logits_kernel(
+ Q,
+ K,
+ cu_q,
+ cu_k,
+ w_b, # int32 pointers
+ out_m,
+ out_S, # [B, Hk, ROWS_MAX] float32
+ LOGITS, # [Nk, Hk, ROWS_MAX] float32
+ sm_scale, # float
+ QUERY_GROUP_SIZE: tl.constexpr,
+ D: tl.constexpr,
+ STRIDE_Q_NQ,
+ STRIDE_Q_HQ,
+ STRIDE_K_NK,
+ STRIDE_K_HK,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ ROWS_MAX,
+):
+ # program ids
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+ rid = tl.program_id(2) # row-tile id
+ # batch segment bounds
+ q_end = tl.load(cu_q + b + 1)
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ q_win_beg = q_end - win
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ # rows for this (b,hk)
+ rows_b = win * QUERY_GROUP_SIZE
+ row0 = rid * BLOCK_Q
+ if row0 >= rows_b:
+ return
+
+ # exp(x) = exp2(x * 1/ln2)
+ qk_scale = sm_scale * 1.4426950408889634
+
+ offs_qrow = row0 + tl.arange(0, BLOCK_Q)
+ row_mask = offs_qrow < rows_b
+
+ # map row -> (q_idx, hq_local)
+ hq_local = offs_qrow % QUERY_GROUP_SIZE
+ q_off = offs_qrow // QUERY_GROUP_SIZE
+ q_idx = q_win_beg + q_off
+ hq_glob = hk * QUERY_GROUP_SIZE + hq_local
+
+ offs_d = tl.arange(0, D)
+
+ q_ptrs = (
+ Q
+ + q_idx[:, None] * STRIDE_Q_NQ
+ + hq_glob[:, None] * STRIDE_Q_HQ
+ + offs_d[None, :]
+ )
+ q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
+ m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
+ S = tl.zeros([BLOCK_Q], dtype=tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
+ k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
+
+ s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
+ s = tl.where(kmask[None, :], s, -float("inf"))
+
+ # store into LOGITS[nk, hk, row] -> [BK, BQ]
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
+
+ # log2 streaming LSE update
+ cur_max = tl.max(s, 1) # [BQ]
+ n_m = tl.maximum(m, cur_max)
+ rescale = tl.math.exp2(m - n_m)
+ S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
+ m = n_m
+
+ # store m,S for these rows
+ m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
+ tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
+
+
+@triton_autotune(
+ configs=[
+ triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
+ for bq in [16, 32, 64]
+ for bk in [32, 64, 128]
+ ],
+ key=["HK", "HQ"],
+ cache_results=True,
+)
+@triton.jit
+def _scores_from_logits_kernel(
+ cu_k,
+ w_b,
+ in_m,
+ in_S, # [B, Hk, ROWS_MAX] f32
+ LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits
+ OUT, # [Nk, Hk] f32
+ #
+ QUERY_GROUP_SIZE: tl.constexpr,
+ STRIDE_M_B,
+ STRIDE_M_H,
+ STRIDE_M_R,
+ STRIDE_S_B,
+ STRIDE_S_H,
+ STRIDE_S_R,
+ STRIDE_LG_NK,
+ STRIDE_LG_HK,
+ STRIDE_LG_R,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ BLOCK_Q: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ #
+ DO_POOL: tl.constexpr, # set True to enable in-place avg pool
+ KPOOL: tl.constexpr, # kernel size for avg pool (stride=1)
+):
+ b = tl.program_id(0)
+ hk = tl.program_id(1)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if (win <= 0) or (k_eff_end <= k_beg):
+ return
+
+ rows_b = win * QUERY_GROUP_SIZE
+
+ # === scores over computed region ===
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+
+ scores = tl.zeros([BLOCK_K], dtype=tl.float32)
+
+ for row0 in tl.range(0, rows_b, BLOCK_Q):
+ r_idx = row0 + tl.arange(0, BLOCK_Q)
+ rmask = r_idx < rows_b
+
+ # load m, S for rows
+ m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
+ S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
+ m = tl.load(
+ m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
+ mask=rmask,
+ other=-float("inf"),
+ )
+ S = tl.load(
+ S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
+ )
+
+ valid_row = S > 0
+ m = tl.where(valid_row, m, 0.0)
+ S = tl.where(valid_row, S, 1.0)
+
+ # load stored logits^T: [BK, BQ]
+ log_ptrs = (
+ LOGITS
+ + nk[:, None] * STRIDE_LG_NK
+ + hk * STRIDE_LG_HK
+ + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
+ )
+ s_T = tl.load(
+ log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
+ ) # [BK, BQ]
+
+ # probs^T = exp2(s_T - m) / S, sum over rows
+ probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
+ probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
+
+ scores += tl.sum(probs_T, 1) # [BK]
+
+ if DO_POOL and (KPOOL > 1):
+ i = tl.arange(0, BLOCK_K)[:, None]
+ j = tl.arange(0, BLOCK_K)[None, :]
+ band = (j <= i) & ((i - j) < KPOOL)
+ band = band & kmask[None, :]
+ # sum within band
+ sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1) # [BK]
+ denom = tl.sum(band, 1).to(tl.float32) # [BK]
+ denom = tl.where(denom > 0, denom, 1.0)
+ scores = sums / denom
+
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(out_ptrs, scores, mask=kmask)
+
+ pad_beg = k_eff_end
+ pad_end = k_end
+ if pad_end > pad_beg:
+ for ks in tl.range(pad_beg, pad_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < pad_end
+ out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
+ tl.store(
+ out_ptrs, tl.full([BLOCK_K], float("inf"), dtype=tl.float32), mask=kmask
+ )
+
+
+@triton_autotune(
+ configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
+ key=["HK"],
+ cache_results=True,
+)
+@triton.jit
+def _zscore_per_batch_epilogue(
+ OUT, # [Nk, Hk], float32
+ cu_k,
+ w_b, # [B+1], [B] int32
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK: tl.constexpr, # Hk
+ EPS: tl.constexpr, # e.g., 1e-12
+ BLOCK_K: tl.constexpr, # e.g., 128
+):
+ b = tl.program_id(0)
+
+ k_beg = tl.load(cu_k + b)
+ k_end = tl.load(cu_k + b + 1)
+ win = tl.load(w_b + b)
+
+ k_eff_end = k_end - win
+ if k_eff_end <= k_beg:
+ return
+
+ sumv = tl.zeros([], dtype=tl.float32)
+ sumsq = tl.zeros([], dtype=tl.float32)
+ count = ((k_eff_end - k_beg) * HK).to(tl.float32)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ sumv += tl.sum(vals, 0)
+ sumsq += tl.sum(vals * vals, 0)
+
+ mean = sumv / count
+ var = tl.maximum(sumsq / count - mean * mean, 0.0)
+ invstd = 1.0 / tl.sqrt(var + EPS)
+
+ for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
+ nk = ks + tl.arange(0, BLOCK_K)
+ kmask = nk < k_eff_end
+ for h in tl.range(0, HK):
+ ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
+ vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
+ vals = (vals - mean) * invstd
+ tl.store(ptrs, vals, mask=kmask)
+
+
+def query_aware_key_scores(
+ q: torch.Tensor, # [N_q, Hq, D]
+ k: torch.Tensor, # [N_k, Hk, D]
+ cu_seqlens_q: torch.Tensor, # [B+1], int32
+ cu_seqlens_k: torch.Tensor, # [B+1], int32
+ w: torch.Tensor | int, # [B], int32
+ sm_scale: float = None, # defaults to 1/sqrt(D)
+ *,
+ accum_scores: torch.Tensor = None,
+ accum_blending: float = None,
+ normalize: bool = False,
+) -> Optional[torch.Tensor]:
+ assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
+ device = q.device
+ N_q, Hq, D = q.shape
+ N_k, Hk, Dk = k.shape
+ assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
+ if sm_scale is None:
+ sm_scale = 1.0 / math.sqrt(D)
+
+ B = cu_seqlens_q.numel() - 1
+ assert B == cu_seqlens_k.numel() - 1
+
+ G = Hq // Hk
+ if type(w) is int:
+ max_w = w
+ w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
+ else:
+ max_w = int(w.max().item())
+ assert w.numel() == B
+ ROWS_MAX = max_w * G
+ if ROWS_MAX == 0:
+ return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
+
+ out = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
+ m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+ logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
+
+ # strides
+ STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
+ STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
+ STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
+ STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
+ STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
+ STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
+
+ def grid(META):
+ return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
+
+ _lse_and_store_logits_kernel[grid](
+ q,
+ k,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ sm_scale,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ D=D,
+ STRIDE_Q_NQ=STRIDE_Q_NQ,
+ STRIDE_Q_HQ=STRIDE_Q_HQ,
+ STRIDE_K_NK=STRIDE_K_NK,
+ STRIDE_K_HK=STRIDE_K_HK,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ ROWS_MAX=ROWS_MAX,
+ )
+
+ _scores_from_logits_kernel[(B, Hk)](
+ cu_seqlens_k,
+ w,
+ m_scratch,
+ S_scratch,
+ logits_buf,
+ out,
+ QUERY_GROUP_SIZE=Hq // Hk,
+ STRIDE_M_B=STRIDE_M_B,
+ STRIDE_M_H=STRIDE_M_H,
+ STRIDE_M_R=STRIDE_M_R,
+ STRIDE_S_B=STRIDE_S_B,
+ STRIDE_S_H=STRIDE_S_H,
+ STRIDE_S_R=STRIDE_S_R,
+ STRIDE_LG_NK=STRIDE_LG_NK,
+ STRIDE_LG_HK=STRIDE_LG_HK,
+ STRIDE_LG_R=STRIDE_LG_R,
+ STRIDE_OUT_NK=STRIDE_OUT_NK,
+ STRIDE_OUT_HK=STRIDE_OUT_HK,
+ DO_POOL=True,
+ KPOOL=5,
+ )
+ if normalize:
+ _zscore_per_batch_epilogue[(B,)](
+ out,
+ cu_seqlens_k,
+ w,
+ STRIDE_OUT_NK,
+ STRIDE_OUT_HK,
+ HK=Hk,
+ EPS=1e-12,
+ )
+ if accum_scores is not None:
+ if accum_blending is not None:
+ accum_scores.mul_(accum_blending)
+ accum_scores.add_(out)
+ return accum_scores
+ else:
+ return out
diff --git a/vllm/kvprune_legacy_save/config/__init__.py b/vllm/kvprune_legacy_save/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..717459650025a6551cdc91bb5136c450984eaca6
--- /dev/null
+++ b/vllm/kvprune_legacy_save/config/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Engine / sampling / kernel constants (compactor-compatible)."""
+
+from vllm.kvprune.config.constants import RESERVED_BATCH, TRITON_RESERVED_BATCH
+
+__all__ = ["RESERVED_BATCH", "TRITON_RESERVED_BATCH"]
diff --git a/vllm/kvprune_legacy_save/config/constants.py b/vllm/kvprune_legacy_save/config/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ff12d82a54466fd9f7264bcc81f5ba4653e65b3
--- /dev/null
+++ b/vllm/kvprune_legacy_save/config/constants.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+RESERVED_BATCH = 0
+# NOTE: Triton `tl.constexpr` is intended for use in kernel signatures/annotations.
+# Some Triton builds reject passing `tl.constexpr(...)` objects as constexpr values.
+# Keep the runtime value as a plain int and let kernel signatures declare constexpr.
+TRITON_RESERVED_BATCH = RESERVED_BATCH
diff --git a/vllm/kvprune_legacy_save/config/engine_config.py b/vllm/kvprune_legacy_save/config/engine_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab165ffb92b93a21643f6270501bf3c61fe280dd
--- /dev/null
+++ b/vllm/kvprune_legacy_save/config/engine_config.py
@@ -0,0 +1,129 @@
+import os
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import List, Optional
+
+from transformers import AutoConfig
+
+
+class AttentionBackend(Enum):
+ """Legacy coarse backend toggle (prefer :class:`KvpruneAttentionSchedule`)."""
+
+ FLASH_ATTENTION = auto()
+ COMPACTOR_TRITON = auto()
+
+
+class KvpruneAttentionSchedule(Enum):
+ """FlashAttention vs Triton split for prefill / decode (KV **writes** stay Triton)."""
+
+ # Default: FA varlen prefill; decode uses ``head_sparse_decode_attention`` (Triton).
+ FA_PREFILL_TRITON_DECODE = auto()
+ # Prefill attention uses ``causal_sparse_varlen_with_cache`` (Triton); decode Triton.
+ TRITON_PREFILL_TRITON_DECODE = auto()
+ # "PDFA": FA prefill + FA decode; paged KV **storage** (incl. pruned top-k) unchanged.
+ PDFA = auto()
+
+
+@dataclass
+class LLMConfig:
+ """Configuration for the :class:`LLM` engine.
+ Parameters
+ ----------
+ model : str
+ Hugging Face model identifier (e.g. ``"meta-llama/Meta-Llama-3-8B"``) or
+ a local model name that can be resolved by
+ :func:`transformers.AutoConfig.from_pretrained`.
+ path : str, optional
+ Local directory containing the model weights. If ``None``, the engine
+ will attempt to resolve a local snapshot for ``model`` using
+ :func:`huggingface_hub.snapshot_download`.
+ max_num_seqs : int, default 256
+ Upper bound on the number of concurrent batches that the scheduler and
+ KV-cache manager are allowed to handle. This affects the size of the
+ page table and some internal buffers.
+ max_model_len : int, default 40960
+ Maximum context length (in tokens) that the engine will allocate KV cache
+ and CUDA graphs for. During initialization this value is clamped to
+ ``hf_config.max_position_embeddings`` for the chosen model.
+ gpu_memory_utilization : float, default 0.9
+ Fraction of the total GPU memory that may be used for KV cache and model
+ activations. Values should be in ``(0, 1]``. If this budget is too small,
+ the KV-cache manager may raise an error at warmup time due
+ to insufficient memory.
+ tensor_parallel_size : int, default 1
+ Number of tensor-parallel workers to shard the model
+ across. Must be between 1 and 8, and must evenly divide the model's
+ number of key/value heads.
+ enforce_eager : bool, default False
+ If ``True``, disable CUDA graph capture and always run the model in
+ eager mode during decoding. This reduces throughput. When ``False``,
+ the engine will capture and reuse CUDA graphs for supported
+ batch sizes and sequence lengths.
+ hf_config : transformers.AutoConfig, optional
+ Pre-loaded Hugging Face configuration for the model. If ``None``,
+ it will then be populated automatically based on ``model``.
+ eos : int, default -1
+ Primary stop token id (warmup / single-id paths). If ``-1``, the
+ :class:`LLM` constructor fills this and :attr:`eos_token_ids` from the
+ tokenizer.
+ eos_token_ids : list of int, optional
+ All token ids that terminate generation (e.g. HF tokenizers may expose
+ ``eos_token_id`` as a list for chat models). If ``None``, inferred in
+ :class:`LLM` from the tokenizer and model type.
+ kvcache_page_size : int, default 128
+ Number of tokens stored in a single KV-cache page. Smaller pages improve
+ allocation flexibility but increase page-table overhead; larger pages
+ reduce overhead but have coarser granularity.
+ leverage_sketch_size : int, default 48
+ Sketch dimension used by the Compactor leverage-score estimator.
+ attention_schedule : KvpruneAttentionSchedule, default FA_PREFILL_TRITON_DECODE
+ Which **attention** implementation runs on prefill vs decode. KV **writes**
+ (``prefill_store_*``, ``decode_store_kv``, pruned top-k) always use the
+ existing Triton store kernels. Env ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` uses
+ short names: ``fa_triton`` (default), ``pdtriton``, ``pdfa``. Enum values:
+ ``FA_PREFILL_TRITON_DECODE`` — FA prefill, Triton decode;
+ ``TRITON_PREFILL_TRITON_DECODE`` — Triton prefill + decode;
+ ``PDFA`` — FA prefill + FA decode (still Triton KV I/O).
+ attention_backend : AttentionBackend, optional
+ Deprecated. Ignored if ``attention_schedule`` is set; otherwise mapped
+ for backward compatibility.
+ """
+
+ model: str
+ path: Optional[str] = None
+ nccl_port: Optional[int] = 1218
+ max_num_seqs: int = 256
+ max_model_len: int = 40960
+ gpu_memory_utilization: float = 0.9
+ tensor_parallel_size: int = 1
+ enforce_eager: bool = False
+ hf_config: AutoConfig | None = None
+ eos: int = -1
+ eos_token_ids: Optional[List[int]] = None
+ kvcache_page_size: int = 128
+ leverage_sketch_size: int = 48
+ attention_schedule: KvpruneAttentionSchedule = (
+ KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ )
+ attention_backend: AttentionBackend | None = None
+ show_progress_bar: bool = True
+
+ def __post_init__(self):
+ if self.attention_backend is not None:
+ if self.attention_backend == AttentionBackend.FLASH_ATTENTION:
+ self.attention_schedule = KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ else:
+ self.attention_schedule = (
+ KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
+ )
+ if self.path is not None and not os.path.isdir(self.path):
+ raise NotADirectoryError(f"Engine config dir {self.path} does not exist")
+ if self.tensor_parallel_size <= 0 or self.tensor_parallel_size > 8:
+ assert 1 <= self.tensor_parallel_size <= 8
+ raise ValueError("tensor_parallel_size must be >= 1 and <= 8")
+ if self.hf_config is None:
+ self.hf_config = AutoConfig.from_pretrained(self.model)
+ self.max_model_len = min(
+ self.max_model_len, self.hf_config.max_position_embeddings
+ )
+
diff --git a/vllm/kvprune_legacy_save/config/sampling_params.py b/vllm/kvprune_legacy_save/config/sampling_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..8202ad67d07ed082822eedcc926e3fa85cf40234
--- /dev/null
+++ b/vllm/kvprune_legacy_save/config/sampling_params.py
@@ -0,0 +1,11 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class SamplingParams:
+ temperature: float = 1.0
+ max_new_tokens: int = 256
+
+ def __post_init__(self):
+ if self.temperature < 0:
+ raise ValueError("Temperature cannot be negative")
diff --git a/vllm/kvprune_legacy_save/core/__init__.py b/vllm/kvprune_legacy_save/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaaa9491a8294328ccca53ffeac6a32422ebda51
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/__init__.py
@@ -0,0 +1,45 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Core: compactor ``LLMEngine`` stack (``llm_engine``, ``scheduler``, …) plus helpers
+(``runtime``, ``flash_integration``, ``block_budget``) used **inside** the compactor path.
+
+v1 does not import these; use :meth:`vllm.LLM.generate` with ``compression=`` for the
+``LLM`` + compactor integration.
+"""
+
+from vllm.kvprune.core.block_budget import (
+ TailReclaimHint,
+ build_tail_reclaim_hint,
+ tail_blocks_if_logical_shorter,
+)
+from vllm.kvprune.core.compression_bridge import (
+ VALID_ALIASES_FOR_SAMPLING,
+ compression_method_id_to_enum,
+ compression_method_str_to_id,
+)
+from vllm.kvprune.core.flash_integration import (
+ do_kv_cache_update_kv_prune,
+ merge_seq_lens_with_kv_prune,
+)
+from vllm.kvprune.core.runtime import (
+ KVPruneForwardState,
+ build_kv_prune_forward_state,
+ get_kv_prune_state,
+ layer_index_from_layer_name,
+)
+
+__all__ = [
+ "KVPruneForwardState",
+ "TailReclaimHint",
+ "VALID_ALIASES_FOR_SAMPLING",
+ "build_kv_prune_forward_state",
+ "build_tail_reclaim_hint",
+ "compression_method_id_to_enum",
+ "compression_method_str_to_id",
+ "do_kv_cache_update_kv_prune",
+ "get_kv_prune_state",
+ "layer_index_from_layer_name",
+ "merge_seq_lens_with_kv_prune",
+ "tail_blocks_if_logical_shorter",
+]
diff --git a/vllm/kvprune_legacy_save/core/block_budget.py b/vllm/kvprune_legacy_save/core/block_budget.py
new file mode 100644
index 0000000000000000000000000000000000000000..e88233e9c8b94fb36c80313fe86e89362caa3a3a
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/block_budget.py
@@ -0,0 +1,69 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Block budget helpers for compactor KV pruning (logical vs physical length).
+
+Used by the **compactor** ``LLMEngine`` path (``PagedKVCache`` / logical lengths),
+not by v1's scheduler. The helpers compare logical KV length to a physical token
+count and return how many full tail blocks can be reclaimed when logical shrinks.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class TailReclaimHint:
+ """How many tail blocks could be freed if logical KV shrinks below allocation."""
+
+ request_id: str
+ allocated_tokens: int
+ logical_tokens: int
+ block_size: int
+ reclaimable_tail_blocks: int
+
+
+def tail_blocks_if_logical_shorter(
+ allocated_tokens: int,
+ logical_tokens: int,
+ block_size: int,
+) -> int:
+ """Return count of fully-unused tail blocks when ``logical < allocated``.
+
+ Block-granular: only counts whole blocks past the last block that still
+ contains a retained logical token index.
+ """
+ if block_size <= 0:
+ return 0
+ if logical_tokens >= allocated_tokens:
+ return 0
+ # Last logical token occupies block index floor((logical-1)/bs) if logical>0
+ if logical_tokens <= 0:
+ return (allocated_tokens + block_size - 1) // block_size
+ last_logical_block = (logical_tokens - 1) // block_size
+ last_alloc_block = (allocated_tokens - 1) // block_size
+ return max(0, last_alloc_block - last_logical_block)
+
+
+def build_tail_reclaim_hint(
+ request_id: str,
+ allocated_tokens: int,
+ logical_tokens: int,
+ block_size: int,
+) -> TailReclaimHint:
+ n = tail_blocks_if_logical_shorter(allocated_tokens, logical_tokens, block_size)
+ return TailReclaimHint(
+ request_id=request_id,
+ allocated_tokens=allocated_tokens,
+ logical_tokens=logical_tokens,
+ block_size=block_size,
+ reclaimable_tail_blocks=n,
+ )
+
+
+__all__ = [
+ "TailReclaimHint",
+ "build_tail_reclaim_hint",
+ "tail_blocks_if_logical_shorter",
+]
diff --git a/vllm/kvprune_legacy_save/core/compression_bridge.py b/vllm/kvprune_legacy_save/core/compression_bridge.py
new file mode 100644
index 0000000000000000000000000000000000000000..43046dccf82b7233d7169d81bc299616ff3314f8
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/compression_bridge.py
@@ -0,0 +1,60 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Map compression method strings (e.g. from :class:`~vllm.kvprune.integration.CompressionParams`) to kvprune GPU / enum IDs."""
+
+from __future__ import annotations
+
+from vllm.kvprune.compression.compression_config import CompressionMethod
+
+# IDs stored on device [num_reqs_padded] (int32). Order is stable for kernels.
+COMPRESSION_METHOD_ID_NONE = 0
+COMPRESSION_METHOD_ID_CRITICALADAKV = 1
+COMPRESSION_METHOD_ID_COMPACTOR = 2
+COMPRESSION_METHOD_ID_SNAPKV = 3
+
+# Aliases accepted for method strings (case-insensitive after strip).
+VALID_ALIASES_FOR_SAMPLING: frozenset[str] = frozenset(
+ {"none", "criticaladakv", "compactor", "snapkv"}
+)
+
+_STR_TO_ID: dict[str, int] = {
+ "none": COMPRESSION_METHOD_ID_NONE,
+ "criticaladakv": COMPRESSION_METHOD_ID_CRITICALADAKV,
+ "compactor": COMPRESSION_METHOD_ID_COMPACTOR,
+ "snapkv": COMPRESSION_METHOD_ID_SNAPKV,
+}
+
+_ID_TO_COMPRESSION_METHOD: dict[int, CompressionMethod] = {
+ COMPRESSION_METHOD_ID_NONE: CompressionMethod.NONE,
+ COMPRESSION_METHOD_ID_CRITICALADAKV: CompressionMethod.CRITICALADAKV,
+ COMPRESSION_METHOD_ID_COMPACTOR: CompressionMethod.COMPACTOR,
+ COMPRESSION_METHOD_ID_SNAPKV: CompressionMethod.SNAPKV,
+}
+
+
+def compression_method_str_to_id(s: str) -> int:
+ """Normalize and map user string to a stable int id (0..3)."""
+ key = (s or "none").strip().lower()
+ if key not in _STR_TO_ID:
+ raise ValueError(
+ f"Unknown compression_method {s!r}; expected one of "
+ f"{sorted(VALID_ALIASES_FOR_SAMPLING)}"
+ )
+ return _STR_TO_ID[key]
+
+
+def compression_method_id_to_enum(method_id: int) -> CompressionMethod:
+ if method_id not in _ID_TO_COMPRESSION_METHOD:
+ return CompressionMethod.NONE
+ return _ID_TO_COMPRESSION_METHOD[method_id]
+
+
+__all__ = [
+ "COMPRESSION_METHOD_ID_NONE",
+ "COMPRESSION_METHOD_ID_CRITICALADAKV",
+ "COMPRESSION_METHOD_ID_COMPACTOR",
+ "COMPRESSION_METHOD_ID_SNAPKV",
+ "VALID_ALIASES_FOR_SAMPLING",
+ "compression_method_id_to_enum",
+ "compression_method_str_to_id",
+]
diff --git a/vllm/kvprune_legacy_save/core/flash_integration.py b/vllm/kvprune_legacy_save/core/flash_integration.py
new file mode 100644
index 0000000000000000000000000000000000000000..247c7ea9e45f4011d5177a25683e14416b0368a2
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/flash_integration.py
@@ -0,0 +1,92 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""FlashAttention + KV cache hooks for kvprune."""
+
+from __future__ import annotations
+
+import torch
+
+from vllm.kvprune.core.runtime import KVPruneForwardState, get_kv_prune_state
+
+_RATIO_ONE = 1.0 - 1e-6
+
+
+def merge_seq_lens_with_kv_prune(
+ base_seq_lens: torch.Tensor,
+ layer_name: str,
+ max_query_len: int,
+) -> torch.Tensor:
+ """Blend scheduler seq_lens with per-layer logical lengths when pruning."""
+ state = get_kv_prune_state()
+ if state is None:
+ return base_seq_lens
+ # Prefill: only scheduler lengths are reliable unless compactor store ran for
+ # every layer (try_prefill_kv_store); when pruning is requested but ineligible
+ # (e.g. unsupported dtype), logical buffers may still be zero — do not override.
+ if max_query_len > 1:
+ return base_seq_lens
+ layer_idx = _layer_idx(layer_name)
+ num_reqs = state.num_reqs
+ comp = state.compression_ratio_gpu[:num_reqs]
+ logical = state.logical_seq_lens_gpu[layer_idx, :num_reqs]
+ if logical.dim() == 2:
+ logical = logical.max(dim=-1).values
+ out = base_seq_lens.clone()
+ use_logical = comp < _RATIO_ONE
+ out[:num_reqs] = torch.where(
+ use_logical,
+ logical.to(out.dtype),
+ base_seq_lens[:num_reqs],
+ )
+ return out
+
+
+def _layer_idx(layer_name: str) -> int:
+ from vllm.kvprune.core.runtime import layer_index_from_layer_name
+
+ return layer_index_from_layer_name(layer_name)
+
+
+def do_kv_cache_update_kv_prune(
+ layer: torch.nn.Module,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ kv_cache: torch.Tensor,
+ slot_mapping: torch.Tensor,
+ reshape_and_cache_flash,
+ kv_cache_dtype: str,
+) -> bool:
+ """If kvprune handles this step, return True (caller skips default path)."""
+ state = get_kv_prune_state()
+ if state is None:
+ return False
+
+ layer_idx = _layer_idx(layer.layer_name)
+ num_reqs = state.num_reqs
+
+ if state.is_prefill:
+ from vllm.kvprune.compression.prefill import try_prefill_kv_store
+
+ if try_prefill_kv_store(layer, key, value, kv_cache):
+ return True
+ return False
+
+ key_cache, value_cache = kv_cache.unbind(0)
+ reshape_and_cache_flash(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ slot_mapping,
+ kv_cache_dtype,
+ layer._k_scale,
+ layer._v_scale,
+ )
+ comp = state.compression_ratio_gpu[:num_reqs]
+ mask = (comp < _RATIO_ONE).to(torch.int32)
+ layer_buf = state.logical_seq_lens_gpu[layer_idx, :num_reqs]
+ if layer_buf.dim() == 2:
+ layer_buf += mask.unsqueeze(-1)
+ else:
+ layer_buf += mask
+ return True
diff --git a/vllm/kvprune_legacy_save/core/llm_engine.py b/vllm/kvprune_legacy_save/core/llm_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..0813cbcabc2a36f92be188a13c01c39672c36c2d
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/llm_engine.py
@@ -0,0 +1,441 @@
+from __future__ import annotations
+
+import atexit
+import inspect
+import logging
+from pathlib import Path
+from typing import Any, List, Optional, Union
+
+import torch.nn as nn
+import torch.multiprocessing as mp
+from vllm.kvprune.compression.compression_config import (
+ BatchCompressionParams,
+ SequenceCompressionParams,
+)
+from vllm.kvprune.config.engine_config import LLMConfig
+from vllm.kvprune.config.sampling_params import SamplingParams
+from vllm.kvprune.core.model_runner import ModelRunner
+from vllm.kvprune.models import MODEL_REGISTRY
+from vllm.kvprune.utils.sequence import Sequence
+from transformers import AutoTokenizer
+
+logger = logging.getLogger(__name__)
+
+PromptLike = Union[str, List[int]]
+
+
+def _infer_stop_token_ids(tokenizer, hf_config) -> list[int]:
+ """
+ Build the set of token ids that should end generation.
+
+ Newer HF chat tokenizers often expose ``eos_token_id`` as a *list* of ids.
+ The engine must not compare generated ids to that list as a single ``int``;
+ see :attr:`LLMConfig.eos_token_ids` and decode-time ``torch.isin``.
+
+ Qwen chat uses ```` (im_end) as the assistant turn boundary; include it
+ when present in ``additional_special_tokens`` / ``added_tokens_encoder``. We
+ avoid loose substring matches like ``\"end\"`` that can tag unrelated tokens.
+ """
+ raw = tokenizer.eos_token_id
+ ids: list[int] = []
+ if isinstance(raw, (list, tuple)):
+ ids.extend(int(x) for x in raw)
+ elif raw is not None:
+ ids.append(int(raw))
+ unk_id = getattr(tokenizer, "unk_token_id", None)
+
+ def _maybe_add_tid(tid: int) -> None:
+ if not isinstance(tid, int) or tid < 0:
+ return
+ if unk_id is not None and tid == unk_id:
+ return
+ if tid not in ids:
+ ids.append(tid)
+
+ model_type = getattr(hf_config, "model_type", None)
+ if model_type in ("qwen2", "qwen3", "qwen2_moe", "qwen3_moe"):
+ enc = getattr(tokenizer, "added_tokens_encoder", None)
+ if isinstance(enc, dict):
+ for key, tid in enc.items():
+ if isinstance(key, str) and "im_end" in key:
+ _maybe_add_tid(int(tid))
+ for extra in getattr(tokenizer, "additional_special_tokens", []) or []:
+ if not isinstance(extra, str) or "im_end" not in extra:
+ continue
+ try:
+ tid = tokenizer.convert_tokens_to_ids(extra)
+ except (TypeError, ValueError, KeyError):
+ continue
+ _maybe_add_tid(tid)
+
+ if not ids:
+ raise ValueError(
+ "Could not infer stop token ids from the tokenizer; set "
+ "LLMConfig(eos_token_ids=[...]) explicitly."
+ )
+ return ids
+
+
+def _merge_apply_chat_template_kwargs(
+ tokenizer,
+ user_kwargs: Optional[dict[str, Any]],
+) -> dict[str, Any]:
+ """
+ Merge user kwargs with defaults for HF chat templates that support them.
+
+ Qwen3 (and similar) instruct models expect `add_generation_prompt=True` so
+ the first generated token continues the assistant turn; without it, output
+ can repeat punctuation / template fragments. `enable_thinking=False` avoids
+ the Qwen3 reasoning channel when the tokenizer supports it.
+ """
+ out = dict(user_kwargs or {})
+ try:
+ sig = inspect.signature(tokenizer.apply_chat_template)
+ except (TypeError, ValueError):
+ return out
+ if "add_generation_prompt" in sig.parameters and "add_generation_prompt" not in out:
+ out["add_generation_prompt"] = True
+ if "enable_thinking" in sig.parameters and "enable_thinking" not in out:
+ out["enable_thinking"] = False
+ return out
+
+
+def _runner_entry(config: LLMConfig, rank: int, evt):
+ runner = None
+ try:
+ runner = ModelRunner(config, rank, evt)
+ runner.loop()
+ except Exception as e:
+ logging.exception(f"Rank {rank}: {repr(e)}")
+ finally:
+ if runner is not None:
+ runner.exit()
+
+
+class LLMEngine:
+ """High-level engine coordinating model runners and scheduling"""
+
+ def __init__(self, config: LLMConfig, external_model: nn.Module | None = None):
+ self.config = config
+ if self.config.hf_config.model_type not in MODEL_REGISTRY:
+ raise ValueError(f"Unknown model {self.config.model}")
+ if config.path is None:
+ # Local directory: use it directly (no Hub round-trip).
+ try:
+ mp = Path(config.model)
+ if mp.is_dir() and (mp / "config.json").is_file():
+ self.config.path = str(mp.resolve())
+ logger.info("Using local model directory for tokenizer: %s", self.config.path)
+ except OSError:
+ pass
+ if config.path is None:
+ from huggingface_hub import snapshot_download
+
+ # Hub repo id: allow downloading missing shards/tokenizer files when cache
+ # is incomplete (local_files_only=False). Local dirs are handled above.
+ self.config.path = snapshot_download(
+ repo_id=config.model,
+ local_files_only=False,
+ )
+ logger.info(
+ "Resolved Hugging Face snapshot for %s @ %s",
+ self.config.model,
+ self.config.path,
+ )
+ assert self.config.path is not None
+ _trust = bool(getattr(self.config.hf_config, "trust_remote_code", False))
+ # Always load tokenizer from the resolved on-disk tree so we do not re-hit
+ # the Hub with the repo id (can re-download tokenizer / LFS shards).
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.config.path,
+ use_fast=True,
+ trust_remote_code=_trust,
+ )
+ if self.config.eos_token_ids is None:
+ if self.config.eos != -1:
+ self.config.eos_token_ids = [int(self.config.eos)]
+ else:
+ self.config.eos_token_ids = _infer_stop_token_ids(
+ self.tokenizer, self.config.hf_config
+ )
+ else:
+ self.config.eos_token_ids = [int(x) for x in self.config.eos_token_ids]
+ self.config.eos_token_ids = sorted(set(self.config.eos_token_ids))
+ if self.config.eos == -1:
+ self.config.eos = int(self.config.eos_token_ids[0])
+ else:
+ self.config.eos = int(self.config.eos)
+ if self.config.eos not in self.config.eos_token_ids:
+ self.config.eos_token_ids = sorted(
+ self.config.eos_token_ids + [self.config.eos]
+ )
+
+ if external_model is not None and int(self.config.tensor_parallel_size) != 1:
+ raise ValueError(
+ "external_model (shared-weight compactor path) only supports "
+ "tensor_parallel_size=1"
+ )
+
+ self.ps = []
+ world_size = int(self.config.tensor_parallel_size)
+ self.events = []
+ if world_size > 1:
+ ctx = mp.get_context("spawn")
+ for r in range(1, world_size):
+ event = ctx.Event()
+ p = ctx.Process(
+ target=_runner_entry,
+ args=(self.config, r, event),
+ daemon=True,
+ )
+ p.start()
+ self.ps.append(p)
+ self.events.append(event)
+
+ self.master_model_runner = ModelRunner(
+ self.config,
+ rank=0,
+ peer_events=self.events,
+ external_model=external_model,
+ )
+ atexit.register(self.exit)
+
+ def exit(self):
+ if getattr(self, "_exited", False):
+ return
+ self._exited = True
+ runner = getattr(self, "master_model_runner", None)
+ if runner is not None:
+ try:
+ runner.exit()
+ except Exception:
+ logger.exception("Failed to exit master ModelRunner cleanly")
+ for p in self.ps:
+ if p.is_alive():
+ p.terminate()
+ p.join(timeout=1.0)
+ if hasattr(self, "events"):
+ self.events.clear()
+
+ def tokenize_prompt(self, prompt: PromptLike, **tokenizer_kwargs) -> List[int]:
+ """
+ Turn a raw prompt into token IDs.
+ """
+ if isinstance(prompt, str):
+ return self.tokenizer(prompt, **tokenizer_kwargs)["input_ids"]
+ else:
+ return list(prompt)
+
+ def detokenize_prompt(
+ self, sequences: List[Sequence], **detokenizer_kwargs
+ ) -> List[str]:
+ """
+ Turn completed Sequences into strings.
+ """
+ defaults: dict[str, Any] = {"skip_special_tokens": True}
+ merged = {**defaults, **detokenizer_kwargs}
+ return self.tokenizer.batch_decode(
+ [s.completion_token_ids for s in sequences], **merged
+ )
+
+ def _build_sequences(
+ self,
+ prompts: List[PromptLike] | PromptLike,
+ sampling_params: SamplingParams | List[SamplingParams],
+ per_sequence_compression_params: Optional[
+ SequenceCompressionParams | List[SequenceCompressionParams]
+ ] = None,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ ) -> List[Sequence]:
+ """
+ Build Sequence objects from prompts, sampling params, and optional
+ per-sequence compression parameters.
+ """
+ tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
+
+ if not isinstance(prompts, list):
+ prompts = [prompts]
+
+ if isinstance(sampling_params, SamplingParams):
+ sampling_params_list: List[SamplingParams] = [sampling_params] * len(
+ prompts
+ )
+ else:
+ sampling_params_list = sampling_params
+ assert len(sampling_params_list) == len(prompts), (
+ "sampling_params list must match prompts length"
+ )
+ if per_sequence_compression_params is None:
+ compression_params_list: List[SequenceCompressionParams] = [
+ SequenceCompressionParams(1.0) for _ in prompts
+ ]
+ elif isinstance(per_sequence_compression_params, SequenceCompressionParams):
+ compression_params_list = [per_sequence_compression_params] * len(prompts)
+ else:
+ # list-like
+ assert len(per_sequence_compression_params) == len(prompts), (
+ "per_sequence_compression_params list must match prompts length"
+ )
+ compression_params_list = list(per_sequence_compression_params)
+
+ seqs: List[Sequence] = []
+ for prompt, sparams, cparams in zip(
+ prompts, sampling_params_list, compression_params_list
+ ):
+ token_ids = self.tokenize_prompt(prompt, **tokenizer_kwargs)
+ if cparams.protected_first_tokens + cparams.protected_last_tokens >= len(token_ids):
+ cparams.compression_ratio = 1.0
+ seqs.append(
+ Sequence(
+ prompt_token_ids=token_ids,
+ sampling_params=sparams,
+ compression_params=cparams,
+ )
+ )
+ return seqs
+
+ def generate(
+ self,
+ prompts: List[PromptLike] | PromptLike,
+ sampling_params: SamplingParams | List[SamplingParams],
+ batch_compression_params: BatchCompressionParams,
+ *,
+ per_sequence_compression_params: Union[
+ List[SequenceCompressionParams], SequenceCompressionParams
+ ] = None,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ detokenizer_kwargs: Optional[dict[str, Any]] = None,
+ return_sequences: bool = False,
+ ) -> List[str] | tuple[List[str], List[Sequence]]:
+ """
+ Accept prompts and return completed Sequences.
+ Args:
+ :param prompts:
+ Single prompt or list of prompts, each either a raw text prompt,
+ or pre-tokenized input IDs.
+ :param sampling_params:
+ A single SamplingParams for all prompts in this batch or a list of
+ SamplingParams with the same length as ``prompts``.
+ :param batch_compression_params:
+ Compression settings for this batch.
+ :param per_sequence_compression_params:
+ Per-sequence compression parameters, including the compression
+ ratio to be applied and the size of the protected regions of the
+ sequence (how many start tokens and end tokens to keep uncompressed).
+ If a SequenceCompressionParams instance, the same params will be
+ applied to all sequences in this batch; if a list is provided,
+ each SequenceCompressionParams will be attached to the corresponding
+ prompt in the batch.
+ :param tokenizer_kwargs:
+ Extra kwargs forwarded to ``tokenizer(...)`` when tokenizing
+ string prompts.
+ :param detokenizer_kwargs:
+ Passed through to `tokenizer.batch_decode`.
+ :param return_sequences:
+ Whether to return sequence objects or not
+ Returns:
+ :return List[Sequence]:
+ One Sequence per input prompt, with `completion_token_ids`
+ filled in after generation.
+ """
+ tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
+ detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
+ seqs = self._build_sequences(
+ prompts,
+ sampling_params=sampling_params,
+ per_sequence_compression_params=per_sequence_compression_params,
+ tokenizer_kwargs=tokenizer_kwargs,
+ )
+ self.master_model_runner.generate(seqs, batch_compression_params)
+ output_strings = self.detokenize_prompt(seqs, **detokenizer_kwargs)
+ if return_sequences:
+ return output_strings, seqs
+ return output_strings
+
+ def generate_chat(
+ self,
+ messages_batch: List[List[dict]],
+ sampling_params: SamplingParams | List[SamplingParams],
+ batch_compression_params: BatchCompressionParams,
+ per_sequence_compression_params: Union[
+ SequenceCompressionParams, List[SequenceCompressionParams]
+ ],
+ *,
+ tokenizer_kwargs: Optional[dict[str, Any]] = None,
+ detokenizer_kwargs: Optional[dict[str, Any]] = None,
+ return_sequences: bool = False,
+ ) -> List[str] | tuple[List[str], List[Sequence]]:
+ """
+ Convenience API for chat-style prompts using HF `apply_chat_template`.
+ Args:
+ :param messages_batch:
+ List of conversations, where each conversation is a list of
+ message dicts like:
+ {"role": "system" | "user" | "assistant", "content": str}
+ :param sampling_params:
+ A single SamplingParams for all prompts in this batch or a list of
+ SamplingParams with the same length as ``prompts``.
+ :param batch_compression_params:
+ Batch Level compression settings. Can set compression_method.
+ :param per_sequence_compression_params:
+ Per-sequence compression parameters, including the compression
+ ratio to be applied and the size of the protected regions of the
+ sequence (how many start tokens and end tokens to keep uncompressed).
+ If a SequenceCompressionParams instance, the same params will be
+ applied to all sequences in this batch; if a list is provided,
+ each SequenceCompressionParams will be attached to the corresponding
+ conversation in the batch.
+ :param tokenizer_kwargs:
+ Passed through to `tokenizer.apply_chat_template`.
+ :param detokenizer_kwargs:
+ Passed through to `tokenizer.batch_decode`.
+ :param return_sequences:
+ Whether to return sequence objects or not
+ Returns:
+ :return List[str] or tuple[List[str], List[Sequence]]:
+ One string per conversation.
+ """
+ prompts_token_ids: List[List[int]] = []
+ tokenizer_kwargs = _merge_apply_chat_template_kwargs(
+ self.tokenizer, tokenizer_kwargs
+ )
+ detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
+ for messages in messages_batch:
+ input_ids = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=True,
+ **tokenizer_kwargs,
+ )
+ if hasattr(input_ids, "tolist"):
+ input_ids = input_ids.tolist()
+ prompts_token_ids.append(input_ids)
+
+ return self.generate(
+ prompts_token_ids,
+ sampling_params=sampling_params,
+ batch_compression_params=batch_compression_params,
+ per_sequence_compression_params=per_sequence_compression_params,
+ tokenizer_kwargs=tokenizer_kwargs,
+ detokenizer_kwargs=detokenizer_kwargs,
+ return_sequences=return_sequences,
+ )
+
+ def generate_from_sequences(
+ self,
+ seqs: List[Sequence],
+ batch_compression_params: BatchCompressionParams,
+ ) -> List[Sequence]:
+ """
+ Args:
+ :param seqs:
+ List of Sequence instances
+ :param batch_compression_params:
+ Compression settings.
+
+ Returns:
+ :return List[Sequence]:
+ Same list, mutated in-place with completions.
+ """
+ self.master_model_runner.generate(seqs, batch_compression_params)
+ return seqs
+
diff --git a/vllm/kvprune_legacy_save/core/memory_manager.py b/vllm/kvprune_legacy_save/core/memory_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd3ee2ce1abe60a857e488b53e4d18cef20e4663
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/memory_manager.py
@@ -0,0 +1,237 @@
+import logging
+import os
+from typing import Iterable, List, Optional
+
+import torch
+from vllm.kvprune.config.engine_config import LLMConfig
+from vllm.kvprune.kv_cache.page_table import KVAllocationStatus, PagedKVCache
+from vllm.kvprune.utils.tp_utils import kv_heads_shard_divisor
+from torch import nn
+
+logger = logging.getLogger(__name__)
+
+
+class KVCacheManager:
+ def __init__(
+ self,
+ rank: int,
+ config: LLMConfig,
+ *,
+ device: str | None = None,
+ ):
+ super().__init__()
+ hf_config = config.hf_config
+ self.rank = rank
+ self.gpu_frac = config.gpu_memory_utilization
+ self.page_size = config.kvcache_page_size
+ self.world_size = config.tensor_parallel_size
+ self.max_num_batches = config.max_num_seqs
+ self.max_model_len = config.max_model_len
+ self.num_layers = hf_config.num_hidden_layers
+ self.model_dtype = hf_config.torch_dtype
+ self.head_dim = getattr(hf_config, "head_dim", None)
+ self.max_pages_per_batch = (
+ self.max_model_len + self.page_size - 1
+ ) // self.page_size
+ _ws = kv_heads_shard_divisor()
+ self.num_kv_heads = hf_config.num_key_value_heads // _ws
+ assert hf_config.num_key_value_heads % _ws == 0, (
+ "tensor-parallel world size needs to divide num_kv_heads"
+ )
+ self._cache_device = device if device is not None else f"cuda:{self.rank}"
+
+ self.num_pages = None
+ self.paged_cache: Optional[PagedKVCache] = None
+ self.max_batched_tokens = None
+
+ self.seq_id_to_batch = {}
+
+ def allocate_sequences(
+ self, seq_ids: List[int], max_positions: List[int]
+ ) -> (bool, Optional[torch.Tensor]):
+ batch_mapping = []
+ for seq_id, len_to_alloc in zip(seq_ids, max_positions):
+ if seq_id not in self.seq_id_to_batch:
+ batch_id = self.paged_cache.new_batch()
+ if batch_id is None:
+ logger.warning("Failed to allocate batch!")
+ return False, None
+ self.seq_id_to_batch[seq_id] = int(batch_id)
+ batch_mapping.append(self.seq_id_to_batch[seq_id])
+ if (
+ alloc_status := self.paged_cache.reserve_tokens(
+ self.seq_id_to_batch[seq_id], len_to_alloc
+ )
+ ) != KVAllocationStatus.SUCCESS:
+ logger.warning(f"Failed to allocate pages ({alloc_status})!")
+ return False, None
+ batch_mapping = torch.as_tensor(batch_mapping, dtype=torch.int32, device="cuda")
+ return True, batch_mapping
+
+ def free_sequences(self, seq_ids: Iterable[int]):
+ for seq_id in seq_ids:
+ global_batch_id = self.seq_id_to_batch.pop(seq_id, None)
+ self.paged_cache.free_batch(global_batch_id)
+
+ def init_cache(self, model: nn.Module):
+ self.num_pages = self.get_num_pages(self.gpu_frac, self.max_pages_per_batch)
+ self.paged_cache = PagedKVCache(
+ num_layers=self.num_layers,
+ H_kv=self.num_kv_heads,
+ head_dim=self.head_dim,
+ page_size=self.page_size,
+ num_pages=int(self.num_pages),
+ max_num_batches=self.max_num_batches,
+ device=self._cache_device,
+ dtype=self.model_dtype,
+ max_logical_pages_per_head=int(self.max_pages_per_batch),
+ )
+ self._assign_cache_to_layers(model)
+
+ def _assign_cache_to_layers(self, model) -> None:
+ for layer_index, layer in enumerate(model.model.layers):
+ attn = layer.self_attn.attn
+ k, v, pt, bh = self.paged_cache.layer_slices(layer_index)
+ attn.k_cache = k
+ attn.v_cache = v
+ attn.page_table = pt
+ attn.bh_seq_lens = bh
+ attn.page_size = self.page_size
+
+ def get_num_pages(self, frac: float, n_logical_pages_max: int):
+ free, total = torch.cuda.mem_get_info()
+ used = total - free
+ stats = torch.cuda.memory_stats()
+ peak = int(stats["allocated_bytes.all.peak"])
+ current = int(stats["allocated_bytes.all.current"])
+ bytes_for_kv_budget = int(total * frac * 0.9) - used - peak + current
+
+ if bytes_for_kv_budget <= 0:
+ # Standalone compactor: ``frac`` is a fraction of total VRAM. When a second
+ # engine shares the GPU with vLLM (shared weights), most VRAM is already
+ # committed; the formula above goes negative. Fall back to a slice of
+ # *currently free* memory for the compactor KV pool.
+ free_frac = float(
+ os.environ.get("VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC", "0.55")
+ )
+ free_frac = max(0.05, min(free_frac, 0.95))
+ bytes_for_kv_budget = int(free * free_frac)
+ logger.warning(
+ "KV cache budget from gpu_memory_utilization (%.2f) is exhausted "
+ "(%.2f MiB free on device); using %.0f%% of free memory (~%.2f MiB) "
+ "for compactor KV (set VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC to adjust).",
+ frac,
+ free / (1024**2),
+ free_frac * 100,
+ bytes_for_kv_budget / (1024**2),
+ )
+ if bytes_for_kv_budget <= 0:
+ raise RuntimeError(
+ "Insufficient memory for compactor KV cache: no free GPU memory left "
+ "after the primary vLLM engine. Lower vLLM gpu_memory_utilization or "
+ "max_model_len, shorten prompts, or run compactor-only / vLLM-only "
+ "sessions. Raising gpu_memory_utilization here does not help."
+ )
+ # page_table[L, B, H_kv, N_LOGICAL_PAGES_MAX] + bh_seq_lens[L, B, H_kv]
+ int32_sz = torch.empty((), dtype=torch.int32).element_size() # 4
+ page_table_bytes_per_layer = (
+ self.max_num_batches
+ * self.num_kv_heads
+ * n_logical_pages_max
+ * int32_sz # page_table
+ + self.max_num_batches * self.num_kv_heads * int32_sz
+ )
+ total_page_table_bytes = self.num_layers * page_table_bytes_per_layer
+ kv_bytes_net = bytes_for_kv_budget - total_page_table_bytes
+ if kv_bytes_net <= 0:
+ # Tight VRAM: metadata alone can exceed the first budget; reserve page
+ # tables plus a slice of remaining free for KV tensors.
+ bytes_for_kv_budget = min(
+ int(free * 0.95),
+ total_page_table_bytes + max(int(free * 0.25), 8 * 1024 * 1024),
+ )
+ kv_bytes_net = bytes_for_kv_budget - total_page_table_bytes
+ if kv_bytes_net <= 0:
+ raise RuntimeError(
+ "page-table footprint exceeds available GPU memory for compactor KV. "
+ f"Reduce vLLM max_num_seqs (compactor uses {self.max_num_batches}) "
+ f"or max_model_len ({self.max_model_len}), or free GPU memory."
+ )
+ dtype_sz = torch.empty((), dtype=self.model_dtype).element_size()
+ bytes_per_page_across_layers = self.num_layers * (
+ 2 * self.page_size * self.head_dim * dtype_sz
+ )
+ return max(1, kv_bytes_net // bytes_per_page_across_layers)
+
+ def estimate_max_batched_tokens(
+ self,
+ warmup_tokens: int,
+ bytes_used_before_warmup: int,
+ bytes_peak_after_warmup: int,
+ ) -> int:
+ """
+ Estimate the max total number of tokens that can be processed concurrently
+ without OOM.
+ """
+ assert warmup_tokens > 0, "warmup_tokens must be > 0"
+ # activation bytes per token
+ warmup_delta = max(
+ 0, int(bytes_peak_after_warmup) - int(bytes_used_before_warmup)
+ )
+ bytes_per_token = max(1, (warmup_delta + warmup_tokens - 1) // warmup_tokens)
+
+ free, total = torch.cuda.mem_get_info()
+ target = int(total * self.gpu_frac)
+ used_now = int(total - free)
+ # reserve headroom equal to the gap between peak and current allocations seen so far
+ stats = torch.cuda.memory_stats()
+ peak_cur = int(stats.get("allocated_bytes.all.peak", 0))
+ cur_now = int(stats.get("allocated_bytes.all.current", 0))
+ cushion = max(0, peak_cur - cur_now)
+
+ activation_budget = int(max(0, target - used_now - cushion) * 0.95)
+ max_tokens_per_batch = activation_budget // bytes_per_token
+ max_tokens_in_cache = (self.num_pages * self.page_size) // self.num_kv_heads
+ # round to lower multiple of page size
+ max_tokens_per_batch = (max_tokens_per_batch // self.page_size) * self.page_size
+ max_tokens_in_cache = (max_tokens_in_cache // self.page_size) * self.page_size
+
+ # When vLLM shares the same GPU, ``used_now`` often exceeds ``target`` (same
+ # situation as ``get_num_pages``), so activation_budget is ~0 and
+ # ``max_tokens_per_batch`` rounds to 0 or one page. The min(...) would then
+ # cap prefill at ~page_size tokens (e.g. 32) even though the compactor KV pool
+ # is large — no prompt longer than that can be scheduled. Prefer KV capacity
+ # (capped by max_model_len) whenever activation math yields only a token or two.
+ if (
+ max_tokens_in_cache > 0
+ and max_tokens_per_batch <= self.page_size
+ and max_tokens_in_cache > max_tokens_per_batch
+ ):
+ max_tokens_per_batch = min(max_tokens_in_cache, self.max_model_len)
+
+ self.max_batched_tokens = min(max_tokens_in_cache, max_tokens_per_batch)
+ # Last resort: allow at least one page when KV exists but min(...) is still 0.
+ if self.max_batched_tokens == 0 and self.num_pages > 0 and max_tokens_in_cache > 0:
+ self.max_batched_tokens = min(max_tokens_in_cache, self.page_size)
+ return self.max_batched_tokens
+
+ @property
+ def num_free_batches(self) -> int:
+ return len(self.paged_cache.free_batches)
+
+ @property
+ def num_free_pages(self) -> int:
+ return min(len(fp) for fp in self.paged_cache.free_pages)
+
+ def reclaim_pages(
+ self,
+ seq_ids_to_reclaim: Iterable[int],
+ future_reserved_buffer: List[int] | torch.Tensor,
+ ) -> int:
+ approximate_bytes_freed = 0
+ for i, seq_id in enumerate(seq_ids_to_reclaim):
+ batch_idx = self.seq_id_to_batch[seq_id]
+ approximate_bytes_freed += self.paged_cache.reclaim_pages(
+ batch_idx, future_reserved_buffer[i]
+ )
+ return approximate_bytes_freed
diff --git a/vllm/kvprune_legacy_save/core/model_runner.py b/vllm/kvprune_legacy_save/core/model_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..c454a6fea5c8fbf05d8831aad14b09172c6482b0
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/model_runner.py
@@ -0,0 +1,794 @@
+import atexit
+import logging
+import os
+import inspect
+from typing import Any, List, Optional
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from vllm.kvprune.attention.sparse_decode_kernel import num_splits_heuristic
+from vllm.kvprune.compression.compression_config import BatchCompressionParams
+from vllm.kvprune.config.constants import RESERVED_BATCH
+from vllm.kvprune.config.engine_config import LLMConfig, KvpruneAttentionSchedule
+from vllm.kvprune.core.memory_manager import KVCacheManager
+from vllm.kvprune.core.scheduler import Scheduler
+from vllm.kvprune.layers.sampler import Sampler
+from vllm.kvprune.models import MODEL_REGISTRY
+from vllm.kvprune.utils.arguments import (
+ DecodeBatchArguments,
+ DecodeBatchOutput,
+ PackedTensorArguments,
+ PrefillBatchArguments,
+)
+from vllm.kvprune.utils.context import CompressionContext, reset_context, set_context
+from vllm.kvprune.utils.kv_dist import barrier_sync, broadcast_from_tp_rank0
+from vllm.kvprune.utils.sequence import Sequence
+from torch.multiprocessing import Event
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+class ModelRunner:
+ """Per-rank execution loop. Manages model, sampler, KV cache, and warmup"""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ rank: int,
+ batch_ready: Optional[Event] = None,
+ peer_events: List[Event] = None,
+ external_model: Optional[nn.Module] = None,
+ *,
+ embedded_in_vllm_worker: bool = False,
+ device: Optional[torch.device] = None,
+ ):
+ self.config = config
+ self.embedded_in_vllm_worker = embedded_in_vllm_worker
+ if embedded_in_vllm_worker:
+ from vllm.distributed.parallel_state import (
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ )
+
+ tp_ws = get_tensor_model_parallel_world_size()
+ tp_rank = get_tensor_model_parallel_rank()
+ if tp_ws != config.tensor_parallel_size:
+ raise RuntimeError(
+ f"tensor parallel world size {tp_ws} != "
+ f"LLMConfig.tensor_parallel_size {config.tensor_parallel_size}"
+ )
+ self.rank = tp_rank
+ _dev = device if device is not None else torch.device(
+ f"cuda:{torch.cuda.current_device()}"
+ )
+ if not dist.is_initialized():
+ raise RuntimeError(
+ "embedded_in_vllm_worker requires torch.distributed to be "
+ "initialized (vLLM worker)."
+ )
+ if dist.get_world_size() != tp_ws:
+ raise NotImplementedError(
+ "KV-prune compactor embedded in vLLM currently requires "
+ "dist.get_world_size() == tensor_parallel_size "
+ "(pipeline_parallel_size=1, data_parallel_size=1). "
+ f"Got dist.get_world_size()={dist.get_world_size()}, "
+ f"tp_ws={tp_ws}."
+ )
+ else:
+ self.rank = rank
+ _dev = device if device is not None else torch.device(f"cuda:{rank}")
+
+ self._device = _dev
+ assert config.eos_token_ids is not None and len(config.eos_token_ids) > 0, (
+ "LLMConfig.eos_token_ids must be set (filled in LLMEngine from tokenizer)."
+ )
+ self._stop_token_ids = torch.tensor(
+ config.eos_token_ids, dtype=torch.int64, device=_dev
+ )
+ hf_config = config.hf_config
+ self.enforce_eager = config.enforce_eager
+ if config.attention_schedule == KvpruneAttentionSchedule.PDFA:
+ if not self.enforce_eager and self.rank == 0:
+ logger.info(
+ "attention_schedule=PDFA: disabling compactor decode CUDA graphs "
+ "(FlashAttention decode path)."
+ )
+ self.enforce_eager = True
+ # Embedded in vLLM worker (TP>1): respect :attr:`LLMConfig.enforce_eager` from
+ # ``v1_tp_runner._apply_compactor_env_overrides``. Set
+ # ``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` to force eager if graph replay is unstable
+ # with shared vLLM VRAM / streams / NCCL on your stack.
+ if embedded_in_vllm_worker:
+ _tp_graph = os.environ.get(
+ "VLLM_KVPRUNE_TP_EMBEDDED_GRAPH", "1"
+ ).strip().lower()
+ if _tp_graph in ("0", "false", "no"):
+ if not self.enforce_eager:
+ logger.info(
+ "embedded_in_vllm_worker: VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0 → "
+ "forcing compactor enforce_eager=True (skip compactor CUDA graph "
+ "capture)."
+ )
+ self.enforce_eager = True
+ self.world_size = config.tensor_parallel_size
+ self.leverage_sketch_size = config.leverage_sketch_size
+ self.show_progress_bar = config.show_progress_bar
+ self.max_num_batches = config.max_num_seqs
+ self.max_model_len = config.max_model_len
+ self.num_layers = hf_config.num_hidden_layers
+ self.model_dtype = hf_config.torch_dtype
+ self.head_dim = getattr(hf_config, "head_dim", None)
+
+ init_kwargs = {}
+ if not embedded_in_vllm_worker:
+ if "device_id" in inspect.signature(dist.init_process_group).parameters:
+ init_kwargs["device_id"] = torch.device(f"cuda:{rank}")
+ if not dist.is_initialized():
+ dist.init_process_group(
+ "nccl",
+ f"tcp://localhost:{config.nccl_port}",
+ world_size=self.world_size,
+ rank=rank,
+ **init_kwargs,
+ )
+ else:
+ ws = dist.get_world_size()
+ if ws != self.world_size:
+ raise RuntimeError(
+ "torch.distributed is already initialized with "
+ f"world_size={ws}, but compactor ModelRunner expects "
+ f"tensor_parallel_size={self.world_size}. "
+ "Use tensor_parallel_size matching the active process group "
+ "(typically 1 when sharing weights with vLLM)."
+ )
+ torch.cuda.set_device(_dev)
+ default_dtype = torch.get_default_dtype()
+ torch.set_default_dtype(hf_config.torch_dtype)
+ torch.set_default_device("cuda")
+ model_type = hf_config.model_type
+ if external_model is not None:
+ self.model = external_model
+ else:
+ self.model = MODEL_REGISTRY[model_type](hf_config)
+ self.model.load_model(
+ config.path, use_tqdm=self.is_master and self.show_progress_bar
+ )
+ self.sampler = Sampler()
+
+ pre_warmup_mem = torch.cuda.memory_stats().get("allocated_bytes.all.current", 0)
+ # No paged KV yet: FA-only varlen path (see :meth:`warmup`).
+ self.warmup(num_warmup_tokens=self.max_model_len, with_kv=False)
+ post_warmup_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
+
+ self.kv_manager = KVCacheManager(
+ self.rank, config, device=str(self._device)
+ )
+ self.kv_manager.init_cache(self.model)
+
+ self.store_stream: Optional[torch.cuda.Stream] = torch.cuda.Stream()
+ torch.set_default_device("cpu")
+ torch.set_default_dtype(default_dtype)
+
+ self.batch_ready = batch_ready
+ self.peer_events = peer_events if peer_events is not None else []
+ # Embedded TP peers: session end is signaled via TP-group broadcast in
+ # maybe_release_peers (no multiprocessing.Event — not pickleable over RPC).
+ self._embedded_peer_continue = True
+ self.captured_graphs = {}
+ self.min_captured_len = {}
+ self.max_batched_tokens = self.kv_manager.estimate_max_batched_tokens(
+ self.max_model_len, pre_warmup_mem, post_warmup_peak
+ )
+ if self.is_master:
+ logger.info(f"Estimated max batched tokens of {self.max_batched_tokens}")
+ self.warmup(num_warmup_tokens=self.max_model_len, with_kv=True)
+
+ if not self.enforce_eager:
+ bs = [1 << i for i in range(self.max_num_batches.bit_length())]
+ for bs in (
+ tqdm(bs, desc="Capturing CUDA Graphs")
+ if self.is_master and self.show_progress_bar
+ else bs
+ ):
+ for seq_len in [1024, 4096, 8192, 16384]:
+ self.capture_cudagraph(bs, seq_len)
+
+ if not self.captured_graphs:
+ logger.warning(
+ "No compactor CUDA graphs were captured (KV budget tight or "
+ "allocate_sequences failed during capture). Using eager decode "
+ "for this session."
+ )
+ self.enforce_eager = True
+
+ self.packed_args = PackedTensorArguments(
+ rank=self.rank,
+ max_batched_tokens=self.max_batched_tokens,
+ config=self.config,
+ device=self._device,
+ use_tp_group_for_collectives=embedded_in_vllm_worker,
+ )
+ atexit.register(self.exit)
+
+ @torch.inference_mode()
+ def warmup(self, num_warmup_tokens: int, *, with_kv: bool):
+ sched = (
+ self.config.attention_schedule
+ if with_kv
+ else KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ )
+ if self.rank == 0:
+ logger.info(
+ "Warming up compactor attention (%s KV init): schedule=%s",
+ "after" if with_kv else "before",
+ sched.name,
+ )
+ device = self._device
+ input_ids = torch.tensor(
+ [self.config.eos] * num_warmup_tokens, device=device, dtype=torch.int64
+ )
+ positions = torch.arange(num_warmup_tokens, device=device, dtype=torch.int64)
+ cu_seqlens_q = torch.tensor(
+ [0, num_warmup_tokens], device=device, dtype=torch.int32
+ )
+ cu_seqlens_k = torch.tensor(
+ [0, num_warmup_tokens], device=device, dtype=torch.int32
+ )
+ if with_kv:
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ [-1], [num_warmup_tokens]
+ )
+ assert success
+ else:
+ batch_mapping = None
+ set_context(
+ is_prefill=True,
+ do_compression=False,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ cu_seqlens_q_host=(0, num_warmup_tokens),
+ cu_seqlens_k_host=(0, num_warmup_tokens),
+ max_seqlen_q=num_warmup_tokens,
+ max_seqlen_k=num_warmup_tokens,
+ batch_mapping=batch_mapping,
+ attention_schedule=sched,
+ )
+ for _ in range(2):
+ torch.cuda.reset_peak_memory_stats()
+ h = self.model(input_ids, positions)
+ self.model.compute_logits(h)
+ barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
+ if with_kv:
+ self.kv_manager.paged_cache.bh_seq_lens.index_fill_(
+ 1, batch_mapping.to(torch.long), 0
+ )
+ reset_context()
+ if with_kv:
+ self.kv_manager.free_sequences([-1])
+
+ def exit(self):
+ if getattr(self, "_exited", False):
+ return
+ self._exited = True
+ try:
+ if hasattr(self, "captured_graphs"):
+ self.captured_graphs.clear()
+ finally:
+ if getattr(self, "embedded_in_vllm_worker", False):
+ return
+ if dist.is_initialized():
+ dist.destroy_process_group()
+
+ def loop(self):
+ while True:
+ if self.batch_ready.wait(1.0):
+ self._process_batches_peer()
+
+ @torch.inference_mode()
+ def run_prefill(
+ self, prefill_args: PrefillBatchArguments, batch_mapping: torch.Tensor
+ ):
+ assert prefill_args.B > 0 and prefill_args.N > 0
+ max_bh_len = (
+ self.kv_manager.paged_cache.bh_seq_lens.index_select(1, index=batch_mapping)
+ .max()
+ .item()
+ )
+ compression_context = CompressionContext(
+ compression_method=prefill_args.compression_method,
+ compression_chunk_size=prefill_args.compression_chunk_size,
+ batch_tokens_to_retain=prefill_args.batch_tokens_to_retain,
+ max_tokens_to_retain=prefill_args.max_tokens_to_retain,
+ context_lens=prefill_args.context_lens.tolist(),
+ PHI=prefill_args.PHI,
+ sketch_dimension=self.leverage_sketch_size,
+ protected_first_tokens=prefill_args.protected_first,
+ protected_last_tokens=prefill_args.protected_last,
+ compression_ratio=prefill_args.compression_ratio,
+ )
+ cu_q_host = tuple(
+ int(x) for x in prefill_args.cu_seqlens_q.detach().cpu().view(-1).tolist()
+ )
+ cu_k_host = tuple(
+ int(x) for x in prefill_args.cu_seqlens_k.detach().cpu().view(-1).tolist()
+ )
+ set_context(
+ is_prefill=True,
+ do_compression=prefill_args.do_compression,
+ cu_seqlens_q=prefill_args.cu_seqlens_q,
+ cu_seqlens_k=prefill_args.cu_seqlens_k,
+ cu_seqlens_q_host=cu_q_host,
+ cu_seqlens_k_host=cu_k_host,
+ max_seqlen_q=prefill_args.max_seqlen_q,
+ max_seqlen_k=prefill_args.max_seqlen_k,
+ batch_mapping=batch_mapping,
+ max_bh_len=max_bh_len,
+ compression_context=compression_context,
+ STORE_STREAM=self.store_stream,
+ attention_schedule=self.config.attention_schedule,
+ )
+ # int32 token ids break vLLM-delegated embedding (expects long indices) on some paths.
+ _iid = (
+ prefill_args.input_ids
+ if prefill_args.input_ids.dtype == torch.int64
+ else prefill_args.input_ids.long()
+ )
+ _pos = (
+ prefill_args.positions
+ if prefill_args.positions.dtype == torch.int64
+ else prefill_args.positions.long()
+ )
+ hidden = self.model(_iid, _pos)
+ logits = self.model.compute_logits(hidden)
+ reset_context()
+ return logits
+
+ def maybe_broadcast(self, tensor: torch.Tensor, *, label: str = "tensor") -> None:
+ if self.world_size > 1:
+ broadcast_from_tp_rank0(
+ tensor, use_tp_group=self.embedded_in_vllm_worker
+ )
+ return None
+
+ def maybe_release_peers(self, do_release=False):
+ if self.world_size <= 1:
+ return
+ if self.embedded_in_vllm_worker:
+ flag = torch.zeros(1, dtype=torch.int32, device=self._device)
+ if self.is_master:
+ flag[0] = 0 if do_release else 1
+ broadcast_from_tp_rank0(flag, use_tp_group=True)
+ if not self.is_master:
+ self._embedded_peer_continue = bool(flag[0].item())
+ barrier_sync(use_tp_group=True)
+ return
+ if self.is_master:
+ if do_release:
+ for event in self.peer_events:
+ event.clear()
+ barrier_sync(use_tp_group=False)
+ else:
+ barrier_sync(use_tp_group=False)
+
+ def _peer_outer_loop_active(self) -> bool:
+ if self.batch_ready is not None:
+ return self.batch_ready.is_set()
+ if self.embedded_in_vllm_worker:
+ return self._embedded_peer_continue
+ return False
+
+ @torch.inference_mode()
+ def generate(
+ self,
+ all_sequences: List[Sequence],
+ batch_compression_params: Optional[BatchCompressionParams] = None,
+ ):
+ assert self.is_master, "generate can only be called on the master process"
+ if not self.embedded_in_vllm_worker:
+ for begin_execution_event in self.peer_events:
+ begin_execution_event.set()
+ if batch_compression_params is None:
+ batch_compression_params = BatchCompressionParams()
+ self._process_batches_master(all_sequences, batch_compression_params)
+
+ @property
+ def is_master(self):
+ return self.rank == 0
+
+ @torch.inference_mode()
+ def _process_batches_master(
+ self,
+ all_sequences: List[Sequence],
+ batch_compression_params: BatchCompressionParams,
+ ):
+ assert self.is_master
+ compression_details = f"Applying Compression Method: {batch_compression_params.compression_method}"
+ if any(seq.compression_params.compression_ratio < 1.0 for seq in all_sequences):
+ logger.info(compression_details)
+ scheduler = Scheduler(
+ all_sequences=all_sequences,
+ kv_manager=self.kv_manager,
+ use_tqdm=self.show_progress_bar,
+ )
+ decode_batch = DecodeBatchArguments()
+ decode_flags = torch.empty(2, dtype=torch.int32, device=self._device)
+ while not scheduler.is_finished():
+ sequences = scheduler.get_prefill_batch()
+ if not sequences:
+ if scheduler.pending_sequence_ids:
+ raise RuntimeError(
+ "KV-prune compactor cannot schedule any prefill (KV/token budget). "
+ f"max_batched_tokens={self.kv_manager.max_batched_tokens}, "
+ f"pending_sequences={len(scheduler.pending_sequence_ids)}. "
+ "Lower v1 gpu_memory_utilization / max_model_len, set "
+ "VLLM_KVPRUNE_RELEASE_V1_KV=1 to discard v1 KV (sleep+wake), "
+ "or free GPU memory. Diagnostics: "
+ f"{scheduler.diagnose_prefill_failure()}"
+ )
+ # Pending is empty: either finished or decode-only continuation.
+ if decode_batch.token_ids is None:
+ break
+ run_decode = True
+ occupancy = -1
+ else:
+ seq_ids_cpu = [seq.seq_id for seq in sequences]
+ scheduler.add_running_sequence_ids(seq_ids_cpu, update_status=True)
+ temps = torch.tensor(
+ [s.sampling_params.temperature for s in sequences],
+ dtype=torch.float32,
+ pin_memory=True,
+ ).to(device=self._device, non_blocking=True)
+ prefill_arguments = self.packed_args.build_prefill_args(
+ sequences, batch_compression_params=batch_compression_params
+ )
+ max_ctx_lens = (
+ prefill_arguments.max_new_tokens + prefill_arguments.context_lens
+ )
+
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ seq_ids_cpu, max_ctx_lens.tolist()
+ )
+ assert success, "failed to allocate pages for sequences"
+
+ logits = self.run_prefill(prefill_arguments, batch_mapping)
+ # Must match prefill `positions` dtype (int64). `context_lens` is int32
+ # from the packed buffer; using int32 here breaks RoPE indexing
+ # (`cos_sin_cache[positions]`) on CUDA for decode vs prefill.
+ positions = prefill_arguments.context_lens.to(dtype=torch.int64)
+ token_ids = self.sampler(logits, temps)
+ # Prefill KV writes + bh_seq_lens updates run on STORE_STREAM; reclaim
+ # reads bh_seq_lens on the default stream and must not race.
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+ # TODO: synchronize page counts accross dist
+ if self.world_size == 1:
+ self.kv_manager.reclaim_pages(
+ seq_ids_cpu, prefill_arguments.max_new_tokens
+ )
+ # with logging_redirect_tqdm():
+ # logger.info(
+ # f"Reclaimed {reclaimed_bytes / 1e6:.2f} MB from the KV cache"
+ # )
+
+ if scheduler.any_pending_sequences():
+ num_pending_batches = (
+ 0
+ if decode_batch.token_ids is None
+ else decode_batch.token_ids.shape[0]
+ )
+ occupancy = int((num_pending_batches + len(seq_ids_cpu)) * 0.66)
+ else:
+ occupancy = -1
+ run_decode = not scheduler.can_prefill_another_batch()
+ decode_batch = decode_batch.update(
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ prefill_arguments.seq_ids,
+ temps,
+ occupancy,
+ )
+ if self.world_size > 1:
+ decode_flags[0] = int(run_decode)
+ decode_flags[1] = occupancy
+ self.maybe_broadcast(decode_flags, label="decode_flags")
+ if not run_decode:
+ continue
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+
+ decode_output, decode_batch = self.run_decode_loop(decode_batch)
+ finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
+ decode_batch.seq_ids.tolist()
+ )
+ scheduler.record_finished_sequence_ids(
+ finished_sequence_ids, update_status=True
+ )
+ self.kv_manager.free_sequences(finished_sequence_ids)
+ self.maybe_release_peers(scheduler.is_finished())
+ scheduler.update_sequences(
+ decode_output.output_tokens.tolist(),
+ decode_output.output_seq_ids.tolist(),
+ )
+ scheduler.close()
+
+ @torch.inference_mode()
+ def run_peer_session(self) -> None:
+ """Non-master TP ranks: run one peer session (used when embedded in vLLM)."""
+ if self.embedded_in_vllm_worker:
+ self._embedded_peer_continue = True
+ self._process_batches_peer()
+
+ @torch.inference_mode()
+ def _process_batches_peer(self):
+ assert not self.is_master
+ scheduler = Scheduler([], kv_manager=self.kv_manager)
+ decode_batch = DecodeBatchArguments()
+ decode_flags = torch.empty(2, dtype=torch.int32, device=self._device)
+ while self._peer_outer_loop_active():
+ prefill_arguments = self.packed_args.build_prefill_args()
+
+ B = prefill_arguments.B
+ max_ctx_lens = (
+ prefill_arguments.max_new_tokens + prefill_arguments.context_lens
+ )
+
+ seq_ids_cpu = prefill_arguments.seq_ids.tolist()
+ scheduler.add_running_sequence_ids(seq_ids_cpu)
+ success, batch_mapping = self.kv_manager.allocate_sequences(
+ seq_ids_cpu, max_ctx_lens.tolist()
+ )
+ assert success, "failed to allocate pages for sequences"
+
+ self.run_prefill(prefill_arguments, batch_mapping)
+ positions = prefill_arguments.context_lens.to(dtype=torch.int64)
+ self.maybe_broadcast(decode_flags, label="decode_flags")
+ run_decode = bool(decode_flags[0].item())
+ occupancy = int(decode_flags[1].item())
+ token_ids = torch.empty(B, dtype=torch.int64, device=self._device)
+ decode_batch = decode_batch.update(
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ prefill_arguments.seq_ids,
+ None, # temps not used in peer process
+ occupancy,
+ )
+
+ if not run_decode:
+ continue
+ if self.store_stream is not None:
+ torch.cuda.default_stream().wait_stream(self.store_stream)
+
+ _, decode_batch = self.run_decode_loop(decode_batch)
+ finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
+ decode_batch.seq_ids.tolist()
+ )
+ scheduler.record_finished_sequence_ids(finished_sequence_ids)
+ self.kv_manager.free_sequences(finished_sequence_ids)
+ self.maybe_release_peers()
+ scheduler.close()
+
+ @torch.inference_mode()
+ def run_decode_loop(
+ self,
+ decode_batch: DecodeBatchArguments,
+ ) -> tuple[DecodeBatchOutput, DecodeBatchArguments]:
+ if self.is_master:
+ num_stashed_batches = decode_batch.num_stashed_batches
+ tok_buffer = [
+ decode_batch.token_ids[num_stashed_batches:].to(
+ "cpu", non_blocking=True
+ )
+ ]
+ seq_buffer = [
+ decode_batch.seq_ids[num_stashed_batches:].to("cpu", non_blocking=True)
+ ]
+ while True:
+ self.maybe_broadcast(decode_batch.token_ids, label="decode_token_ids")
+ not_stopped = ~torch.isin(decode_batch.token_ids, self._stop_token_ids)
+ running_batches = (decode_batch.positions < decode_batch.max_ctx_lens) & (
+ not_stopped
+ )
+ decode_batch.token_ids = torch.masked_select(
+ decode_batch.token_ids, running_batches
+ )
+ decode_batch.positions = torch.masked_select(
+ decode_batch.positions, running_batches
+ )
+ decode_batch.batch_mapping = torch.masked_select(
+ decode_batch.batch_mapping, running_batches
+ )
+ decode_batch.max_ctx_lens = torch.masked_select(
+ decode_batch.max_ctx_lens, running_batches
+ )
+ decode_batch.seq_ids = torch.masked_select(
+ decode_batch.seq_ids, running_batches
+ )
+ if self.is_master:
+ decode_batch.temps = torch.masked_select(
+ decode_batch.temps, running_batches
+ )
+ num_remaining = decode_batch.token_ids.numel()
+ if (
+ num_remaining == 0
+ or num_remaining <= decode_batch.desired_batch_occupancy
+ ):
+ decode_batch.num_stashed_batches = num_remaining
+ break
+ logits = self._decode_step_logits(decode_batch)
+
+ if self.is_master:
+ decode_batch.token_ids = self.sampler(logits, decode_batch.temps)
+ tok_buffer.append(decode_batch.token_ids.to("cpu", non_blocking=True))
+ seq_buffer.append(decode_batch.seq_ids.to("cpu", non_blocking=True))
+ decode_batch.positions += 1
+
+ if self.is_master:
+ # non_blocking D2H copies must finish before cat/tolist read CPU data.
+ torch.cuda.synchronize()
+ output = DecodeBatchOutput(
+ output_tokens=torch.cat(tok_buffer),
+ output_seq_ids=torch.cat(seq_buffer),
+ )
+ else:
+ output = DecodeBatchOutput(None, None)
+ return output, decode_batch
+
+ def _decode_logits_eager(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ ):
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=batch_mapping,
+ attention_schedule=self.config.attention_schedule,
+ )
+ _iid = input_ids if input_ids.dtype == torch.int64 else input_ids.long()
+ _pos = positions if positions.dtype == torch.int64 else positions.long()
+ hidden = self.model(_iid, _pos)
+ return self.model.compute_logits(hidden)
+
+ @torch.inference_mode()
+ def _decode_step_logits(self, decode_batch: DecodeBatchArguments):
+ """Graph decode when possible; otherwise eager (never raises on missing graph)."""
+ if self.enforce_eager or not self.captured_graphs:
+ return self._decode_logits_eager(
+ decode_batch.token_ids,
+ decode_batch.positions,
+ decode_batch.batch_mapping,
+ )
+ try:
+ return self.run_graph_decode(
+ decode_batch.token_ids,
+ decode_batch.positions,
+ decode_batch.batch_mapping,
+ )
+ except Exception as e:
+ logger.warning(
+ "CUDA graph decode failed (%s); switching to eager decode for "
+ "remaining steps.",
+ e,
+ )
+ self.enforce_eager = True
+ return self._decode_logits_eager(
+ decode_batch.token_ids,
+ decode_batch.positions,
+ decode_batch.batch_mapping,
+ )
+
+ @torch.inference_mode()
+ def run_graph_decode(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ batch_mapping: torch.Tensor,
+ ):
+ bs = input_ids.shape[0]
+ max_k = int(positions.max())
+ graph_dict = self.get_cuda_graph(bs, max_k)
+ if graph_dict is None:
+ return self._decode_logits_eager(input_ids, positions, batch_mapping)
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=batch_mapping,
+ attention_schedule=self.config.attention_schedule,
+ )
+ graph_dict["input_ids"][:bs] = input_ids
+ graph_dict["positions"][:bs] = positions
+ graph_dict["batch_mapping"].fill_(RESERVED_BATCH)
+ graph_dict["batch_mapping"][:bs] = batch_mapping
+ graph_dict["graph"].replay()
+ logits_out = graph_dict["logits"]
+ return logits_out[:bs].contiguous()
+
+ @torch.inference_mode()
+ def capture_cudagraph(self, batch_size: int, max_seqlen_k: int):
+ barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
+ device = torch.device("cuda")
+ logger.debug(
+ f"Capturing CUDA graph for batch size {batch_size} ({max_seqlen_k} tokens)"
+ )
+ _g_input_ids = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ _g_positions = torch.zeros(batch_size, dtype=torch.int64, device=device)
+ _g_hidden = None
+ key_split = num_splits_heuristic(
+ batch_size * self.kv_manager.num_kv_heads,
+ max_seq_len=max_seqlen_k,
+ num_sms=torch.cuda.get_device_properties(device).multi_processor_count,
+ max_splits=12,
+ )
+
+ success, _g_batch_mapping = self.kv_manager.allocate_sequences(
+ list(range(batch_size)), [256] * batch_size
+ )
+ if not success:
+ # Shared GPU with vLLM: compactor KV pool is small; large batch capture
+ # often cannot reserve [256]*batch_size per sequence. Skip this graph.
+ logger.warning(
+ "Skipping CUDA graph capture for batch_size=%s max_seqlen_k=%s "
+ "(KV allocate_sequences failed; decode will use eager or other graphs).",
+ batch_size,
+ max_seqlen_k,
+ )
+ barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
+ return
+
+ set_context(
+ is_prefill=False,
+ do_compression=False,
+ batch_mapping=_g_batch_mapping,
+ key_split=key_split,
+ attention_schedule=self.config.attention_schedule,
+ )
+ _gw = self.model(_g_input_ids, _g_positions)
+ self.model.compute_logits(_gw)
+ barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
+ decode_graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(decode_graph):
+ _g_hidden = self.model(_g_input_ids, _g_positions)
+ _g_logits = self.model.compute_logits(_g_hidden)
+ graph_vars = {
+ "graph": decode_graph,
+ "input_ids": _g_input_ids,
+ "positions": _g_positions,
+ "batch_mapping": _g_batch_mapping,
+ "hidden": _g_hidden,
+ "logits": _g_logits,
+ "key_split": key_split,
+ }
+ if batch_size not in self.captured_graphs:
+ self.captured_graphs[batch_size] = {}
+ self.min_captured_len[batch_size] = float("inf")
+
+ self.captured_graphs[batch_size][max_seqlen_k] = graph_vars
+ self.min_captured_len[batch_size] = min(
+ max_seqlen_k, self.min_captured_len[batch_size]
+ )
+ self.kv_manager.free_sequences(list(range(batch_size)))
+
+ def get_cuda_graph(
+ self, batch_size: int, max_seqlen_k: int
+ ) -> Optional[dict[str, Any]]:
+ """Return a captured graph dict, or None if no compatible capture exists."""
+ if not self.captured_graphs:
+ return None
+ eligible_bs = [x for x in self.captured_graphs.keys() if x >= batch_size]
+ if not eligible_bs:
+ return None
+ bs_key = min(eligible_bs)
+ batch_size_graphs = self.captured_graphs[bs_key]
+ candidates = [sl for sl in batch_size_graphs.keys() if sl <= max_seqlen_k]
+ if not candidates:
+ return None
+ best_sl = max(candidates)
+ return batch_size_graphs[best_sl]
+
diff --git a/vllm/kvprune_legacy_save/core/runtime.py b/vllm/kvprune_legacy_save/core/runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff23b6764fbe25b4c011c87b815835ebd25b917a
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/runtime.py
@@ -0,0 +1,130 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import torch
+
+from vllm.forward_context import get_forward_context
+
+from vllm.kvprune.core.compression_bridge import (
+ COMPRESSION_METHOD_ID_NONE,
+ compression_method_str_to_id,
+)
+
+
+@dataclass
+class KVPruneForwardState:
+ """Per-forward-pass state for KV pruning (per-layer logical lengths)."""
+
+ active: bool
+ compression_ratio_gpu: torch.Tensor
+ """[num_reqs_padded] ratio in (0,1], 1.0 means no pruning for that row."""
+
+ compression_method_id_gpu: torch.Tensor
+ """[num_reqs_padded] int32 — see ``compression_bridge`` ids (0=none)."""
+
+ query_start_loc: torch.Tensor
+ """[num_reqs_padded + 1] int32 on device."""
+
+ num_reqs: int
+ num_reqs_padded: int
+ num_layers: int
+ logical_seq_lens_gpu: torch.Tensor
+ """Logical KV length per layer (and optionally per KV head).
+
+ Shape ``[num_layers, num_reqs_padded]`` or, when ``num_kv_heads > 1``,
+ ``[num_layers, num_reqs_padded, num_kv_heads]`` for per-head lengths.
+ """
+
+ is_prefill: bool
+ device: torch.device
+
+ def logical_seq_lens_for_layer(self, layer_idx: int) -> torch.Tensor:
+ sl = self.logical_seq_lens_gpu[layer_idx]
+ if sl.dim() == 2:
+ return sl.max(dim=-1).values
+ return sl
+
+
+def build_kv_prune_forward_state(
+ *,
+ req_ids: list[str],
+ requests: dict[str, object],
+ query_start_loc: torch.Tensor,
+ num_reqs: int,
+ num_reqs_padded: int,
+ num_layers: int,
+ max_num_scheduled_tokens: int,
+ device: torch.device,
+ logical_seq_lens_gpu: torch.Tensor,
+) -> KVPruneForwardState | None:
+ """Build pruning state when any request uses compression_ratio < 1.0."""
+ if num_reqs <= 0 or num_layers <= 0:
+ return None
+
+ ratios = []
+ method_ids: list[int] = []
+ active_req = False
+ for rid in req_ids[:num_reqs]:
+ req = requests.get(rid)
+ sp = getattr(req, "sampling_params", None) if req is not None else None
+ r = 1.0 if sp is None else float(getattr(sp, "compression_ratio", 1.0))
+ if r < 1.0 - 1e-6:
+ active_req = True
+ ratios.append(r)
+ if sp is None or r >= 1.0 - 1e-6:
+ mid = COMPRESSION_METHOD_ID_NONE
+ else:
+ cm = getattr(sp, "compression_method", "none") or "none"
+ mid = compression_method_str_to_id(str(cm))
+ method_ids.append(mid)
+
+ if not active_req:
+ return None
+
+ compression_ratio_gpu = torch.ones(
+ (num_reqs_padded,), dtype=torch.float32, device=device
+ )
+ compression_ratio_gpu[:num_reqs] = torch.tensor(
+ ratios, dtype=torch.float32, device=device
+ )
+ compression_method_id_gpu = torch.zeros(
+ (num_reqs_padded,), dtype=torch.int32, device=device
+ )
+ compression_method_id_gpu[:num_reqs] = torch.tensor(
+ method_ids, dtype=torch.int32, device=device
+ )
+
+ is_prefill = max_num_scheduled_tokens > 1
+
+ return KVPruneForwardState(
+ active=True,
+ compression_ratio_gpu=compression_ratio_gpu,
+ compression_method_id_gpu=compression_method_id_gpu,
+ query_start_loc=query_start_loc,
+ num_reqs=num_reqs,
+ num_reqs_padded=num_reqs_padded,
+ num_layers=num_layers,
+ logical_seq_lens_gpu=logical_seq_lens_gpu,
+ is_prefill=is_prefill,
+ device=device,
+ )
+
+
+def layer_index_from_layer_name(layer_name: str) -> int:
+ from vllm.model_executor.models.utils import extract_layer_index
+
+ return extract_layer_index(layer_name)
+
+
+def get_kv_prune_state() -> KVPruneForwardState | None:
+ try:
+ fc = get_forward_context()
+ except AssertionError:
+ return None
+ state = fc.additional_kwargs.get("kv_prune")
+ if state is None or not isinstance(state, KVPruneForwardState) or not state.active:
+ return None
+ return state
diff --git a/vllm/kvprune_legacy_save/core/scheduler.py b/vllm/kvprune_legacy_save/core/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d55d3a3db052ba8c85aff9afc8fa2b97ce030b96
--- /dev/null
+++ b/vllm/kvprune_legacy_save/core/scheduler.py
@@ -0,0 +1,259 @@
+import time
+from typing import Iterable, List
+
+from vllm.kvprune.core.memory_manager import KVCacheManager
+from vllm.kvprune.utils.sequence import Sequence, SequenceStatus
+from tqdm import tqdm
+
+
+def cdiv(a, b):
+ """ceiling division"""
+ return (a + b - 1) // b
+
+
+class Scheduler:
+ """
+ Simple sequence scheduler for prefill + decode with a paged KV cache.
+ The scheduler tracks three disjoint sets of sequence IDs:
+
+ * ``pending_sequence_ids`` – sequences that have not yet been started.
+ * ``active_sequence_ids`` – sequences currently running.
+ * ``finished_sequence_ids`` – sequences that have generated all tokens.
+
+ At prefill time, :meth:`get_prefill_batch` selects a subset of pending
+ sequences that can fit into the available KV cache and per-step token
+ budget, given the constraints from the associated :class:`KVCacheManager`.
+
+ The class also handles basic bookkeeping of sequence statuses.
+
+ Args:
+ :param all_sequences:
+ Iterable of :class:`Sequence` objects to be scheduled. Each
+ sequence must have a unique ``seq_id``.
+ :param kv_manager:
+ A :class:`KVCacheManager` instance that this scheduler will use
+ to determine whether additional batches can be scheduled.
+ :param use_tqdm:
+ If True, two progress bars are created:
+ * "Started Batches" – increments when a sequence moves from
+ pending to running.
+ * "Finished Batches" – increments when a sequence finishes.
+ """
+
+ def __init__(
+ self,
+ all_sequences: Iterable[Sequence],
+ kv_manager: KVCacheManager,
+ *,
+ use_tqdm=False,
+ ):
+ self.allseq_mapping: dict[int, Sequence] = {s.seq_id: s for s in all_sequences}
+ self.pending_sequence_ids: set[int] = set([s.seq_id for s in all_sequences])
+ self.active_sequence_ids: set[int] = set()
+ self.finished_sequence_ids: set[int] = set()
+ self.manager = kv_manager
+ self.use_tqdm = use_tqdm
+ self.start_time = time.perf_counter()
+ self.total_tokens_generated = 0
+ self.total_tokens_input = 0
+ self.pbar = None
+ if use_tqdm:
+ self.pbar = tqdm(
+ total=len(self.pending_sequence_ids),
+ desc="Completed Batches",
+ )
+
+ def get_prefill_batch(self) -> List[Sequence]:
+ """
+ Select a batch of pending sequences to prefill under KV/memory constraints.
+
+ The selection is greedy over ``pending_sequence_ids`` in iteration order.
+ A sequence is added to the batch if:
+
+ * The sum of its prompt length and the total prompt tokens selected so
+ far does not exceed ``manager.max_batched_tokens``, and
+ * There is at least one free KV "batch slot" left
+ (``manager.num_free_batches``), and
+ * The total number of KV pages required by the sequence's prompt +
+ max_new_tokens does not exceed the remaining free pages.
+ Returns:
+ :return List[Sequence]:
+ The list of :class:`Sequence` objects chosen for prefill in
+ this step. The caller is responsible for marking them as
+ active via :meth:`add_running_sequence_ids`.
+ """
+ total_tok, sequences = 0, []
+ num_free_batches, num_free_pages = (
+ self.manager.num_free_batches,
+ self.manager.num_free_pages,
+ )
+ for seq_id in self.pending_sequence_ids:
+ seq = self.allseq_mapping[seq_id]
+ prompt_length = seq.prompt_len
+ pages_needed = (
+ cdiv(
+ prompt_length + seq.sampling_params.max_new_tokens,
+ self.manager.page_size,
+ )
+ * self.manager.num_kv_heads
+ )
+ if (
+ prompt_length + total_tok <= self.manager.max_batched_tokens
+ and num_free_batches > 0
+ and pages_needed <= num_free_pages
+ ):
+ sequences.append(seq)
+ total_tok += prompt_length
+ num_free_pages -= pages_needed
+ num_free_batches -= 1
+ return sequences
+
+ def diagnose_prefill_failure(self) -> str:
+ """Explain why :meth:`get_prefill_batch` may return empty (debugging)."""
+ num_free_batches = self.manager.num_free_batches
+ num_free_pages = self.manager.num_free_pages
+ parts = [
+ f"num_free_batches={num_free_batches}",
+ f"num_free_pages={num_free_pages}",
+ f"num_pages_per_layer={getattr(self.manager, 'num_pages', None)}",
+ ]
+ seq_id = next(iter(self.pending_sequence_ids), None)
+ if seq_id is None:
+ return "; ".join(parts)
+ seq = self.allseq_mapping[seq_id]
+ pl = seq.prompt_len
+ mn = seq.sampling_params.max_new_tokens
+ pages_needed = (
+ cdiv(pl + mn, self.manager.page_size) * self.manager.num_kv_heads
+ )
+ parts.append(
+ f"first_pending seq_id={seq_id} prompt_len={pl} max_new_tokens={mn} "
+ f"pages_needed~={pages_needed}"
+ )
+ if num_free_batches == 0:
+ parts.append(
+ "likely_cause=no free batch slots (compactor max_num_seqs exhausted)"
+ )
+ elif pl > self.manager.max_batched_tokens:
+ parts.append(
+ f"likely_cause=prompt_len ({pl}) > max_batched_tokens "
+ f"({self.manager.max_batched_tokens})"
+ )
+ elif pages_needed > num_free_pages:
+ parts.append(
+ "likely_cause=KV pool too small: pages_needed exceeds num_free_pages "
+ "(raise VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC / lower v1 memory, or cap "
+ "compactor max_num_seqs to shrink page-table overhead)"
+ )
+ else:
+ parts.append(
+ "likely_cause=batched token sum or greedy order (another sequence may "
+ "block first in set iteration)"
+ )
+ return "; ".join(parts)
+
+ def is_finished(self) -> bool:
+ """
+ Check whether all sequences have completed.
+ """
+ return (
+ len(self.pending_sequence_ids) == 0 and len(self.active_sequence_ids) == 0
+ )
+
+ def any_pending_sequences(self) -> bool:
+ """
+ Check whether any sequences are still pending (not yet started).
+ """
+ return len(self.pending_sequence_ids) != 0
+
+ def add_running_sequence_ids(
+ self, active_sequence_ids: Iterable[int], *, update_status: bool = False
+ ):
+ """
+ Mark a set of sequences as active / running. This moves sequence IDs
+ from ``pending_sequence_ids`` into ``active_sequence_ids``. Optionally,
+ it also updates the per-sequence status and progress bar.
+
+ Args:
+ :param active_sequence_ids:
+ Iterable of sequence IDs that have been scheduled for prefill
+ or decode and should now be considered running.
+ :param update_status:
+ If True, set each corresponding :class:`Sequence`'s
+ ``status = SequenceStatus.RUNNING`` and increment the
+ "Started Batches" progress bar if ``use_tqdm`` is enabled.
+ """
+ self.active_sequence_ids.update(active_sequence_ids)
+ self.pending_sequence_ids.difference_update(self.active_sequence_ids)
+ if update_status:
+ for seq_id in active_sequence_ids:
+ self.allseq_mapping[seq_id].status = SequenceStatus.RUNNING
+ self.total_tokens_input += self.allseq_mapping[seq_id].prompt_len
+
+ def get_finished_sequence_ids_from_unfinished(
+ self, unfinished_sequence_ids: Iterable[int]
+ ) -> set[int]:
+ """
+ Infer which active sequences have finished given the
+ unfinished set (for decode steps where the caller knows
+ which sequences are still generating but not necessarily
+ which have just completed).
+ Args:
+ :param unfinished_sequence_ids:
+ Iterable of sequence IDs that are still running
+ Returns:
+ :return set[int]:
+ The inferred set of sequence IDs that transitioned from active
+ to finished.
+ """
+ return self.active_sequence_ids.difference(unfinished_sequence_ids)
+
+ def record_finished_sequence_ids(
+ self, finished_sequence_ids: Iterable[int], *, update_status: bool = False
+ ):
+ """
+ Record that a set of sequences has finished generation.
+
+ This moves IDs from ``active_sequence_ids`` into
+ ``finished_sequence_ids``.
+
+ Args:
+ :param finished_sequence_ids:
+ Iterable of sequence IDs that have completed generation and
+ no longer require KV cache.
+ :param update_status:
+ If True, set each corresponding :class:`Sequence`'s
+ ``status = SequenceStatus.FINISHED``
+ """
+ self.active_sequence_ids.difference_update(finished_sequence_ids)
+ self.finished_sequence_ids.update(finished_sequence_ids)
+ if update_status:
+ for seq_id in finished_sequence_ids:
+ self.allseq_mapping[seq_id].status = SequenceStatus.FINISHED
+ if self.pbar is not None:
+ self.pbar.update(1)
+
+ def update_sequences(self, tokens: Iterable[int], seq_ids: Iterable[int]):
+ """
+ Append newly generated tokens to their corresponding sequences.
+ Args:
+ :param tokens:
+ Iterable of generated token IDs, one per sequence.
+ :param seq_ids:
+ Iterable of sequence IDs aligned with ``tokens``.
+ """
+ cur_time = time.perf_counter()
+ for tok, seq_id in zip(tokens, seq_ids):
+ self.allseq_mapping[seq_id].add_new_token(tok)
+ self.total_tokens_generated += 1
+ if self.pbar is not None:
+ self.pbar.set_description(
+ f"Throughput: {(self.total_tokens_generated + self.total_tokens_input) / (cur_time - self.start_time):.2f} tok/s"
+ )
+
+ def close(self):
+ if self.pbar is not None:
+ self.pbar.close()
+
+ def can_prefill_another_batch(self) -> bool:
+ return len(self.get_prefill_batch()) > 0
diff --git a/vllm/kvprune_legacy_save/integration/__init__.py b/vllm/kvprune_legacy_save/integration/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1470f0ad554e2282b26acd66c47a98824c12245c
--- /dev/null
+++ b/vllm/kvprune_legacy_save/integration/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""KV-pruning integration: compactor ``LLMEngine`` sharing weights with :class:`~vllm.LLM`."""
+
+from vllm.kvprune.integration.compression_params import CompressionParams
+
+__all__ = ["CompressionParams"]
diff --git a/vllm/kvprune_legacy_save/integration/compactor_shared.py b/vllm/kvprune_legacy_save/integration/compactor_shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..148df4f06dd6397f57a96080ec340dc6d9eaa1d0
--- /dev/null
+++ b/vllm/kvprune_legacy_save/integration/compactor_shared.py
@@ -0,0 +1,140 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Construct compactor :class:`LLMEngine` sharing weight tensors with an in-process vLLM ``LLM``."""
+
+from __future__ import annotations
+
+import os
+
+import torch.nn as nn
+
+from vllm.config import VllmConfig
+from vllm.kvprune.config.engine_config import LLMConfig
+from vllm.kvprune.core.llm_engine import LLMEngine
+from vllm.kvprune.integration.config_adapter import vllm_config_to_llm_config
+from vllm.kvprune.integration.vllm_model_access import extract_vllm_causal_lm
+from vllm.kvprune.integration.weight_tie import (
+ delegate_kvprune_compute_logits_to_vllm,
+ delegate_kvprune_embed_tokens_to_vllm,
+ tie_kvprune_rope_buffers_from_vllm,
+ tie_kvprune_weights_from_vllm,
+)
+from vllm.kvprune.models import MODEL_REGISTRY
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def build_llm_config_for_compactor(vc: VllmConfig) -> LLMConfig:
+ """Public helper: vLLM config → compactor :class:`LLMConfig`."""
+ return vllm_config_to_llm_config(vc)
+
+
+def create_compactor_engine_with_shared_weights(llm: object) -> LLMEngine:
+ """Single GPU, TP=1: compactor ``LLMEngine`` whose weights alias vLLM tensors.
+
+ Call after the vLLM ``LLM`` has loaded weights. Requires in-process executor
+ (``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
+ """
+ llm_engine = getattr(llm, "llm_engine", None)
+ if llm_engine is None:
+ raise RuntimeError("Expected ``llm.llm_engine``.")
+ vc: VllmConfig = llm_engine.vllm_config
+ if vc.parallel_config.tensor_parallel_size != 1:
+ raise ValueError(
+ "Shared-weight compactor backend requires tensor_parallel_size=1"
+ )
+
+ cfg = vllm_config_to_llm_config(vc)
+ # ``cfg.enforce_eager`` is for the compactor ``ModelRunner`` only (decode CUDA
+ # graphs), not v1. v1 graph capture is controlled solely by ``LLM(...,
+ # enforce_eager=...)`` / ``kvprune_compression=True`` on the entrypoint ``LLM``.
+ # Large vLLM max_num_seqs blows up compactor page-table GPU memory; sharing the GPU
+ # with v1 leaves little room for metadata + KV tensors. Default cap 32 so physical
+ # KV pages stay usable; set VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS=0 to disable cap,
+ # or raise (e.g. 128) if you have VRAM headroom.
+ _cap = os.environ.get("VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS", "32").strip()
+ if _cap:
+ lim = int(_cap)
+ if lim > 0:
+ cfg.max_num_seqs = min(cfg.max_num_seqs, lim)
+
+ # Compactor decode graphs (``enforce_eager=False``): honored for non-shared-weight
+ # engines. **Shared-weight** path (below) forces ``enforce_eager=True`` after
+ # delegating ``compute_logits`` to vLLM unless ``VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1``.
+ # Opt out of graphs for non-shared runs: ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1`` or
+ # ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0``.
+ _ce = os.environ.get("VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER", "").strip().lower()
+ if _ce in ("1", "true", "yes"):
+ cfg.enforce_eager = True
+ logger.info(
+ "KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1 → "
+ "enforce_eager=True (skip compactor decode CUDA graphs)."
+ )
+ elif _ce in ("0", "false", "no"):
+ cfg.enforce_eager = False
+ logger.info(
+ "KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=0 → "
+ "enforce_eager=False (try compactor CUDA graph capture)."
+ )
+ else:
+ _dg = os.environ.get(
+ "VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH", "1"
+ ).strip().lower()
+ if _dg in ("0", "false", "no"):
+ cfg.enforce_eager = True
+ logger.info(
+ "KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0 → "
+ "enforce_eager=True (skip compactor decode CUDA graphs)."
+ )
+ else:
+ cfg.enforce_eager = False
+ logger.info(
+ "KV-prune compactor: default try decode CUDA graphs; ModelRunner "
+ "falls back to eager if capture yields none. Set "
+ "VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1 or "
+ "VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0 to skip capture."
+ )
+
+ hf = cfg.hf_config
+ assert hf is not None
+ model_type = hf.model_type
+ if model_type not in MODEL_REGISTRY:
+ raise ValueError(
+ f"Compactor MODEL_REGISTRY has no entry for model_type={model_type!r}; "
+ f"supported: {sorted(MODEL_REGISTRY)}"
+ )
+
+ vllm_model = extract_vllm_causal_lm(llm)
+ device = next(vllm_model.parameters()).device
+ dtype = next(vllm_model.parameters()).dtype
+
+ # Build compactor shell on CPU first. **Do not** call ``.to(device)`` before tying:
+ # that allocates a full second copy of weights on GPU; tying then frees the
+ # duplicate but peak memory can OOM on large models. Tie first so parameters
+ # alias vLLM tensors directly (no extra weight VRAM).
+ kv_model: nn.Module = MODEL_REGISTRY[model_type](hf)
+ tie_kvprune_weights_from_vllm(vllm_model, kv_model)
+ # Buffers (e.g. RoPE tables) not in ``named_parameters`` may still be on CPU.
+ kv_model.to(device=device, dtype=dtype)
+ tie_kvprune_rope_buffers_from_vllm(vllm_model, kv_model)
+ delegate_kvprune_embed_tokens_to_vllm(vllm_model, kv_model)
+ delegate_kvprune_compute_logits_to_vllm(vllm_model, kv_model)
+
+ # Compactor decode CUDA graphs capture ``model.forward`` + ``compute_logits`` in one
+ # graph. Here ``compute_logits`` is delegated to vLLM's LM head / LogitsProcessor
+ # (cublas GEMM, padded vocab, etc.). Embedding that in a nested capture commonly
+ # fails with ``CUBLAS_STATUS_EXECUTION_FAILED`` and invalidates stream capture
+ # (``cudaErrorStreamCaptureInvalidated``). Default: skip graphs for this integration.
+ _sw_graph = os.environ.get(
+ "VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH", "0"
+ ).strip().lower() in ("1", "true", "yes")
+ if not _sw_graph:
+ cfg.enforce_eager = True
+ logger.info(
+ "KV-prune shared-weight compactor: enforce_eager=True (skip compactor "
+ "decode CUDA graphs; logits delegated to vLLM). Set "
+ "VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1 only to attempt capture (often fails)."
+ )
+
+ return LLMEngine(cfg, external_model=kv_model)
diff --git a/vllm/kvprune_legacy_save/integration/compressed_generate.py b/vllm/kvprune_legacy_save/integration/compressed_generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed7ff7a8543e94857ad17560297caa743e13d819
--- /dev/null
+++ b/vllm/kvprune_legacy_save/integration/compressed_generate.py
@@ -0,0 +1,447 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""KV-pruning (compactor) path invoked from :meth:`vllm.entrypoints.llm.LLM.generate`."""
+
+from __future__ import annotations
+
+import os
+from collections.abc import Callable, Sequence
+from pathlib import Path
+from typing import Any
+
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer
+
+from vllm.kvprune.compression.compression_config import (
+ BatchCompressionParams,
+ SequenceCompressionParams,
+)
+from vllm.kvprune.config.sampling_params import SamplingParams as CompactorSamplingParams
+from vllm.kvprune.core.compression_bridge import (
+ compression_method_id_to_enum,
+ compression_method_str_to_id,
+)
+from vllm.kvprune.core.llm_engine import LLMEngine, _infer_stop_token_ids
+from vllm.kvprune.integration.compactor_shared import create_compactor_engine_with_shared_weights
+from vllm.kvprune.integration.compression_params import CompressionParams
+from vllm.logger import init_logger
+from vllm.outputs import CompletionOutput, RequestOutput
+from vllm.sampling_params import SamplingParams
+
+logger = init_logger(__name__)
+
+_MP_ENV = "VLLM_ENABLE_V1_MULTIPROCESSING"
+_RELEASE_V1_KV_ENV = "VLLM_KVPRUNE_RELEASE_V1_KV"
+
+
+def _maybe_release_v1_kv_for_compactor(llm: Any) -> None:
+ """Optionally discard v1's KV cache so more GPU memory is free for compactor.
+
+ v1 reserves KV blocks at engine init; shared-weight compactor then competes for
+ the same VRAM. ``sleep(level=1)`` discards v1 KV and may offload tagged weights
+ per v1 sleep policy, then ``wake_up()`` reloads — compactor still ties the same
+ v1 tensors after.
+
+ **Default:** ``vllm.env_override`` sets ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` (no
+ sleep/wake; v1 KV stays on GPU). Set ``=1`` if you need extra VRAM for compactor
+ before the first compressed step (then ``llm.sleep`` / ``CuMemAllocator`` /
+ ``Sleep mode freed …`` logs are expected). This does **not** remove v1's KV
+ reservation at init; it only runs the optional sleep/wake cycle before compactor.
+
+ Tests keep ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` in ``conftest``.
+ """
+ if os.environ.get(_RELEASE_V1_KV_ENV, "0").strip().lower() not in (
+ "1",
+ "true",
+ "yes",
+ ):
+ return
+ try:
+ logger.info(
+ "%s=1: discarding v1 KV via sleep(level=1) then wake_up() "
+ "(reloads model weights to GPU).",
+ _RELEASE_V1_KV_ENV,
+ )
+ llm.sleep(level=1, mode="abort")
+ llm.wake_up()
+ except Exception as e:
+ logger.warning("%s: sleep/wake failed: %s", _RELEASE_V1_KV_ENV, e)
+
+
+def ensure_inprocess_engine_for_weight_sharing() -> None:
+ """Compactor must see ``worker.get_model()`` in the same process as vLLM."""
+ if os.environ.get(_MP_ENV, "1") != "0":
+ os.environ[_MP_ENV] = "0"
+ logger.info(
+ "KV cache pruning: set %s=0 so the model stays in-process for "
+ "shared-weight compactor (no manual env needed).",
+ _MP_ENV,
+ )
+
+
+def _normalize_prompt_list(prompts: Any) -> list[Any]:
+ if isinstance(prompts, str):
+ return [prompts]
+ if isinstance(prompts, dict):
+ return [prompts]
+ return list(prompts)
+
+
+def _normalize_sampling_params(
+ sampling_params: SamplingParams | Sequence[SamplingParams] | None,
+ n: int,
+) -> list[SamplingParams]:
+ if sampling_params is None:
+ return [SamplingParams() for _ in range(n)]
+ if isinstance(sampling_params, SamplingParams):
+ return [sampling_params] * n
+ sps = list(sampling_params)
+ if len(sps) != n:
+ raise ValueError(
+ f"sampling_params length {len(sps)} != prompts length {n}"
+ )
+ return sps
+
+
+def _normalize_compression_params(
+ compression: CompressionParams | Sequence[CompressionParams] | None,
+ n: int,
+) -> list[CompressionParams]:
+ if compression is None:
+ return [CompressionParams(compression_ratio=1.0) for _ in range(n)]
+ if isinstance(compression, CompressionParams):
+ return [compression] * n
+ comp = list(compression)
+ if len(comp) != n:
+ raise ValueError(f"compression length {len(comp)} != prompts length {n}")
+ return comp
+
+
+def _any_compactor(comps: list[CompressionParams]) -> bool:
+ return any(c.compression_ratio < 1.0 for c in comps)
+
+
+_FORCE_COMPACTOR_PATH_ENV = "VLLM_KVPRUNE_FORCE_COMPACTOR_PATH"
+
+
+def _should_use_kvprune_compactor_path(comps: list[CompressionParams]) -> bool:
+ """Use integrated compactor when any prompt requests compression, or when forced.
+
+ If all ``compression_ratio >= 1.0``, the default is to return ``None`` from
+ :func:`try_compressed_generate` and fall back to the standard v1 engine
+ (``Processed prompts`` loop). That hides TP/kvprune bugs behind a different
+ code path. Set ``VLLM_KVPRUNE_FORCE_COMPACTOR_PATH=1`` to run the same
+ compactor + collective RPC path as compression-on, with no KV pruning.
+ """
+ if _any_compactor(comps):
+ return True
+ return os.environ.get(_FORCE_COMPACTOR_PATH_ENV, "").strip().lower() in (
+ "1",
+ "true",
+ "yes",
+ )
+
+
+def _to_compactor_sampling(sp: SamplingParams) -> CompactorSamplingParams:
+ mt = sp.max_tokens
+ if mt is None:
+ mt = 16
+ return CompactorSamplingParams(
+ temperature=float(sp.temperature),
+ max_new_tokens=int(mt),
+ )
+
+
+def _to_sequence_compression(cp: CompressionParams) -> SequenceCompressionParams:
+ return SequenceCompressionParams(
+ compression_ratio=float(cp.compression_ratio),
+ protected_first_tokens=int(cp.protected_first_tokens),
+ protected_last_tokens=int(cp.protected_last_tokens),
+ )
+
+
+def _batch_compression_from_comps(comps: list[CompressionParams]) -> BatchCompressionParams:
+ for c in comps:
+ if c.compression_ratio < 1.0:
+ mid = compression_method_str_to_id(c.compression_method)
+ return BatchCompressionParams(
+ compression_method=compression_method_id_to_enum(mid)
+ )
+ return BatchCompressionParams()
+
+
+def _kvprune_compactor_hf_tokenizer(llm: Any):
+ """HF tokenizer matching :meth:`vllm.kvprune.core.llm_engine.LLMEngine.__init__`.
+
+ Loads from the **resolved on-disk** model tree (local dir or HF cache snapshot), not
+ the bare repo id, to avoid redundant Hub downloads.
+ """
+ cached = getattr(llm, "_kvprune_compactor_hf_tokenizer", None)
+ if cached is not None:
+ return cached
+ mc = llm.llm_engine.vllm_config.model_config
+ model_s = str(mc.model)
+ src = model_s
+ try:
+ p = Path(model_s)
+ if p.is_dir() and (p / "config.json").is_file():
+ src = str(p.resolve())
+ else:
+ from huggingface_hub import snapshot_download
+
+ src = snapshot_download(repo_id=model_s, local_files_only=False)
+ except Exception:
+ src = model_s
+ hf_cfg = mc.hf_config
+ _trust = bool(getattr(hf_cfg, "trust_remote_code", False)) if hf_cfg is not None else False
+ tok = AutoTokenizer.from_pretrained(src, use_fast=True, trust_remote_code=_trust)
+ llm._kvprune_compactor_hf_tokenizer = tok
+ return tok
+
+
+def _prompt_to_compactor_input(prompt: Any) -> str | list[int]:
+ if isinstance(prompt, str):
+ return prompt
+ # Decoder-only `list[int]` token ids (see `vllm.inputs.PromptType`).
+ if isinstance(prompt, list):
+ if not prompt:
+ raise TypeError("Empty token-id prompt is not supported for compactor path.")
+ if all(isinstance(t, int) for t in prompt):
+ return list(prompt)
+ if isinstance(prompt, dict):
+ if "prompt_token_ids" in prompt:
+ ids = prompt["prompt_token_ids"]
+ return list(ids) if not isinstance(ids, list) else ids
+ p = prompt.get("prompt")
+ if isinstance(p, str):
+ return p
+ raise TypeError(
+ f"Unsupported prompt type for compactor path: {type(prompt)}. "
+ "Use str, list[int] token ids, or dict with 'prompt_token_ids' or 'prompt'."
+ )
+
+
+def _prompt_to_token_ids_for_tp(llm: Any, prompt: Any) -> list[int]:
+ """Driver-side token ids for the TP collective path (same tokenizer as vLLM ``LLM``)."""
+ comp_in = _prompt_to_compactor_input(prompt)
+ if isinstance(comp_in, str):
+ return llm.get_tokenizer().encode(comp_in)
+ return list(comp_in)
+
+
+def _compressed_generate_tp_collective(
+ llm: Any,
+ plist: list[Any],
+ sps: list[SamplingParams],
+ comps: list[CompressionParams],
+) -> list[RequestOutput]:
+ """TP>1: run compactor on each worker via ``collective_rpc`` (all ranks)."""
+ vc = llm.llm_engine.vllm_config
+ pc = vc.parallel_config
+ if pc.pipeline_parallel_size != 1 or pc.data_parallel_size != 1:
+ raise NotImplementedError(
+ "KV-prune TP compression requires pipeline_parallel_size=1 and "
+ f"data_parallel_size=1 (got PP={pc.pipeline_parallel_size}, "
+ f"DP={pc.data_parallel_size})."
+ )
+
+ hf = vc.model_config.hf_config
+ tok = llm.get_tokenizer()
+ eos_token_ids = _infer_stop_token_ids(tok, hf)
+
+ prompt_token_ids = [_prompt_to_token_ids_for_tp(llm, p) for p in plist]
+
+ max_len = int(vc.model_config.max_model_len)
+ for i, ids in enumerate(prompt_token_ids):
+ if len(ids) > max_len:
+ raise ValueError(
+ f"KV-prune TP compressed generate: prompt {i} length {len(ids)} "
+ f"exceeds max_model_len ({max_len}). Shorten the prompt or raise "
+ "max_model_len when constructing LLM()."
+ )
+
+ # Payload must be picklable for multiproc/Ray RPC: do not pass multiprocessing
+ # synchronization primitives (workers are separate processes).
+ payload: dict[str, Any] = {
+ "eos_token_ids": eos_token_ids,
+ "prompt_token_ids": prompt_token_ids,
+ "sampling_params": [
+ {
+ "temperature": float(sp.temperature),
+ "max_new_tokens": int(sp.max_tokens if sp.max_tokens is not None else 16),
+ }
+ for sp in sps
+ ],
+ "compression_params": [
+ {
+ "compression_ratio": float(c.compression_ratio),
+ "compression_method": str(c.compression_method),
+ "protected_first_tokens": int(c.protected_first_tokens),
+ "protected_last_tokens": int(c.protected_last_tokens),
+ }
+ for c in comps
+ ],
+ }
+
+ _maybe_release_v1_kv_for_compactor(llm)
+ try:
+ results = llm.llm_engine.collective_rpc(
+ "kvprune_v1_compressed_generate",
+ args=(payload,),
+ )
+ except RuntimeError as e:
+ if "cancelled" in str(e).lower():
+ raise RuntimeError(
+ "collective_rpc was cancelled (a GPU worker likely crashed). "
+ "Scroll up for the first worker traceback — often NCCL/CUDA before "
+ "TCPStore/Broken pipe on the driver."
+ ) from e
+ raise
+ master: dict[str, Any] | None = None
+ for r in results:
+ if isinstance(r, dict) and r.get("tensor_parallel_rank") == 0:
+ master = r
+ break
+ if master is None:
+ raise RuntimeError(
+ "collective_rpc did not return a dict from tensor parallel rank 0."
+ )
+ return _tp_payload_to_request_outputs(llm, master)
+
+
+def _tp_payload_to_request_outputs(llm: Any, master: dict[str, Any]) -> list[RequestOutput]:
+ tok = llm.get_tokenizer()
+ out: list[RequestOutput] = []
+ pids_list = master["prompt_token_ids"]
+ cids_list = master["completion_token_ids"]
+ for i, (pids, cids) in enumerate(zip(pids_list, cids_list)):
+ text = tok.decode(cids, skip_special_tokens=True)
+ co = CompletionOutput(
+ index=0,
+ text=text,
+ token_ids=list(cids),
+ cumulative_logprob=None,
+ logprobs=None,
+ finish_reason="stop",
+ )
+ ro = RequestOutput(
+ request_id=f"kvprune-tp-{i}",
+ prompt=None,
+ prompt_token_ids=list(pids),
+ prompt_logprobs=None,
+ outputs=[co],
+ finished=True,
+ )
+ out.append(ro)
+ return out
+
+
+def _ensure_compactor_engine(llm: Any) -> LLMEngine:
+ if llm._kvprune_compactor_engine is None:
+ pc = llm.llm_engine.vllm_config.parallel_config
+ if pc.tensor_parallel_size != 1:
+ raise ValueError(
+ "KV-pruning compactor path requires tensor_parallel_size=1 "
+ "for shared weights."
+ )
+ llm._kvprune_compactor_engine = create_compactor_engine_with_shared_weights(llm)
+ logger.info("Initialized compactor LLMEngine with weights shared from vLLM.")
+ return llm._kvprune_compactor_engine
+
+
+def try_compressed_generate(
+ llm: Any,
+ prompts: Any,
+ sampling_params: SamplingParams | Sequence[SamplingParams] | None,
+ *,
+ compression: CompressionParams | Sequence[CompressionParams] | None,
+ use_tqdm: bool | Callable[..., tqdm] = True,
+ lora_request: Any = None,
+ priority: list[int] | None = None,
+ tokenization_kwargs: dict[str, Any] | None = None,
+) -> list[RequestOutput] | None:
+ """Return completions on the compactor engine, or ``None`` to use normal v1.
+
+ ``lora_request`` / ``priority`` / ``tokenization_kwargs`` are accepted for API
+ parity with :meth:`~vllm.entrypoints.llm.LLM.generate` but are not passed to the
+ compactor engine yet.
+ """
+ del lora_request, priority, tokenization_kwargs, use_tqdm
+
+ plist = _normalize_prompt_list(prompts)
+ sps = _normalize_sampling_params(sampling_params, len(plist))
+ comps = _normalize_compression_params(compression, len(plist))
+
+ pc = llm.llm_engine.vllm_config.parallel_config
+ # TP>1: every worker must run the same collective_rpc session. If all
+ # compression_ratio >= 1, the old code returned None and only the driver ran
+ # v1 _run_engine — other ranks never joined a matching collective, which can
+ # deadlock NCCL / leave workers unsynchronized (hang at "Processed prompts:").
+ if pc.tensor_parallel_size > 1:
+ if not _should_use_kvprune_compactor_path(comps):
+ comps = [CompressionParams(compression_ratio=1.0) for _ in plist]
+ elif not _should_use_kvprune_compactor_path(comps):
+ return None
+
+ v1_eager = bool(
+ getattr(llm.llm_engine.vllm_config.model_config, "enforce_eager", False)
+ )
+ if not v1_eager:
+ logger.warning(
+ "KV-prune compression: v1 CUDA graphs are still enabled on this LLM. "
+ "The compactor does not reuse v1 graphs; capture wastes VRAM. "
+ "Set kvprune_compression=True, enforce_eager=True, or "
+ "VLLM_KVPRUNE_COMPRESSION_DEFAULT=1 before import vllm."
+ )
+
+ if pc.tensor_parallel_size > 1:
+ return _compressed_generate_tp_collective(llm, plist, sps, comps)
+
+ ensure_inprocess_engine_for_weight_sharing()
+ if llm._kvprune_compactor_engine is None:
+ _maybe_release_v1_kv_for_compactor(llm)
+ engine = _ensure_compactor_engine(llm)
+ comp_sp = [_to_compactor_sampling(sp) for sp in sps]
+ seq_c = [_to_sequence_compression(c) for c in comps]
+ batch_c = _batch_compression_from_comps(comps)
+ comp_in = [_prompt_to_compactor_input(p) for p in plist]
+
+ _, seqs = engine.generate(
+ comp_in,
+ sampling_params=comp_sp,
+ batch_compression_params=batch_c,
+ per_sequence_compression_params=seq_c,
+ return_sequences=True,
+ )
+
+ return _sequences_to_request_outputs(seqs, engine)
+
+
+def _sequences_to_request_outputs(seqs: list[Any], engine: LLMEngine) -> list[RequestOutput]:
+ tok = engine.tokenizer
+ out: list[RequestOutput] = []
+ for i, seq in enumerate(seqs):
+ text = tok.decode(seq.completion_token_ids, skip_special_tokens=True)
+ # If every emitted id is “special” (e.g. EOS / chat boundary), the stripped
+ # string is empty while ``completion_token_ids`` is non-empty — avoid
+ # presenting a blank answer so users can see boundary tokens / debug.
+ if not text.strip() and seq.completion_token_ids:
+ text = tok.decode(seq.completion_token_ids, skip_special_tokens=False)
+ co = CompletionOutput(
+ index=0,
+ text=text,
+ token_ids=list(seq.completion_token_ids),
+ cumulative_logprob=None,
+ logprobs=None,
+ finish_reason="stop",
+ )
+ ro = RequestOutput(
+ request_id=f"kvprune-{i}",
+ prompt=None,
+ prompt_token_ids=list(seq.prompt_token_ids),
+ prompt_logprobs=None,
+ outputs=[co],
+ finished=True,
+ )
+ out.append(ro)
+ return out
diff --git a/vllm/kvprune_legacy_save/integration/compression_params.py b/vllm/kvprune_legacy_save/integration/compression_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26511afb5445522fc759e8acfbe379f0ff59936
--- /dev/null
+++ b/vllm/kvprune_legacy_save/integration/compression_params.py
@@ -0,0 +1,52 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Per-request KV compression for :meth:`vllm.LLM.generate` (``compression=`` kwarg)."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+
+@dataclass
+class CompressionParams:
+ """Per-prompt compression intent for :meth:`vllm.LLM.generate`.
+
+ If **any** prompt in the batch has ``compression_ratio < 1.0``, the **whole** batch
+ is run on the compactor ``LLMEngine`` (same stack as standalone compactor-vllm:
+ ``PagedKVCache`` + pruning kernels). If all prompts have ``compression_ratio >= 1.0``,
+ the batch stays on standard vLLM.
+
+ ``compression_method`` follows :mod:`vllm.kvprune.core.compression_bridge` aliases:
+ ``none``, ``criticaladakv``, ``compactor``, ``snapkv`` (ignored when
+ ``compression_ratio`` is effectively 1).
+
+ ``protected_*`` map to compactor :class:`~vllm.kvprune.compression.compression_config.SequenceCompressionParams`
+ (defaults match standalone compactor-vllm-style usage).
+ """
+
+ compression_ratio: float = 1.0
+ compression_method: str = "compactor"
+ protected_first_tokens: int = 16
+ protected_last_tokens: int = 64
+
+ def __post_init__(self) -> None:
+ if not 0.0 < self.compression_ratio <= 1.0:
+ raise ValueError(
+ f"compression_ratio must be in (0, 1], got {self.compression_ratio}"
+ )
+ self.compression_method = (
+ self.compression_method or "compactor"
+ ).strip().lower()
+ from vllm.kvprune.core.compression_bridge import VALID_ALIASES_FOR_SAMPLING
+
+ if self.compression_method not in VALID_ALIASES_FOR_SAMPLING:
+ raise ValueError(
+ f"compression_method must be one of {sorted(VALID_ALIASES_FOR_SAMPLING)}, "
+ f"got {self.compression_method!r}"
+ )
+ if self.compression_ratio >= 1.0 - 1e-9:
+ self.compression_method = "none"
+ elif self.compression_method == "none":
+ raise ValueError(
+ "When compression_ratio < 1.0, compression_method cannot be 'none'."
+ )
diff --git a/vllm/kvprune_legacy_save/integration/config_adapter.py b/vllm/kvprune_legacy_save/integration/config_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a054096545f6c9b1218c6cec5676579d23cb92d3
--- /dev/null
+++ b/vllm/kvprune_legacy_save/integration/config_adapter.py
@@ -0,0 +1,116 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Build :class:`vllm.kvprune.config.engine_config.LLMConfig` from :class:`VllmConfig`."""
+
+from __future__ import annotations
+
+import os
+from pathlib import Path
+
+from vllm.config import VllmConfig
+from vllm.kvprune.config.engine_config import LLMConfig, KvpruneAttentionSchedule
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def _attention_schedule_from_env() -> KvpruneAttentionSchedule:
+ """Resolve :class:`KvpruneAttentionSchedule` from env.
+
+ Primary (``VLLM_KVPRUNE_ATTENTION_SCHEDULE``):
+
+ - ``fa_triton`` — FA prefill, Triton decode (default). Aliases: ``fa_prefill``,
+ ``default``, empty.
+ - ``pdtriton`` — Triton prefill + Triton decode. Aliases: ``triton``,
+ ``triton_prefill``, ``compactor_prefill``, ``pd_triton``.
+ - ``pdfa`` — FA prefill + FA decode (KV stores still Triton). Aliases:
+ ``fa_full``, ``fa_both``.
+
+ Legacy: ``VLLM_KVPRUNE_ATTENTION_BACKEND`` maps ``flash``/``fa`` → ``fa_triton``,
+ ``compactor``/``triton`` → ``pdtriton``.
+ """
+ s = os.environ.get("VLLM_KVPRUNE_ATTENTION_SCHEDULE", "").strip().lower()
+ if s in ("fa_triton", "fa_prefill", "default", ""):
+ return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ if s in ("pdtriton", "pd_triton", "triton", "triton_prefill", "compactor_prefill"):
+ return KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
+ if s in ("pdfa", "fa_full", "fa_both"):
+ return KvpruneAttentionSchedule.PDFA
+ if s:
+ logger.warning(
+ "Unknown VLLM_KVPRUNE_ATTENTION_SCHEDULE=%r; using FA_PREFILL_TRITON_DECODE",
+ s,
+ )
+ return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+
+ v = os.environ.get("VLLM_KVPRUNE_ATTENTION_BACKEND", "").strip().lower()
+ if v in ("flash", "fa", "flash_attention", "flashattention"):
+ return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ if v in ("compactor", "triton", "compactor_triton", ""):
+ return KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
+ logger.warning(
+ "Unknown VLLM_KVPRUNE_ATTENTION_BACKEND=%r; using FA_PREFILL_TRITON_DECODE", v
+ )
+ return KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+
+
+def _compactor_kvcache_page_size(vllm_block_size: int | None) -> int:
+ """Tokens per physical KV page for compactor :class:`LLMConfig`.
+
+ vLLM ``block_size`` is often 16; compactor ``head_sparse_decode_attention`` requires
+ ``PAGE_SIZE % 32 == 0`` (see ``kvprune/attention/sparse_decode_kernel.py``). Standalone
+ compactor-vllm defaults to 128. Round up to the next multiple of 32 when needed.
+ """
+ if vllm_block_size is None:
+ return 128
+ bs = int(vllm_block_size)
+ if bs <= 0:
+ return 128
+ if bs % 32 == 0:
+ return bs
+ return ((bs + 31) // 32) * 32
+
+
+def vllm_config_to_llm_config(vc: VllmConfig) -> LLMConfig:
+ """Map vLLM engine config to compactor :class:`LLMConfig`."""
+ mc = vc.model_config
+ cc = vc.cache_config
+ pc = vc.parallel_config
+ sc = vc.scheduler_config
+ block_size = cc.block_size
+ if block_size is None:
+ block_size = 16
+ max_num_seqs = getattr(sc, "max_num_seqs", 256)
+ # Do **not** forward ``model_config.enforce_eager`` (v1) into compactor
+ # :class:`LLMConfig`. They are independent flags: v1 uses it only to skip
+ # *v1* ``capture_model()``; kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner`
+ # uses :attr:`LLMConfig.enforce_eager` only for *compactor* decode CUDA graphs.
+ # Shared-weight setup in ``compactor_shared`` defaults compactor to eager decode;
+ # see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` (default try graphs) /
+ # ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``.
+ # Local checkpoint directory: forward so compactor skips redundant Hub fetches.
+ _model_s = str(mc.model)
+ _path: str | None = None
+ try:
+ if _model_s and Path(_model_s).is_dir() and (Path(_model_s) / "config.json").is_file():
+ _path = str(Path(_model_s).resolve())
+ except OSError:
+ pass
+
+ return LLMConfig(
+ model=_model_s,
+ path=_path,
+ nccl_port=1218,
+ max_num_seqs=max_num_seqs,
+ max_model_len=mc.max_model_len,
+ gpu_memory_utilization=cc.gpu_memory_utilization,
+ tensor_parallel_size=pc.tensor_parallel_size,
+ enforce_eager=False,
+ hf_config=mc.hf_config,
+ eos=-1,
+ eos_token_ids=None,
+ kvcache_page_size=_compactor_kvcache_page_size(block_size),
+ leverage_sketch_size=48,
+ attention_schedule=_attention_schedule_from_env(),
+ attention_backend=None,
+ )
diff --git a/vllm/kvprune_legacy_save/integration/v1_tp_runner.py b/vllm/kvprune_legacy_save/integration/v1_tp_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc8fc6b5245f9316e574133fbe4f5f91533e02d
--- /dev/null
+++ b/vllm/kvprune_legacy_save/integration/v1_tp_runner.py
@@ -0,0 +1,203 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""TP>1: one kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner` per vLLM worker.
+
+Invoked via v1 ``collective_rpc("kvprune_v1_compressed_generate", ...)`` so every tensor-
+parallel rank participates in the same compactor forward/broadcast sequence as the
+standalone multi-process compactor.
+
+Compactor decode CUDA graphs (when not ``enforce_eager``) capture the full decode step
+including ``compute_logits``. To force eager on embedded TP workers, set
+``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` or ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``.
+
+Peer/master session boundaries use TP-group ``broadcast``/``barrier`` (see
+``ModelRunner.maybe_release_peers``), not ``multiprocessing.Event`` — RPC payloads must
+be picklable across worker processes.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Any
+
+import torch
+import torch.nn as nn
+
+from vllm.kvprune.compression.compression_config import (
+ BatchCompressionParams,
+ SequenceCompressionParams,
+)
+from vllm.kvprune.config.sampling_params import SamplingParams as CompactorSamplingParams
+from vllm.kvprune.core.compression_bridge import (
+ compression_method_id_to_enum,
+ compression_method_str_to_id,
+)
+from vllm.kvprune.core.model_runner import ModelRunner
+from vllm.kvprune.integration.config_adapter import vllm_config_to_llm_config
+from vllm.kvprune.utils.kv_dist import barrier_sync
+from vllm.kvprune.integration.weight_tie import (
+ delegate_kvprune_compute_logits_to_vllm,
+ delegate_kvprune_embed_tokens_to_vllm,
+ tie_kvprune_rope_buffers_from_vllm,
+ tie_kvprune_weights_from_vllm,
+)
+from vllm.kvprune.models import MODEL_REGISTRY
+from vllm.kvprune.utils.sequence import Sequence
+
+_ATTR = "_kvprune_tp_embedded_runner"
+
+
+def _apply_compactor_env_overrides(cfg: Any) -> None:
+ """Match :func:`~vllm.kvprune.integration.compactor_shared.create_compactor_engine_with_shared_weights` caps."""
+ _cap = os.environ.get("VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS", "32").strip()
+ if _cap:
+ lim = int(_cap)
+ if lim > 0:
+ cfg.max_num_seqs = min(cfg.max_num_seqs, lim)
+
+ _ce = os.environ.get("VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER", "").strip().lower()
+ if _ce in ("1", "true", "yes"):
+ cfg.enforce_eager = True
+ elif _ce in ("0", "false", "no"):
+ cfg.enforce_eager = False
+ else:
+ _dg = os.environ.get("VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH", "1").strip().lower()
+ cfg.enforce_eager = _dg in ("0", "false", "no")
+
+
+def _build_sequences(payload: dict[str, Any]) -> list[Sequence]:
+ prompt_ids: list[list[int]] = payload["prompt_token_ids"]
+ sps: list[dict[str, Any]] = payload["sampling_params"]
+ cps: list[dict[str, Any]] = payload["compression_params"]
+ seqs: list[Sequence] = []
+ for i, ids in enumerate(prompt_ids):
+ sp = CompactorSamplingParams(
+ temperature=float(sps[i]["temperature"]),
+ max_new_tokens=int(sps[i]["max_new_tokens"]),
+ )
+ cp = SequenceCompressionParams(
+ compression_ratio=float(cps[i]["compression_ratio"]),
+ protected_first_tokens=int(cps[i].get("protected_first_tokens", 16)),
+ protected_last_tokens=int(cps[i].get("protected_last_tokens", 64)),
+ )
+ if cp.protected_first_tokens + cp.protected_last_tokens >= len(ids):
+ cp.compression_ratio = 1.0
+ seqs.append(
+ Sequence(
+ prompt_token_ids=list(ids),
+ sampling_params=sp,
+ compression_params=cp,
+ )
+ )
+ return seqs
+
+
+def _batch_compression_from_payload(payload: dict[str, Any]) -> BatchCompressionParams:
+ cps = payload["compression_params"]
+ for c in cps:
+ if float(c["compression_ratio"]) < 1.0:
+ mid = compression_method_str_to_id(str(c.get("compression_method", "none")))
+ return BatchCompressionParams(
+ compression_method=compression_method_id_to_enum(mid)
+ )
+ return BatchCompressionParams()
+
+
+def _get_or_create_runner(worker: Any, payload: dict[str, Any]) -> ModelRunner:
+ existing = getattr(worker, _ATTR, None)
+ if existing is not None:
+ return existing
+
+ from vllm.distributed.parallel_state import (
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ )
+
+ vc = worker.vllm_config
+ pc = vc.parallel_config
+ if pc.pipeline_parallel_size != 1 or pc.data_parallel_size != 1:
+ raise NotImplementedError(
+ "KV-prune TP compressed generate requires pipeline_parallel_size=1 and "
+ f"data_parallel_size=1; got PP={pc.pipeline_parallel_size}, "
+ f"DP={pc.data_parallel_size}."
+ )
+
+ tp_ws = get_tensor_model_parallel_world_size()
+ if tp_ws != pc.tensor_parallel_size:
+ raise RuntimeError(
+ f"parallel_state TP world size {tp_ws} != config.tensor_parallel_size "
+ f"{pc.tensor_parallel_size}"
+ )
+
+ hf = vc.model_config.hf_config
+ model_type = getattr(hf, "model_type", None)
+ if model_type not in MODEL_REGISTRY:
+ raise ValueError(
+ f"KV-prune TP path: unsupported model_type={model_type!r}; "
+ f"registry has {sorted(MODEL_REGISTRY)}"
+ )
+
+ cfg = vllm_config_to_llm_config(vc)
+ eos_ids = payload["eos_token_ids"]
+ cfg.eos_token_ids = sorted({int(x) for x in eos_ids})
+ cfg.eos = int(cfg.eos_token_ids[0])
+ _apply_compactor_env_overrides(cfg)
+
+ vllm_model = worker.get_model()
+ kv_model: nn.Module = MODEL_REGISTRY[model_type](hf)
+ tie_kvprune_weights_from_vllm(vllm_model, kv_model)
+
+ dev = next(vllm_model.parameters()).device
+ dtype = next(vllm_model.parameters()).dtype
+ kv_model.to(device=dev, dtype=dtype)
+ tie_kvprune_rope_buffers_from_vllm(vllm_model, kv_model)
+ delegate_kvprune_embed_tokens_to_vllm(vllm_model, kv_model)
+ delegate_kvprune_compute_logits_to_vllm(vllm_model, kv_model)
+
+ tp_rank = get_tensor_model_parallel_rank()
+ device = torch.device(f"cuda:{torch.cuda.current_device()}")
+
+ if tp_rank == 0:
+ runner = ModelRunner(
+ cfg,
+ rank=0,
+ peer_events=[],
+ external_model=kv_model,
+ embedded_in_vllm_worker=True,
+ device=device,
+ )
+ else:
+ runner = ModelRunner(
+ cfg,
+ rank=tp_rank,
+ batch_ready=None,
+ external_model=kv_model,
+ embedded_in_vllm_worker=True,
+ device=device,
+ )
+
+ setattr(worker, _ATTR, runner)
+ return runner
+
+
+def run_kvprune_tp_compressed_generate(worker: Any, payload: dict[str, Any]) -> dict[str, Any]:
+ """Execute one compressed generation session on this worker (all TP ranks)."""
+ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
+
+ tp_rank = get_tensor_model_parallel_rank()
+ runner = _get_or_create_runner(worker, payload)
+ sequences = _build_sequences(payload)
+ batch_c = _batch_compression_from_payload(payload)
+
+ barrier_sync(use_tp_group=True)
+
+ if tp_rank == 0:
+ runner.generate(sequences, batch_c)
+ return {
+ "tensor_parallel_rank": 0,
+ "prompt_token_ids": [list(s.prompt_token_ids) for s in sequences],
+ "completion_token_ids": [list(s.completion_token_ids) for s in sequences],
+ }
+
+ runner.run_peer_session()
+ return {"tensor_parallel_rank": int(tp_rank), "ok": True}
diff --git a/vllm/kvprune_legacy_save/integration/vllm_model_access.py b/vllm/kvprune_legacy_save/integration/vllm_model_access.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b23c91f958c5376061c4bb499e1868491350f1c
--- /dev/null
+++ b/vllm/kvprune_legacy_save/integration/vllm_model_access.py
@@ -0,0 +1,46 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Access the in-process vLLM model weights for compactor weight sharing."""
+
+from __future__ import annotations
+
+import torch.nn as nn
+
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def extract_vllm_causal_lm(llm: object) -> nn.Module:
+ """Return the root ``nn.Module`` holding transformer + lm_head from a v1 ``LLM``.
+
+ Requires ``LLMEngine`` to have been constructed with ``multiprocess_mode=False``
+ so ``model_executor`` lives in-process (set ``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
+ """
+ llm_engine = getattr(llm, "llm_engine", None)
+ if llm_engine is None:
+ raise RuntimeError("Expected an object with a ``llm_engine`` attribute (e.g. ``vllm.LLM``).")
+
+ ex = getattr(llm_engine, "model_executor", None)
+ if ex is None:
+ raise RuntimeError(
+ "model_executor is unavailable (multiprocess engine mode). "
+ "Set environment variable VLLM_ENABLE_V1_MULTIPROCESSING=0 for "
+ "in-process weight sharing."
+ )
+
+ driver = getattr(ex, "driver_worker", None)
+ if driver is None:
+ raise RuntimeError(
+ "Executor has no driver_worker (unexpected executor type for weight sharing)."
+ )
+
+ worker = getattr(driver, "worker", None)
+ if worker is None:
+ raise RuntimeError("Worker wrapper has no worker loaded.")
+
+ get_model = getattr(worker, "get_model", None)
+ if not callable(get_model):
+ raise RuntimeError("Worker does not expose get_model().")
+
+ return get_model()
diff --git a/vllm/kvprune_legacy_save/integration/weight_tie.py b/vllm/kvprune_legacy_save/integration/weight_tie.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d2356e763dd970e8d0002ae5d7df205d56e5f13
--- /dev/null
+++ b/vllm/kvprune_legacy_save/integration/weight_tie.py
@@ -0,0 +1,192 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Share vLLM parameter storage with compactor ``MODEL_REGISTRY`` models (TP=1)."""
+
+from __future__ import annotations
+
+import types
+
+import torch
+import torch.nn as nn
+
+from vllm.kvprune.utils.context import get_context
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def tie_kvprune_weights_from_vllm(
+ vllm_model: nn.Module,
+ kvprune_model: nn.Module,
+ *,
+ strict: bool = True,
+) -> int:
+ """Point compactor parameters to the same tensors as vLLM where names match.
+
+ Returns the number of parameters tied. Requires identical parameter names
+ and shapes for overlapping weights (typical when both stacks mirror HF
+ naming for the same architecture).
+
+ Args:
+ vllm_model: Model returned by ``worker.get_model()`` (e.g. ``Qwen3ForCausalLM``).
+ kvprune_model: Instance from ``vllm.kvprune.models.MODEL_REGISTRY``.
+ strict: If True, raise when any ``kvprune`` parameter name is missing from
+ ``vllm_model`` or shapes differ.
+ """
+ vd = dict(vllm_model.named_parameters())
+ kd = dict(kvprune_model.named_parameters())
+ tied = 0
+ for name, kp in kd.items():
+ if name not in vd:
+ if strict:
+ raise ValueError(
+ f"kvprune parameter {name!r} not found in vLLM model; "
+ "architecture/layout may differ (disable strict tying only "
+ "for expert debugging)."
+ )
+ continue
+ vp = vd[name]
+ if vp.shape != kp.shape:
+ raise ValueError(
+ f"Shape mismatch for {name}: vllm {vp.shape} vs kvprune {kp.shape}"
+ )
+ kp.data = vp.data
+ tied += 1
+ if tied == 0:
+ raise ValueError(
+ "No parameters were tied — check that vLLM and kvprune model types match "
+ "and use the same state_dict names."
+ )
+ logger.info("Tied %d parameters from vLLM into compactor model (shared storage).", tied)
+ return tied
+
+
+def tie_kvprune_rope_buffers_from_vllm(
+ vllm_model: nn.Module,
+ kvprune_model: nn.Module,
+) -> int:
+ """Copy RoPE ``cos_sin_cache`` buffers from vLLM into kvprune.
+
+ :func:`tie_kvprune_weights_from_vllm` only aliases :class:`~torch.nn.Parameter`
+ tensors. RoPE tables live in buffers; kvprune's simplified ``RotaryEmbedding``
+ can disagree with vLLM's ``rope_parameters`` (YaRN, etc.). Copying
+ ``cos_sin_cache`` after ``.to(device, dtype)`` keeps Q/K rotation aligned with
+ the main model.
+
+ kvprune uses layout ``[max_len, 1, rotary_dim]``; vLLM uses ``[max_len,
+ rotary_dim]``. The singleton dim is filled via ``unsqueeze(1)`` on the vLLM
+ tensor when copying.
+ """
+ vd = dict(vllm_model.named_buffers())
+ copied = 0
+ for name, kb in kvprune_model.named_buffers():
+ if "cos_sin_cache" not in name:
+ continue
+ if name not in vd:
+ logger.warning(
+ "kvprune RoPE buffer %r not found in vLLM; leaving kvprune cache",
+ name,
+ )
+ continue
+ vb = vd[name]
+ if vb.shape == kb.shape:
+ kb.copy_(vb)
+ copied += 1
+ elif kb.dim() == 3 and vb.dim() == 2:
+ if (
+ kb.shape[0] != vb.shape[0]
+ or kb.shape[2] != vb.shape[1]
+ or kb.shape[1] != 1
+ ):
+ raise ValueError(
+ f"cos_sin_cache shape mismatch for {name!r}: "
+ f"vLLM {tuple(vb.shape)} vs kvprune {tuple(kb.shape)}"
+ )
+ kb.copy_(vb.unsqueeze(1))
+ copied += 1
+ else:
+ raise ValueError(
+ f"Unsupported cos_sin_cache layout for {name!r}: "
+ f"vLLM {tuple(vb.shape)} vs kvprune {tuple(kb.shape)}"
+ )
+ if copied:
+ logger.info(
+ "Copied %d RoPE cos_sin_cache buffer(s) from vLLM into kvprune model.",
+ copied,
+ )
+ return copied
+
+
+def delegate_kvprune_embed_tokens_to_vllm(
+ vllm_model: nn.Module,
+ kvprune_model: nn.Module,
+) -> bool:
+ """Use vLLM's ``model.embed_tokens`` forward for kvprune (TP-safe token→shard mapping).
+
+ Even with tied weights, kvprune's simplified contiguous
+ ``VocabParallelEmbedding`` (``vocab_start = rank * partition``) can disagree with
+ vLLM's padded vocabulary and org/added shard ranges, producing invalid indices for
+ ``F.embedding`` on non-zero TP ranks (``index_copy_`` / device-side assert).
+
+ Delegating the forward to vLLM's embedding module keeps masks and indices aligned
+ with the main model while parameters remain shared storage.
+ """
+ if not hasattr(vllm_model, "model") or not hasattr(kvprune_model, "model"):
+ return False
+ vm = getattr(vllm_model.model, "embed_tokens", None)
+ km = getattr(kvprune_model.model, "embed_tokens", None)
+ if vm is None or km is None:
+ logger.warning(
+ "delegate_kvprune_embed_tokens_to_vllm: embed_tokens missing; skipped"
+ )
+ return False
+
+ def _forward(_self_unused: nn.Module, x):
+ return vm(x)
+
+ km.forward = types.MethodType(_forward, km)
+ logger.info(
+ "kvprune model.embed_tokens forward delegated to vLLM (correct vocab-parallel masks)."
+ )
+ return True
+
+
+def delegate_kvprune_compute_logits_to_vllm(
+ vllm_model: nn.Module,
+ kvprune_model: nn.Module,
+) -> bool:
+ """Route ``kvprune_model.compute_logits`` through vLLM's ``compute_logits``.
+
+ Standalone compactor used :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`
+ with ``F.linear`` + TP gather. vLLM applies :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
+ (gather/all-gather, padded-vocab trim, quant hooks). Mismatch here commonly
+ produces garbage token distributions while the rest of the stack looks fine.
+
+ After weight tying, ``vllm_model.compute_logits(hidden)`` uses the same lm_head
+ storage as kvprune; only the *application* path matches production vLLM.
+ """
+ if not callable(getattr(vllm_model, "compute_logits", None)):
+ logger.warning(
+ "delegate_kvprune_compute_logits_to_vllm: vLLM model has no compute_logits; skipped"
+ )
+ return False
+
+ def _compute_logits(_self: nn.Module, hidden_states):
+ # Match kvprune :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`:
+ # prefill logits are for the **last** token of each packed sequence only.
+ context = get_context()
+ if context.is_prefill and context.cu_seqlens_q is not None:
+ cuq = context.cu_seqlens_q
+ last_indices = (cuq[1:] - 1).to(torch.long)
+ n_tok = hidden_states.shape[0]
+ if n_tok > 0:
+ last_indices = last_indices.clamp(min=0, max=n_tok - 1)
+ hidden_states = hidden_states[last_indices].contiguous()
+ # vLLM lm_head + gather expect contiguous activations; non-contiguous views have
+ # caused garbage logits under TP in edge cases.
+ hidden_states = hidden_states.contiguous()
+ logits = vllm_model.compute_logits(hidden_states)
+ return logits
+
+ kvprune_model.compute_logits = types.MethodType(_compute_logits, kvprune_model)
+ return True
diff --git a/vllm/kvprune_legacy_save/kv_cache/__init__.py b/vllm/kvprune_legacy_save/kv_cache/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5ddb214b7e6e1b22d8a1e2892643accb28d38d3
--- /dev/null
+++ b/vllm/kvprune_legacy_save/kv_cache/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Paged KV cache helpers and Triton KV store."""
+
+from vllm.kvprune.kv_cache.store_kv_cache import (
+ decode_store_kv,
+ prefill_store_all_kv,
+ prefill_store_topk_kv,
+)
+
+__all__ = [
+ "decode_store_kv",
+ "prefill_store_all_kv",
+ "prefill_store_topk_kv",
+]
diff --git a/vllm/kvprune_legacy_save/kv_cache/page_table.py b/vllm/kvprune_legacy_save/kv_cache/page_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..aed26907feb316e4b3b02618181c41cf18b00835
--- /dev/null
+++ b/vllm/kvprune_legacy_save/kv_cache/page_table.py
@@ -0,0 +1,313 @@
+import heapq
+import logging
+from enum import Enum, auto
+from typing import List, Optional, Union
+
+import torch
+from vllm.kvprune.config.constants import RESERVED_BATCH
+from vllm.kvprune.kv_cache.write_page_table import scatter_to_page_table
+
+logger = logging.getLogger(__name__)
+
+
+def cdiv(a, b):
+ return (a + b - 1) // b
+
+
+def next_multiple(a, b):
+ return cdiv(a, b) * b
+
+
+class KVAllocationStatus(Enum):
+ EXCEEDS_MAX_SEQUENCE_LENGTH = auto()
+ EXCEEDS_CURRENTLY_AVAILABLE_PAGES = auto()
+ EXCEEDS_MAX_NUM_BATCHES = auto()
+ SUCCESS = auto()
+
+
+class PagedKVCache(torch.nn.Module):
+ """
+ Global paged KV cache.
+ This module manages:
+ * A global K/V backing buffer for all layers:
+ ``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
+ where the first dimension indexes K vs V.
+ * A per-layer page table:
+ ``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
+ mapping logical (batch, kv-head, logical_page) to a physical page ID
+ in the global K/V buffer.
+ * Per-layer, per-(batch, kv-head) logical sequence lengths
+ ``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
+ the number of allocated pages ``bh_num_pages`` for each (layer, batch,
+ head).
+ * A page allocator implemented as a min-heap of free physical pages
+ per layer, plus free batch indices.
+ Pages are of fixed size ``page_size`` tokens.
+ Args:
+ :param num_layers:
+ Number of transformer layers that will use this cache.
+ :param max_logical_pages_per_head:
+ Maximum number of logical pages that can be assigned to a single
+ (batch, kv-head) pair.
+ :param num_pages:
+ Total number of physical pages available in the global cache per
+ layer. The global K/V buffers are of length
+ ``num_pages * page_size`` along the token dimension.
+ :param page_size:
+ Number of tokens stored per page.
+ :param H_kv:
+ Number of KV heads per layer.
+ :param head_dim:
+ Head dimension for K/V.
+ :param max_num_batches:
+ Maximum number of concurrent batches / sequences supported. One
+ batch index is reserved for internal use (``RESERVED_BATCH``).
+ :param dtype:
+ Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
+ :param device:
+ Device on which to allocate the cache (string, torch.device, or
+ int; defaults to ``"cuda"``).
+ """
+
+ def __init__(
+ self,
+ num_layers: int,
+ max_logical_pages_per_head: int,
+ num_pages: int,
+ page_size: int, # tokens per page
+ H_kv: int,
+ head_dim: int,
+ max_num_batches: int,
+ dtype: torch.dtype,
+ device: Union[str, torch.device, int] = "cuda",
+ ):
+ super().__init__()
+ self.n_pages = num_pages
+ self.num_layers = num_layers
+ self.page_size: int = int(page_size)
+ self.H_kv = int(H_kv)
+ self.max_pages_per_head = max_logical_pages_per_head
+ max_num_batches += 1
+ self.max_num_batches = max_num_batches
+ self.head_dim = head_dim
+ cache_shape = (2, num_layers, num_pages * page_size, head_dim)
+ self.kv_cache = torch.empty(cache_shape, dtype=dtype, device=device)
+
+ self.page_table = torch.empty(
+ (num_layers, max_num_batches, H_kv, self.max_pages_per_head),
+ device=device,
+ dtype=torch.int32,
+ )
+
+ # Per-(batch, head) logical seq length (tokens)
+ self.bh_seq_lens = torch.zeros(
+ (num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
+ )
+ # self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
+ self.bh_num_pages = torch.zeros(
+ (num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
+ )
+
+ # Page allocator (min-heap of free physical pages)
+ self.free_pages: List[List[int]] = [
+ list(range(num_pages)) for _ in range(num_layers)
+ ]
+ for free_pages in self.free_pages:
+ heapq.heapify(free_pages)
+ # batch zero is reserved
+ self.free_batches: List[int] = list(reversed(range(max_num_batches)))
+ self.free_batches.remove(RESERVED_BATCH)
+ # Record of physical page ids owned by a batch (for freeing)
+ self.pages_indices_per_batch: List[List[set[int]]] = [
+ [set() for _ in range(num_layers)] for _ in range(max_num_batches)
+ ]
+
+ def new_batch(self) -> Optional[int]:
+ """
+ Reserve a new batch slot.
+ A batch slot corresponds to a row in ``bh_seq_lens`` /
+ ``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
+ heads. This method checks whether a free batch index is available, and
+ whether each layer has at least ``H_kv`` free pages remaining.
+ If both checks pass, it returns a batch index and removes it from
+ ``free_batches``. Otherwise, it returns ``None``.
+
+ Returns:
+ :return Optional[int]:
+ Newly reserved batch index, or ``None`` if no capacity is
+ available.
+ """
+ if self.free_batches and all([self.H_kv <= len(fp) for fp in self.free_pages]):
+ return self.free_batches.pop()
+ return None
+
+ def reserve_tokens(self, batch_index: int, add_tokens: int) -> KVAllocationStatus:
+ """
+ Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
+ Args:
+ :param batch_index:
+ Batch index to reserve space for.
+ :param add_tokens:
+ Number of additional tokens to reserve capacity for.
+ All heads in this batch and all layers reserve
+ the same number of extra tokens.
+ Returns:
+ :return bool:
+ ``True`` if the reservation succeeds; ``False`` otherwise .
+ """
+ cur_bh_lens = self.bh_seq_lens[:, batch_index] # [L, H]
+ curr_pages = self.bh_num_pages[:, batch_index] # [L, H]
+ curr_cap_tokens = curr_pages * self.page_size # [L, H]
+ need_tokens = cur_bh_lens + add_tokens # [L, H]
+ if (need_tokens <= curr_cap_tokens).all():
+ return KVAllocationStatus.SUCCESS
+ missing_tokens = need_tokens - curr_cap_tokens
+ add_pages = cdiv(missing_tokens, self.page_size)
+ new_total_pages = curr_pages + add_pages
+ if (new_total_pages > self.max_pages_per_head).any():
+ return KVAllocationStatus.EXCEEDS_MAX_SEQUENCE_LENGTH
+ # CPU work
+ pages_per_layer_cpu = add_pages.sum(dim=-1).tolist()
+ new_phys_pages = []
+ for layer_index in range(self.num_layers):
+ if pages_per_layer_cpu[layer_index] > len(self.free_pages[layer_index]):
+ return KVAllocationStatus.EXCEEDS_CURRENTLY_AVAILABLE_PAGES
+ for layer_index in range(self.num_layers):
+ this_layer_pages = [
+ heapq.heappop(self.free_pages[layer_index])
+ for _ in range(pages_per_layer_cpu[layer_index])
+ ]
+ self.pages_indices_per_batch[batch_index][layer_index] |= set(
+ this_layer_pages
+ )
+ new_phys_pages.extend(this_layer_pages)
+
+ new_phys_pages = torch.tensor(new_phys_pages, dtype=torch.int32, device="cuda")
+
+ scatter_to_page_table(
+ add_pages=add_pages,
+ new_phys_pages=new_phys_pages,
+ curr_pages=curr_pages,
+ page_table=self.page_table[:, batch_index],
+ max_pages_per_head=self.max_pages_per_head,
+ )
+
+ self.bh_num_pages[:, batch_index, :] = new_total_pages.to(
+ self.bh_num_pages.dtype
+ )
+ return KVAllocationStatus.SUCCESS
+
+ def reclaim_pages(
+ self,
+ batch_index: int,
+ future_reserve_tokens: int = 0,
+ ):
+ """
+ Reclaim unused pages for a single batch index. This shrinks the KV
+ allocation for the batch down to the minimum number of pages needed
+ to hold the current (plus optional future) sequence length.
+
+ Args:
+ :param batch_index:
+ Batch index whose pages should be compacted.
+ :param future_reserve_tokens:
+ Optional number of extra tokens to keep capacity for, beyond
+ the current sequence length. This can reduce churn when
+ sequences are expected to grow slightly in the near future.
+
+ Returns:
+ :return int:
+ Approximate number of bytes freed across both K and V.
+ """
+ device = self.bh_seq_lens.device
+ L, B, H = self.bh_seq_lens.shape
+ assert 0 <= batch_index < B
+
+ seq = self.bh_seq_lens[:, batch_index, :] + future_reserve_tokens # [L, H]
+ alloc = self.bh_num_pages[:, batch_index, :] # [L, H]
+ pt = self.page_table[:, batch_index, :, :].reshape(-1) # [L, H, P]
+
+ # Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
+ used_pages = cdiv(seq, self.page_size)
+ used_pages = torch.minimum(used_pages, alloc)
+
+ # page indices [0..P-1], broadcasted over [L, H, P]
+ p = torch.arange(
+ self.max_pages_per_head, device=device, dtype=torch.int32
+ ).view(1, 1, self.max_pages_per_head)
+
+ # allocated: p < alloc
+ alloc_mask = p < alloc.unsqueeze(-1) # [L, H, P]
+ # to free: allocated and p in [used_pages, alloc)
+ free_mask = alloc_mask & (p >= used_pages.unsqueeze(-1))
+ free_mask_flat = free_mask.view(-1) # [L*H*P]
+ if not free_mask_flat.any():
+ return 0
+
+ idx = free_mask_flat.nonzero(as_tuple=False).squeeze(
+ -1
+ ) # indices of freed slots
+
+ # Freed physical page ids
+ freed_pages = pt[idx]
+ # Compute layer index for each freed slot:
+ # layout is [L, H, P] → flat index = ((l * H) + h) * P + p
+ freed_layers = (idx // (H * self.max_pages_per_head)).to(torch.int32)
+ freed_pages = freed_pages.tolist()
+ layer_mapping = freed_layers.tolist()
+ self.bh_num_pages[:, batch_index, :] = used_pages
+ for page, layer in zip(freed_pages, layer_mapping):
+ self.pages_indices_per_batch[batch_index][layer].remove(page)
+ heapq.heappush(self.free_pages[layer], page)
+ approximate_bytes_freed = (
+ len(freed_pages)
+ * (self.page_size * self.head_dim * self.kv_cache.element_size())
+ * 2
+ ) # multiply for two for K + V
+ return approximate_bytes_freed
+
+ def _free_batch_layer(self, layer_index: int, batch_index: int) -> None:
+ """
+ Free all pages belonging to batch_index and reset its metadata.
+ """
+ # Return pages to the global heap
+ for phys in self.pages_indices_per_batch[batch_index][layer_index]:
+ heapq.heappush(self.free_pages[layer_index], int(phys))
+
+ self.pages_indices_per_batch[batch_index][layer_index] = set()
+
+ def free_batch(self, batch_index: int) -> None:
+ """
+ Free all resources associated with a batch index.
+ Args:
+ :param batch_index:
+ Batch index to release. Must have been previously allocated
+ via :meth:`new_batch`.
+ """
+ for layer in range(self.num_layers):
+ self._free_batch_layer(layer, batch_index)
+ self.bh_seq_lens[:, batch_index].zero_()
+ self.bh_num_pages[:, batch_index].zero_()
+ self.free_batches.append(batch_index)
+
+ def layer_slices(self, layer: int):
+ """
+ Return layer-local views needed by the attention module.
+
+ For a given ``layer`` index, this method returns the slices of the
+ global K/V cache, page table, and per-(batch, head) sequence lengths
+ corresponding to that layer.
+ Args:
+ :param layer:
+ Layer index ``l`` in ``[0, num_layers)``.
+
+ Returns:
+ :return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ ``(k, v, pt, bh)`` as described above.
+ """
+ assert 0 <= layer < self.num_layers
+ k = self.kv_cache[0, layer]
+ v = self.kv_cache[1, layer]
+ pt = self.page_table[layer]
+ bh = self.bh_seq_lens[layer]
+ return k, v, pt, bh
diff --git a/vllm/kvprune_legacy_save/kv_cache/store_kv_cache.py b/vllm/kvprune_legacy_save/kv_cache/store_kv_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9d07c00c98d7a203d8415b423fee3bc62ef0a72
--- /dev/null
+++ b/vllm/kvprune_legacy_save/kv_cache/store_kv_cache.py
@@ -0,0 +1,468 @@
+import torch
+import triton
+import triton.language as tl
+from vllm.kvprune.config.constants import (
+ TRITON_RESERVED_BATCH as _TRITON_RESERVED_BATCH,
+)
+
+
+@triton.jit
+def _prefill_store_topk_kv_kernel(
+ key,
+ value, # [N_total, H, D] (D stride assumed 1)
+ batch_mapping, # [B] int32 (local b -> true batch)
+ num_tokens_to_retain, # [B] int32
+ indices_topk, # [B, MAX_SEL] int32 (across all heads)
+ # Lengths & page table:
+ bh_lens, # [B, H] int32 (contiguous)
+ page_table, # [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
+ k_cache,
+ v_cache, # [N_PAGES * PAGE_SIZE, D]
+ sk_n,
+ sk_h, # strides for key,value. D stride assumed 1
+ sv_n,
+ sv_h,
+ # Runtime ints
+ MAX_SEL, # num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ K_TILE: tl.constexpr, # how many selected tokens each program processes
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ b_local = tl.program_id(0)
+ tile_id = tl.program_id(1)
+ offs = tl.arange(0, D)
+ # how many tokens we actually keep for this batch
+ k_total = tl.load(num_tokens_to_retain + b_local)
+ if k_total == 0:
+ return
+ # map to true batch row in the page table
+ b_true = tl.load(batch_mapping + b_local)
+ if b_true == TRITON_RESERVED_BATCH:
+ return
+ base = tile_id * K_TILE
+ # process up to K_TILE tokens
+ for j in tl.range(0, K_TILE):
+ sel_idx = base + j
+ if sel_idx < k_total and sel_idx < MAX_SEL:
+ # flattened selection: sel = token * H + head
+ sel = tl.load(indices_topk + b_local * MAX_SEL + sel_idx)
+ tok = sel // HKV
+ head = sel - (tok * HKV)
+ # atomically reserve one position in (b_local, hed)
+ # i.e the KV cache is scrambled when storing
+ len_ptr = bh_lens + b_local * HKV + head
+ pos = tl.atomic_add(len_ptr, 1) # old length (int32)
+ lp = pos // PAGE_SIZE
+ off = pos - lp * PAGE_SIZE
+ # translate logical page to physical page
+ pt_base = (b_true * HKV + head) * N_LOGICAL_PAGES_MAX
+ phys = tl.load(page_table + pt_base + lp).to(tl.int64)
+ # destination row and element offset
+ dst_row = phys * PAGE_SIZE + off
+ dst_off = dst_row * D + offs
+ # load one vector from [N_total, H, D]
+ k_src = key + tok * sk_n + head * sk_h + offs
+ v_src = value + tok * sv_n + head * sv_h + offs
+ tl.store(
+ k_cache + dst_off,
+ tl.load(k_src, cache_modifier=".cv", eviction_policy="evict_first"),
+ eviction_policy="evict_first",
+ )
+ tl.store(
+ v_cache + dst_off,
+ tl.load(v_src, cache_modifier=".cv", eviction_policy="evict_first"),
+ eviction_policy="evict_first",
+ )
+
+
+def prefill_store_topk_kv(
+ *,
+ new_keys: torch.Tensor, # [N_total, H, D]
+ new_vals: torch.Tensor, # [N_total, H, D]
+ indices_topk: torch.Tensor, # [B, MAX_SEL] int32 (global flattened token*H + head)
+ num_tokens_to_retain: torch.Tensor, # [B] int32
+ page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
+ batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
+ bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
+ k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ PAD_TO_PAGE_SIZE: bool = True,
+ cu_seqlens_k: torch.Tensor | None = None,
+ K_TILE: int = 16,
+ TRITON_RESERVED_BATCH: int = None,
+):
+ assert new_keys.shape == new_vals.shape
+ N_total, H, D = new_keys.shape
+ B = indices_topk.shape[0]
+ assert page_table.shape[1] == H
+ assert bh_lens.shape == (B, H)
+ assert new_keys.device == k_cache.device == v_cache.device
+ assert page_table.is_contiguous(), "page table must be contiguous."
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+ assert new_keys.stride(-1) == 1 and new_vals.stride(-1) == 1, (
+ "new_keys/new_vals last dim must be contiguous."
+ )
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ page_table = page_table.to(torch.int32)
+ bh_lens = bh_lens.to(torch.int32)
+ batch_mapping = batch_mapping.to(torch.int32)
+ indices_topk = indices_topk.to(torch.int32)
+ num_tokens_to_retain = num_tokens_to_retain.to(torch.int32)
+
+ # strides (elements) for [N_total, H, D]
+ sk_n, sk_h, _ = new_keys.stride()
+ sv_n, sv_h, _ = new_vals.stride()
+
+ # tile second grid dim
+ MAX_SEL = indices_topk.shape[-1]
+ N_TILES = (MAX_SEL + K_TILE - 1) // K_TILE
+ grid = (B, max(1, N_TILES))
+ if TRITON_RESERVED_BATCH is None:
+ TRITON_RESERVED_BATCH = _TRITON_RESERVED_BATCH
+ _prefill_store_topk_kv_kernel[grid](
+ key=new_keys,
+ value=new_vals,
+ batch_mapping=batch_mapping,
+ num_tokens_to_retain=num_tokens_to_retain,
+ indices_topk=indices_topk,
+ bh_lens=bh_lens,
+ page_table=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ MAX_SEL=int(MAX_SEL),
+ HKV=H,
+ N_LOGICAL_PAGES_MAX=page_table.shape[2],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ K_TILE=K_TILE,
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
+ )
+ if PAD_TO_PAGE_SIZE:
+ assert cu_seqlens_k is not None
+ assert indices_topk.is_contiguous()
+ assert page_table.is_contiguous()
+ _prefill_store_topk_pad_kernel[(B, H)](
+ key=new_keys,
+ value=new_vals,
+ batch_mapping=batch_mapping,
+ num_tokens_to_retain=num_tokens_to_retain,
+ indices=indices_topk,
+ local_lens=bh_lens,
+ page_table_flat=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ cu_seqlens_k=cu_seqlens_k,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ MAX_SEL=int(MAX_SEL),
+ H=H, # type: ignore
+ N_LOGICAL_PAGES_MAX=page_table.shape[2], # type: ignore
+ D=D, # type: ignore
+ PAGE_SIZE=PAGE_SIZE, # type: ignore
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
+ )
+
+
+@triton.jit
+def _prefill_store_topk_pad_kernel(
+ key, # [N_total, H, D]
+ value, # [N_total, H, D]
+ batch_mapping, # [B] int32 (local b -> true batch)
+ num_tokens_to_retain, # [B] int32
+ indices, # [B, MAX_SEL] int32 (across all heads)
+ local_lens, # [B, H] int32 (contiguous)
+ page_table_flat, # [B_total*H*N_LOGICAL_PAGES_MAX] int32
+ k_cache,
+ v_cache, # [N_PAGES*PAGE_SIZE, D]
+ cu_seqlens_k,
+ sk_n,
+ sk_h,
+ sv_n,
+ sv_h,
+ MAX_SEL,
+ # Constexprs
+ H: tl.constexpr, # number of KV heads
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ b_local = tl.program_id(0)
+ h = tl.program_id(1)
+ offs_d = tl.arange(0, D)
+ L = tl.load(local_lens + b_local * H + h)
+ modulo_page_size = L - (L // PAGE_SIZE) * PAGE_SIZE
+ if modulo_page_size == 0:
+ return
+ need = PAGE_SIZE - modulo_page_size
+ b_true = tl.load(batch_mapping + b_local)
+ if b_true == TRITON_RESERVED_BATCH:
+ return
+ pt_base = (b_true * H + h) * N_LOGICAL_PAGES_MAX
+ written_tokens = 0
+ idx = tl.load(num_tokens_to_retain + b_local)
+ this_batch_ctx_len = tl.load(cu_seqlens_k + b_local + 1) - tl.load(
+ cu_seqlens_k + b_local
+ )
+ max_additional = this_batch_ctx_len - L
+ while (written_tokens < need and idx < MAX_SEL) and (
+ written_tokens < max_additional
+ ):
+ # candidate head
+ cand_idx = tl.load(indices + b_local * MAX_SEL + idx)
+ cand_h = cand_idx % H
+ if cand_h == h:
+ tok = cand_idx // H
+ pos = L + written_tokens
+ lp = pos // PAGE_SIZE
+ off = pos - lp * PAGE_SIZE
+ phys = tl.load(page_table_flat + pt_base + lp).to(tl.int32)
+
+ dst_row = phys * PAGE_SIZE + off
+ dst_off = dst_row.to(tl.int64) * D + offs_d
+
+ k_src = key + tok * sk_n + h * sk_h + offs_d
+ v_src = value + tok * sv_n + h * sv_h + offs_d
+
+ tl.store(
+ k_cache + dst_off,
+ tl.load(k_src),
+ )
+ tl.store(
+ v_cache + dst_off,
+ tl.load(v_src),
+ )
+
+ written_tokens += 1
+ idx += 1
+ tl.store(local_lens + b_local * H + h, L + written_tokens)
+
+
+@triton.jit
+def _prefill_store_all_kv_kernel(
+ key,
+ value, # [N, H, D] (D contiguous)
+ cu_seqlens_k, # [B + 1] int32
+ batch_mapping, # [B] int32 (local b -> true batch index)
+ bh_lens, # [B * HKV] int32 (UPDATED)
+ pt_flat, # [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
+ k_cache,
+ v_cache, # [N_PAGES * PAGE_SIZE, D]
+ # source strides (elements)
+ sk_n,
+ sk_h,
+ sv_n,
+ sv_h,
+ # constexpr
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ K_TILE: tl.constexpr, # number of (token, head) pairs processed per program
+):
+ pid_b = tl.program_id(0)
+ pid_blk = tl.program_id(1)
+
+ start = tl.load(cu_seqlens_k + pid_b)
+ end = tl.load(cu_seqlens_k + pid_b + 1)
+ num_toks_this_batch = end - start
+ if num_toks_this_batch <= 0:
+ return
+
+ total_elems = num_toks_this_batch * HKV
+
+ # base linear index in (token, head) grid for this program
+ base = pid_blk * K_TILE
+
+ offs_d = tl.arange(0, D)
+
+ # Iterate K_TILE elements in this tile
+ for i in tl.range(0, K_TILE):
+ idx = base + i
+ if idx < total_elems:
+ # map linear idx -> (t, h)
+ t = idx // HKV
+ h = idx - t * HKV
+
+ len_idx = pid_b * HKV + h
+ L0 = tl.load(bh_lens + len_idx)
+
+ token_idx_in_cache = L0 + t
+ lp = token_idx_in_cache // PAGE_SIZE # logical page
+ off_in_pg = token_idx_in_cache - lp * PAGE_SIZE # pos in page
+
+ # physical page
+ b_true = tl.load(batch_mapping + pid_b).to(tl.int32)
+ pt_base = (b_true * HKV + h) * N_LOGICAL_PAGES_MAX
+ phys = tl.load(pt_flat + pt_base + lp).to(tl.int64)
+
+ row = phys * PAGE_SIZE + off_in_pg
+ dst_off = row * D + offs_d
+
+ n_global = (start + t).to(tl.int64)
+
+ # Use strides for non-contiguous [N, H, D] (D stride == 1)
+ k_src = key + n_global * sk_n + h * sk_h + offs_d
+ v_src = value + n_global * sv_n + h * sv_h + offs_d
+
+ tl.store(k_cache + dst_off, tl.load(k_src))
+ tl.store(v_cache + dst_off, tl.load(v_src))
+
+
+def prefill_store_all_kv(
+ *,
+ new_keys: torch.Tensor,
+ new_values: torch.Tensor, # [N, H_kv, D]
+ cu_seqlens_k: torch.Tensor, # [B + 1] int32
+ max_seqlen_k: int,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ page_table: torch.Tensor, # [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
+ bh_lens: torch.Tensor, # [B, H_kv] int32 (UPDATED)
+ batch_mapping: torch.Tensor, # [B] int32 (local->true)
+ PAGE_SIZE: int,
+ K_TILE: int = 32, # how many (token, head) pairs per program
+):
+ assert new_keys.stride(-1) == 1 and new_values.stride(-1) == 1, (
+ "last dim must be contiguous"
+ )
+ assert page_table.is_contiguous(), "page table must be contiguous"
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous"
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous"
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+
+ N, HKV, D = new_keys.shape
+ B = batch_mapping.shape[0]
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+
+ sk_n, sk_h, _ = new_keys.stride()
+ sv_n, sv_h, _ = new_values.stride()
+ n_tiles = (max_seqlen_k * HKV + K_TILE - 1) // K_TILE
+ grid = (B, n_tiles)
+ _prefill_store_all_kv_kernel[grid](
+ new_keys,
+ new_values,
+ cu_seqlens_k,
+ batch_mapping,
+ bh_lens,
+ page_table,
+ k_cache,
+ v_cache,
+ sk_n=sk_n,
+ sk_h=sk_h,
+ sv_n=sv_n,
+ sv_h=sv_h,
+ HKV=HKV,
+ N_LOGICAL_PAGES_MAX=page_table.shape[-1],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ K_TILE=K_TILE,
+ )
+ bh_lens += cu_seqlens_k.diff()[:, None]
+
+
+@triton.jit
+def _decode_store_kv_kernel(
+ key,
+ value,
+ batch_mapping, # [B] int32
+ bh_lens, # [B*HKV] int32
+ page_table, # [B_total*HKV*N_LOGICAL_PAGES_MAX]
+ k_cache,
+ v_cache, # [N_PAGES*PAGE_SIZE, D]
+ sk_b,
+ sk_h,
+ sv_b,
+ sv_h,
+ HKV: tl.constexpr,
+ N_LOGICAL_PAGES_MAX: tl.constexpr,
+ D: tl.constexpr,
+ PAGE_SIZE: tl.constexpr,
+ TRITON_RESERVED_BATCH: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ h = tl.program_id(1)
+ mapped_b = tl.load(batch_mapping + pid_b)
+ if mapped_b == TRITON_RESERVED_BATCH:
+ return
+ offs_d = tl.arange(0, D)
+
+ length = tl.load(bh_lens + pid_b * HKV + h)
+ logical_page = length // PAGE_SIZE
+ internal_offset = length - logical_page * PAGE_SIZE
+
+ pt_base = (mapped_b * HKV + h) * N_LOGICAL_PAGES_MAX
+ physical_page = tl.load(page_table + pt_base + logical_page).to(tl.int64)
+
+ dst_row = physical_page * PAGE_SIZE + internal_offset
+
+ # Source addressing using strides (D stride == 1)
+ k_src = key + pid_b * sk_b + h * sk_h + offs_d
+ v_src = value + pid_b * sv_b + h * sv_h + offs_d
+
+ dst_off = dst_row * D + offs_d
+ tl.store(k_cache + dst_off, tl.load(k_src))
+ tl.store(v_cache + dst_off, tl.load(v_src))
+ tl.store(bh_lens + pid_b * HKV + h, length + 1)
+
+
+def decode_store_kv(
+ *,
+ key: torch.Tensor, # [B, HKV, D]
+ value: torch.Tensor, # [B, HKV, D]
+ batch_mapping: torch.Tensor, # [B] int32
+ bh_lens: torch.Tensor, # [B, HKV] or flattened [B*HKV] int32
+ page_table: torch.Tensor, # [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor, # [N_PAGES*PAGE_SIZE, D]
+ PAGE_SIZE: int,
+ TRITON_RESERVED_BATCH: int = None,
+):
+ assert key.shape == value.shape and key.ndim == 3, "key/value must be [B, HKV, D]"
+ B, HKV, D = key.shape
+ assert key.stride(-1) == 1 and value.stride(-1) == 1, (
+ "key/value last dim must be contiguous."
+ )
+ assert page_table.is_contiguous(), "page table must be contiguous."
+ assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
+ assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
+ assert k_cache.is_contiguous() and v_cache.is_contiguous()
+ assert (D & (D - 1)) == 0, "D must be a power of 2"
+ sk_b, sk_h, _ = key.stride()
+ sv_b, sv_h, _ = value.stride()
+ grid = (
+ int(batch_mapping.shape[0]),
+ HKV,
+ )
+ _decode_store_kv_kernel[grid](
+ key=key,
+ value=value,
+ batch_mapping=batch_mapping,
+ bh_lens=bh_lens,
+ page_table=page_table,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ sk_b=sk_b,
+ sk_h=sk_h,
+ sv_b=sv_b,
+ sv_h=sv_h,
+ HKV=HKV,
+ N_LOGICAL_PAGES_MAX=page_table.shape[2],
+ D=D,
+ PAGE_SIZE=PAGE_SIZE,
+ TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH
+ if TRITON_RESERVED_BATCH is not None
+ else _TRITON_RESERVED_BATCH,
+ )
diff --git a/vllm/kvprune_legacy_save/kv_cache/write_page_table.py b/vllm/kvprune_legacy_save/kv_cache/write_page_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99c4e1f566af65c4586c47c727ae671f9c801d7
--- /dev/null
+++ b/vllm/kvprune_legacy_save/kv_cache/write_page_table.py
@@ -0,0 +1,110 @@
+import torch
+import triton
+import triton.language as tl
+
+
+def scatter_to_page_table(
+ add_pages: torch.Tensor, # [L, H] int32
+ new_phys_pages: torch.Tensor, # [N]
+ curr_pages: torch.Tensor, # [L, H] int32
+ page_table: torch.Tensor, # [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
+ max_pages_per_head: int,
+):
+ """
+ Append newly allocated physical pages into a layered page table via Triton.
+ For each (layer ``l``, head ``h``):
+ Args:
+ :param add_pages:
+ Tensor of shape ``[L, H]`` (int32) indicating how many pages to
+ append for each (layer, head).
+ :param new_phys_pages:
+ 1D tensor of shape ``[N]`` (int32) containing physical page IDs
+ for all (layer, head) pairs, concatenated in row-major (L, H)
+ order. ``N`` must equal ``add_pages.sum()``.
+ :param curr_pages:
+ Tensor of shape ``[L, H]`` (int32) with the current logical page
+ counts per (layer, head) before this update.
+ :param page_table:
+ Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
+ the logical to physical page mapping. The last dimension is
+ logically indexed as logical_page ∈ [0, max_pages_per_head).
+ :param max_pages_per_head:
+ Maximum number of logical pages permitted per (layer, head). The
+ kernel skips writes beyond this bound.
+ Returns:
+ None. The function updates ``page_table`` in-place.
+ """
+ L, H = add_pages.shape
+ if L == 0 or H == 0:
+ return
+ add_flat = add_pages.to(torch.int32).contiguous().view(-1)
+ curr_flat = curr_pages.to(torch.int32).contiguous().view(-1)
+ cum_page_heads = torch.empty(L * H + 1, device="cuda", dtype=torch.int32)
+ cum_page_heads[0] = 0
+ torch.cumsum(add_flat, 0, out=cum_page_heads[1:])
+ stride_pl, stride_ph, stride_pp = page_table.stride()
+ grid = (L, H)
+ _scatter_pages_kernel_lh[grid](
+ add_flat,
+ cum_page_heads,
+ new_phys_pages,
+ curr_flat,
+ page_table,
+ stride_pl,
+ stride_ph,
+ stride_pp,
+ L=L,
+ H=H,
+ max_pages_per_head=max_pages_per_head,
+ )
+
+
+@triton.jit
+def _scatter_pages_kernel_lh(
+ add_pages, # int32 [L*H]
+ cum_page_heads, # int32 [L*H], base offset in flat_new_phys per (l,h)
+ flat_new_phys, # int32 [total_pages]
+ curr_pages, # int32 [L*H], existing logical pages per (l,h)
+ page_table_ptr, # int32* base pointer to page_table
+ stride_pl, # int, stride for layer dim
+ stride_ph, # int, stride for head dim
+ stride_pp, # int, stride for page dim
+ L: tl.constexpr,
+ H: tl.constexpr,
+ max_pages_per_head: tl.constexpr,
+):
+ layer_idx = tl.program_id(0)
+ h = tl.program_id(1)
+ if layer_idx >= L or h >= H:
+ return
+
+ lh = layer_idx * H + h
+ ap = tl.load(add_pages + lh)
+ if ap <= 0:
+ return
+
+ base = tl.load(cum_page_heads + lh)
+ cp = tl.load(curr_pages + lh)
+
+ # Append ap pages: logical pages [cp .. cp+ap)
+ for i in tl.range(0, ap):
+ phys = tl.load(flat_new_phys + base + i)
+ lp = cp + i
+ if lp < max_pages_per_head:
+ offset = layer_idx * stride_pl + h * stride_ph + lp * stride_pp
+ tl.store(page_table_ptr + offset, phys)
+
+
+# TODO: write reclaim kernel
+@triton.jit
+def reclaim_page_kernel():
+ pass
+
+
+def reclaim_pages(
+ batch_index: int,
+ bh_seq_lens: torch.Tensor,
+ bh_num_pages: torch.Tensor,
+ page_table: torch.Tensor,
+):
+ pass
diff --git a/vllm/kvprune_legacy_save/kvprune_to_vllm.md b/vllm/kvprune_legacy_save/kvprune_to_vllm.md
new file mode 100644
index 0000000000000000000000000000000000000000..e045d1a0bb4bd9d5e220ae02477c5081ee7d2c67
--- /dev/null
+++ b/vllm/kvprune_legacy_save/kvprune_to_vllm.md
@@ -0,0 +1,56 @@
+# KV-prune 与上游 vLLM 的集成说明
+
+本文说明:**剪枝/压缩(Compactor)功能**在「官网 vLLM 主仓库」里改动了哪些位置、是否只有少量文件、以及随 vLLM 版本升级时如何预期合并成本。
+
+## 1. 是否「仅仅」改了少数几个脚本?
+
+**核心运行时接线**确实集中在少数几个**非** `vllm/kvprune/` 下的文件;功能主体在 `vllm/kvprune/` 包内独立维护。
+
+| 路径 | 作用简述 |
+|------|-----------|
+| `vllm/env_override.py` | 在 `import vllm` 最早阶段设置与 kvprune 相关的默认环境变量(如 v1 多进程默认、压缩默认开关、可选释放 v1 KV 等)。 |
+| `vllm/__init__.py` | 对外导出 `CompressionParams`(懒加载至 `vllm.kvprune.integration.compression_params`)。 |
+| `vllm/entrypoints/llm.py` | `kvprune_compression` 参数、`generate(..., compression=...)`、v1 `enforce_eager` / `num_gpu_blocks_override` 策略、懒加载 compactor、委托 `compressed_generate`。 |
+| `vllm/v1/worker/gpu_worker.py` | `kvprune_v1_compressed_generate`:供 `collective_rpc` 调用的 TP 多卡压缩生成入口。 |
+| `tests/conftest.py` | 测试在导入 vLLM 前覆盖部分 `VLLM_KVPRUNE_*` 默认值,避免全量测试默认走压缩路径。 |
+| `vllm\vllm\envs.py` | envs.py 中对 VLLM_KVPRUNE_* 的集中注册 |
+
+**此外(可选/示例,非引擎必需):**
+
+- `examples/offline_inference/` 下若干 `*kvprune*` 示例脚本:演示用法,不参与核心引擎加载。
+
+**结论:**
+- **「官网 vLLM 主包」里与 kvprune 强相关的改动,主要就是上表 4 个文件 + 测试根配置**(若把测试也算进「集成面」,共 5 处常见提法)。
+- **算法、Compactor、TP 内嵌 runner 等**均在 `vllm/kvprune/`(及该目录下的 `integration/`)中,与上游 diff 相对隔离。
+
+## 2. 随 vLLM 版本更新,是否「很容易」同步剪枝压缩功能?
+
+**相对容易的部分:**
+
+- **集成面小**:合并冲突主要出现在上述少数文件,而不是遍布整个 executor / attention / model 层。
+- **逻辑内聚**:大量代码在 `vllm/kvprune/`,可整体移植或 `git` 三方合并时以子树为主处理。
+
+**仍需人工跟进的点(不能假设「自动无痛」):**
+
+- **`entrypoints/llm.py` 属于高频变更文件**:上游每次大版本可能重构 `LLM` 构造参数、`generate` 签名或引擎初始化;需要**逐次解决冲突**并回归压缩路径。
+- **`v1/worker/gpu_worker.py`** 同样会随 executor / RPC 接口变动;`collective_rpc` 方法名或 worker 基类若有变化,需对齐。
+- **`env_override.py`** 若上游调整导入顺序或新增全局默认环境变量,需避免覆盖冲突或行为打架。
+- **vLLM v1 内部 API**(如 `worker.get_model()`、`vllm_config` 结构)若变更,`vllm/kvprune/integration/*` 也可能要跟着改——这类改动**不在**「仅 5 个文件」里,但仍是**集成层**维护成本。
+
+**建议同步流程(简版):**
+
+1. 在新上游 tag 上先合并/应用 `vllm/kvprune/` 目录。
+2. 再手动合并上述 4 个主包文件 + `tests/conftest.py`。
+3. 跑与 kvprune 相关的测试与至少一条离线 `compression` 示例。
+4. 关注发行说明中 `LLM`、`EngineArgs`、`gpu_worker`、多进程默认的破坏性变更。
+
+## 3. 与「深度改内核」方案的区别
+
+当前设计**没有**在 `model_executor` 的统一注意力路径上大规模插入 kvprune 钩子(相关辅助逻辑主要在 `vllm/kvprune` 内部)。因此:
+
+- **上游同步时**,通常不必与 FlashAttention / 每层模型代码逐文件对打;
+- **代价是**:功能边界以「共享权重 + compactor 引擎 + 可选 TP RPC」为主,与「原生 KV 算子级一体化」的改动面不同。
+
+---
+
+*文档随仓库维护;若集成文件列表有增删,请同步更新本节表格。*
diff --git a/vllm/kvprune_legacy_save/layers/__init__.py b/vllm/kvprune_legacy_save/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b10a0da49360a96886ab956e04e8977f0ebd842f
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Layers from upstream compactor (attention, linear, MoE, …).
+
+Prefer importing concrete modules, e.g. ``from vllm.kvprune.layers.attention import ...``.
+"""
+
+__all__: list[str] = []
diff --git a/vllm/kvprune_legacy_save/layers/activation.py b/vllm/kvprune_legacy_save/layers/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a19e488cf3f5d25670fcdc8f4a17161ca64e1010
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/activation.py
@@ -0,0 +1,13 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class SiluAndMul(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # @torch.compile
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, y = x.chunk(2, -1)
+ return F.silu(x) * y
diff --git a/vllm/kvprune_legacy_save/layers/attention.py b/vllm/kvprune_legacy_save/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d4368bf53efd41a430be8770c4320d63389e225
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/attention.py
@@ -0,0 +1,212 @@
+from typing import Optional
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_varlen_func
+from torch import nn
+
+from vllm.kvprune.attention.fa_paged_bridge import (
+ flash_decode_from_paged,
+ flash_prefill_from_paged,
+)
+from vllm.kvprune.attention.sparse_decode_kernel import head_sparse_decode_attention
+from vllm.kvprune.attention.sparse_varlen_kernel import (
+ causal_sparse_varlen_with_cache,
+)
+from vllm.kvprune.compression.common import extract_and_store_top_kv
+from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule
+from vllm.kvprune.kv_cache.store_kv_cache import decode_store_kv, prefill_store_all_kv
+from vllm.kvprune.utils.context import Context, get_context
+from vllm.kvprune.utils.helpers import maybe_execute_in_stream
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ num_heads,
+ head_dim,
+ scale,
+ num_kv_heads,
+ ):
+ super().__init__()
+ self.num_heads: int = num_heads
+ self.head_dim = head_dim
+ self.scale: float = scale
+ self.num_kv_heads = int(num_kv_heads)
+
+ self.k_cache: Optional[torch.Tensor] = None
+ self.v_cache: Optional[torch.Tensor] = None
+ self.page_table: Optional[torch.Tensor] = None
+ self.bh_seq_lens: Optional[torch.Tensor] = None
+ self.page_size: Optional[int] = None
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ scores: Optional[torch.Tensor] = None,
+ ):
+ context: Context = get_context()
+ batch_mapping = context.batch_mapping
+ seq_lens = (
+ None
+ if self.bh_seq_lens is None
+ else self.bh_seq_lens.index_select(0, batch_mapping).contiguous()
+ )
+ sched = context.attention_schedule
+ use_triton_prefill_attn = (
+ sched == KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
+ )
+ use_fa_decode = sched == KvpruneAttentionSchedule.PDFA
+
+ if context.is_prefill:
+ seq_lens_copy = seq_lens.clone() if seq_lens is not None else None
+ if (
+ self.k_cache is not None
+ and context.do_compression
+ and scores is not None
+ ):
+ compression_context = context.compression_context
+ assert scores is not None
+ assert compression_context is not None
+ maybe_execute_in_stream(
+ extract_and_store_top_kv,
+ scores=scores,
+ cu_seqlens_k=context.cu_seqlens_k,
+ max_k_len=context.max_seqlen_k,
+ top_k=compression_context.max_tokens_to_retain,
+ H=int(self.num_kv_heads),
+ new_keys=k,
+ new_vals=v,
+ num_tokens_to_retain=compression_context.batch_tokens_to_retain,
+ page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ bh_lens=seq_lens,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ PAGE_SIZE=self.page_size,
+ PAD_TO_PAGE_SIZE=True,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+ elif self.k_cache is not None:
+ maybe_execute_in_stream(
+ prefill_store_all_kv,
+ new_keys=k,
+ new_values=v,
+ cu_seqlens_k=context.cu_seqlens_k,
+ max_seqlen_k=context.max_seqlen_k,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ page_table=self.page_table,
+ bh_lens=seq_lens,
+ batch_mapping=batch_mapping,
+ PAGE_SIZE=self.page_size,
+ STORE_STREAM=context.STORE_STREAM,
+ )
+
+ if use_triton_prefill_attn:
+ if context.do_compression and context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+ assert seq_lens_copy is not None
+ o = causal_sparse_varlen_with_cache(
+ q,
+ k,
+ v,
+ self.k_cache,
+ self.v_cache,
+ seq_lens_bh=seq_lens_copy,
+ global_page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_q=context.max_seqlen_q,
+ max_seqlen_k_cache=context.max_bh_len,
+ HKV=int(self.num_kv_heads),
+ PAGE_SIZE=self.page_size,
+ sm_scale=self.scale,
+ )
+ elif context.do_compression:
+ if context.STORE_STREAM is not None:
+ torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
+ assert seq_lens_copy is not None
+ o = flash_prefill_from_paged(
+ q,
+ k,
+ v,
+ self.k_cache,
+ self.v_cache,
+ seq_lens_bh_before=seq_lens_copy,
+ global_page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_q=context.max_seqlen_q,
+ PAGE_SIZE=self.page_size,
+ HKV=int(self.num_kv_heads),
+ sm_scale=self.scale,
+ )
+ else:
+ o = flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ max_seqlen_q=context.max_seqlen_q,
+ cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_k=context.max_seqlen_k,
+ cu_seqlens_k=context.cu_seqlens_k,
+ softmax_scale=self.scale,
+ causal=True,
+ )
+ else:
+ assert self.k_cache is not None, "KV Cache must be initialized for decoding"
+ decode_store_kv(
+ key=k,
+ value=v,
+ batch_mapping=batch_mapping,
+ bh_lens=seq_lens,
+ page_table=self.page_table,
+ k_cache=self.k_cache,
+ v_cache=self.v_cache,
+ PAGE_SIZE=self.page_size,
+ )
+
+ if use_fa_decode:
+ assert seq_lens is not None
+ o = flash_decode_from_paged(
+ q,
+ self.k_cache,
+ self.v_cache,
+ seq_lens_bh=seq_lens,
+ global_page_table=self.page_table,
+ batch_mapping=batch_mapping,
+ PAGE_SIZE=self.page_size,
+ HKV=int(self.num_kv_heads),
+ sm_scale=self.scale,
+ )
+ else:
+ o = head_sparse_decode_attention(
+ q,
+ self.k_cache,
+ self.v_cache,
+ seq_lens,
+ self.page_table,
+ batch_mapping,
+ int(self.num_kv_heads),
+ self.page_size,
+ self.scale,
+ key_split=context.key_split,
+ )
+ # Match compactor_vllm ``Attention``: ``index_copy_`` into the global
+ # ``bh_seq_lens`` table. The Triton masked copy was a CUDA fast path but
+ # disagreed with decode_store_kv / paged attention bookkeeping in edge
+ # cases and could leave lengths stale → garbage logits / immediate EOS.
+ if self.bh_seq_lens is not None:
+ longbm = batch_mapping.to(
+ device=self.bh_seq_lens.device, dtype=torch.long
+ )
+ maybe_execute_in_stream(
+ self.bh_seq_lens.index_copy_,
+ 0,
+ longbm,
+ seq_lens,
+ STORE_STREAM=context.STORE_STREAM if context.is_prefill else None,
+ )
+ return o
diff --git a/vllm/kvprune_legacy_save/layers/embed_head.py b/vllm/kvprune_legacy_save/layers/embed_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1c19ab17b708dd8d2b0c3e7f17e946cc1983ca
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/embed_head.py
@@ -0,0 +1,111 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from vllm.kvprune.utils.context import get_context
+from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
+from vllm.kvprune.utils.tp_utils import (
+ tensor_parallel_rank_for_sharding,
+ tensor_parallel_world_size_for_sharding,
+)
+from torch import nn
+
+
+class VocabParallelEmbedding(nn.Module):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ ):
+ super().__init__()
+ self.tp_rank = tensor_parallel_rank_for_sharding()
+ self.tp_size = tensor_parallel_world_size_for_sharding()
+ assert num_embeddings % self.tp_size == 0
+ self.num_embeddings = num_embeddings
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
+ self.weight = nn.Parameter(
+ torch.empty(self.num_embeddings_per_partition, embedding_dim)
+ )
+ self.weight.weight_loader = self.weight_loader
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(0)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor):
+ if self.tp_size > 1:
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
+ x = mask * (x - self.vocab_start_idx)
+ y = F.embedding(x, self.weight)
+ if self.tp_size > 1:
+ y = mask.unsqueeze(1) * y
+ tensor_parallel_all_reduce(y)
+ return y
+
+
+class ParallelLMHead(VocabParallelEmbedding):
+ """LM head with TP vocab sharding.
+
+ When embedded in a vLLM worker, logits must be gathered on the **tensor-
+ parallel** process group (see :func:`~vllm.distributed.communication_op.tensor_model_parallel_gather`),
+ not the default :func:`torch.distributed.gather` — otherwise shard order / group
+ mismatch yields garbage logits and decoded gibberish.
+
+ After gather, logits are truncated to ``org_vocab_size`` (HF tokenizer vocab),
+ matching :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
+ removal of padded vocabulary columns.
+ """
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ bias: bool = False,
+ *,
+ org_vocab_size: int | None = None,
+ ):
+ assert not bias
+ super().__init__(num_embeddings, embedding_dim)
+ # Original (unpadded) vocab size for logits truncation; defaults to num_embeddings.
+ self.org_vocab_size = (
+ int(org_vocab_size) if org_vocab_size is not None else num_embeddings
+ )
+
+ def forward(self, x: torch.Tensor):
+ context = get_context()
+ if context.is_prefill:
+ cu = context.cu_seqlens_q
+ last_indices = (cu[1:] - 1).to(torch.long)
+ n_tok = x.shape[0]
+ if n_tok > 0:
+ last_indices = last_indices.clamp(min=0, max=n_tok - 1)
+ x = x[last_indices].contiguous()
+ logits = F.linear(x, self.weight)
+ if self.tp_size > 1:
+ logits = self._gather_logits_tp(logits)
+ if logits is not None and logits.shape[-1] > self.org_vocab_size:
+ logits = logits[..., : self.org_vocab_size]
+ return logits
+
+ def _gather_logits_tp(self, logits: torch.Tensor) -> torch.Tensor | None:
+ try:
+ from vllm.distributed.parallel_state import model_parallel_is_initialized
+ from vllm.distributed.communication_op import (
+ tensor_model_parallel_gather,
+ )
+
+ if model_parallel_is_initialized():
+ return tensor_model_parallel_gather(logits, dst=0, dim=-1)
+ except Exception:
+ pass
+ all_logits = (
+ [torch.empty_like(logits) for _ in range(self.tp_size)]
+ if self.tp_rank == 0
+ else None
+ )
+ dist.gather(logits, all_logits, 0)
+ return torch.cat(all_logits, -1) if self.tp_rank == 0 else None
diff --git a/vllm/kvprune_legacy_save/layers/layernorm.py b/vllm/kvprune_legacy_save/layers/layernorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dabaad38ce9dec79b9e7c40a1405809c9235f3c
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/layernorm.py
@@ -0,0 +1,49 @@
+import torch
+from torch import nn
+
+
+class RMSNorm(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ eps: float = 1e-6,
+ ) -> None:
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+
+ # @torch.compile
+ def rms_forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ orig_dtype = x.dtype
+ x = x.float()
+ var = x.pow(2).mean(dim=-1, keepdim=True)
+ x.mul_(torch.rsqrt(var + self.eps))
+ x = x.to(orig_dtype).mul_(self.weight)
+ return x
+
+ # @torch.compile
+ def add_rms_forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ orig_dtype = x.dtype
+ x = x.float().add_(residual.float())
+ residual = x.to(orig_dtype)
+ var = x.pow(2).mean(dim=-1, keepdim=True)
+ x.mul_(torch.rsqrt(var + self.eps))
+ x = x.to(orig_dtype).mul_(self.weight)
+ return x, residual
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ return self.rms_forward(x)
+ else:
+ return self.add_rms_forward(x, residual)
diff --git a/vllm/kvprune_legacy_save/layers/linear.py b/vllm/kvprune_legacy_save/layers/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..be86096d2b1694c866f170fc572b6068511dee85
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/linear.py
@@ -0,0 +1,158 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
+from vllm.kvprune.utils.tp_utils import (
+ tensor_parallel_rank_for_sharding,
+ tensor_parallel_world_size_for_sharding,
+)
+from torch import nn
+
+
+def divide(numerator, denominator):
+ assert numerator % denominator == 0
+ return numerator // denominator
+
+
+class LinearBase(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ tp_dim: int | None = None,
+ ):
+ super().__init__()
+ self.tp_dim = tp_dim
+ self.tp_rank = tensor_parallel_rank_for_sharding()
+ self.tp_size = tensor_parallel_world_size_for_sharding()
+ self.weight = nn.Parameter(torch.empty(output_size, input_size))
+ self.weight.weight_loader = self.weight_loader
+ if bias:
+ self.bias = nn.Parameter(torch.empty(output_size))
+ self.bias.weight_loader = self.weight_loader
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class ReplicatedLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ super().__init__(input_size, output_size, bias)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param.data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.linear(x, self.weight, self.bias)
+
+
+class ColumnParallelLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ tp_size = tensor_parallel_world_size_for_sharding()
+ super().__init__(input_size, divide(output_size, tp_size), bias, 0)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(self.tp_dim)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.linear(x, self.weight, self.bias)
+
+
+class MergedColumnParallelLinear(ColumnParallelLinear):
+ def __init__(
+ self,
+ input_size: int,
+ output_sizes: list[int],
+ bias: bool = False,
+ ):
+ self.output_sizes = output_sizes
+ super().__init__(input_size, sum(output_sizes), bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int
+ ):
+ param_data = param.data
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_data.copy_(loaded_weight)
+
+
+class QKVParallelLinear(ColumnParallelLinear):
+ def __init__(
+ self,
+ hidden_size: int,
+ head_size: int,
+ total_num_heads: int,
+ total_num_kv_heads: int | None = None,
+ bias: bool = False,
+ ):
+ tp_size = tensor_parallel_world_size_for_sharding()
+ total_num_kv_heads = total_num_kv_heads or total_num_heads
+ self.head_size = head_size
+ self.num_heads = divide(total_num_heads, tp_size)
+ self.num_kv_heads = divide(total_num_kv_heads, tp_size)
+ output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
+ super().__init__(hidden_size, output_size, bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str
+ ):
+ param_data = param.data
+ assert loaded_shard_id in ["q", "k", "v"]
+ if loaded_shard_id == "q":
+ shard_size = self.num_heads * self.head_size
+ shard_offset = 0
+ elif loaded_shard_id == "k":
+ shard_size = self.num_kv_heads * self.head_size
+ shard_offset = self.num_heads * self.head_size
+ else:
+ shard_size = self.num_kv_heads * self.head_size
+ shard_offset = (
+ self.num_heads * self.head_size + self.num_kv_heads * self.head_size
+ )
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
+ param_data.copy_(loaded_weight)
+
+
+class RowParallelLinear(LinearBase):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ ):
+ tp_size = tensor_parallel_world_size_for_sharding()
+ super().__init__(divide(input_size, tp_size), output_size, bias, 1)
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ param_data = param.data
+ shard_size = param_data.size(self.tp_dim)
+ start_idx = self.tp_rank * shard_size
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
+ param_data.copy_(loaded_weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
+ if self.tp_size > 1:
+ tensor_parallel_all_reduce(y)
+ return y
diff --git a/vllm/kvprune_legacy_save/layers/moe.py b/vllm/kvprune_legacy_save/layers/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d487e650eb9b58396a6eb1998a7202c1baf3986
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/moe.py
@@ -0,0 +1,177 @@
+import torch
+import torch.distributed as dist
+from vllm.kvprune.triton_kernels.matmul_ogs import matmul_ogs
+from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
+from vllm.kvprune.utils.tp_utils import (
+ tensor_parallel_rank_for_sharding,
+ tensor_parallel_world_size_for_sharding,
+)
+from torch import nn
+
+
+def divide(numerator, denominator):
+ assert numerator % denominator == 0
+ return numerator // denominator
+
+
+class TritonFusedMoeLinearBase(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ tp_dim: int | None = None,
+ ) -> None:
+ super().__init__()
+ self.tp_dim = tp_dim
+ self.tp_rank = tensor_parallel_rank_for_sharding()
+ self.tp_size = tensor_parallel_world_size_for_sharding()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.num_experts = num_experts
+
+ self.weight = nn.Parameter(
+ torch.empty((num_experts, in_features, out_features)).transpose(-1, -2)
+ )
+ self.weight.weight_loader = self.weight_loader
+
+ if bias:
+ self.bias = nn.Parameter(torch.empty((num_experts, out_features)))
+ self.bias.weight_loader = self.weight_loader
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class ReplicatedTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ super().__init__(in_features, out_features, num_experts, bias)
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ param.data[expert_idx].copy_(loaded_weight, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ return matmul_ogs(
+ x,
+ self.weight,
+ self.bias,
+ **kwargs,
+ )
+
+
+class RowParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ tp_size = (
+ tensor_parallel_world_size_for_sharding()
+ if dist.is_initialized()
+ else 1
+ )
+ super().__init__(
+ divide(in_features, tp_size), out_features, num_experts, bias, 2
+ )
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ shard_size = param.size(2)
+ start_idx = self.tp_rank * shard_size
+ local_shard = loaded_weight[:, start_idx : start_idx + shard_size]
+ param.data[expert_idx].copy_(local_shard, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ y = matmul_ogs(
+ x,
+ w,
+ self.bias,
+ **kwargs,
+ )
+ if self.tp_size > 1:
+ tensor_parallel_all_reduce(y)
+ return y
+
+
+class ColumnParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ bias: bool = False,
+ ) -> None:
+ tp_size = (
+ tensor_parallel_world_size_for_sharding()
+ if dist.is_initialized()
+ else 1
+ )
+ super().__init__(
+ in_features, divide(out_features, tp_size), num_experts, bias, 1
+ )
+
+ def weight_loader(
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
+ ):
+ shard_size = param.size(1)
+ start_idx = self.tp_rank * shard_size
+ local_shard = loaded_weight[start_idx : start_idx + shard_size, :]
+ param.data[expert_idx].copy_(local_shard, non_blocking=True)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ w = self.weight.transpose(-1, -2)
+ assert w.is_contiguous()
+ y = matmul_ogs(
+ x,
+ w,
+ self.bias,
+ **kwargs,
+ )
+ return y
+
+
+class MergedColumnParallelTritonFusedMoeLinear(ColumnParallelTritonFusedMoeLinear):
+ def __init__(
+ self,
+ in_features: int,
+ out_feature_list: list[int],
+ num_experts: int,
+ bias: bool = False,
+ ):
+ self.out_feature_list = out_feature_list
+ super().__init__(in_features, sum(out_feature_list), num_experts, bias)
+
+ def weight_loader(
+ self,
+ param: nn.Parameter,
+ loaded_weight: torch.Tensor,
+ expert_idx: int,
+ shard_id: int,
+ ):
+ param_data = param.data
+ shard_offset = sum(self.out_feature_list[:shard_id]) // self.tp_size
+ shard_size = self.out_feature_list[shard_id] // self.tp_size
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
+ local_weight = loaded_weight.chunk(self.tp_size, dim=self.tp_dim - 1)[
+ self.tp_rank
+ ]
+ param_data[expert_idx].copy_(local_weight, non_blocking=True)
diff --git a/vllm/kvprune_legacy_save/layers/rotary_embedding.py b/vllm/kvprune_legacy_save/layers/rotary_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..582be894ee331bc106e27063488dc090b23f8e8d
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/rotary_embedding.py
@@ -0,0 +1,121 @@
+import math
+from functools import lru_cache
+from typing import Any
+
+import torch
+from torch import nn
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ x1, x2 = torch.chunk(x.float(), 2, dim=-1)
+ y1 = x1 * cos - x2 * sin
+ y2 = x2 * cos + x1 * sin
+ return torch.cat((y1, y2), dim=-1).to(x.dtype)
+
+
+def rope_theta_from_hf_config(config: Any) -> float:
+ """Match vLLM/HF: ``rope_theta`` may live only under ``rope_parameters`` in config.json."""
+ rp = getattr(config, "rope_parameters", None)
+ if isinstance(rp, dict) and "rope_theta" in rp:
+ return float(rp["rope_theta"])
+ return float(getattr(config, "rope_theta", 1_000_000.0))
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: float,
+ rope_scaling: tuple | None,
+ ) -> None:
+ super().__init__()
+ self.head_size = head_size
+ self.rotary_dim = rotary_dim
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)
+ )
+ if rope_scaling is not None:
+ (
+ rope_type,
+ factor,
+ low_freq_factor,
+ high_freq_factor,
+ original_max_position_embeddings,
+ ) = rope_scaling
+ assert rope_type == "llama3"
+ old_context_len = original_max_position_embeddings
+ low_freq_wavelen = old_context_len / low_freq_factor
+ high_freq_wavelen = old_context_len / high_freq_factor
+ wavelen = 2 * math.pi / inv_freq
+
+ inv_freq_llama = torch.where(
+ wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
+ )
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
+ high_freq_factor - low_freq_factor
+ )
+ smoothed_inv_freq = (
+ 1 - smooth_factor
+ ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(
+ wavelen > low_freq_wavelen
+ )
+ inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
+
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
+
+ # @torch.compile
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ cache_len = self.cos_sin_cache.shape[0]
+ # CUDA graph capture forbids device→CPU sync (e.g. ``.item()``) inside the
+ # captured region; :meth:`ModelRunner.capture_cudagraph` runs decode with
+ # placeholder positions. Skip the range check while capturing; eager runs
+ # still validate.
+ _capturing = (
+ torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()
+ )
+ if positions.numel() > 0 and not _capturing:
+ pmax = int(positions.max().item())
+ pmin = int(positions.min().item())
+ if pmax >= cache_len or pmin < 0:
+ raise ValueError(
+ f"RoPE positions out of range: need 0 <= pos < {cache_len}, "
+ f"got min={pmin}, max={pmax}. "
+ "Shorten the prompt or increase max_model_len (and align vLLM "
+ "RoPE cos_sin_cache with tie_kvprune_rope_buffers_from_vllm)."
+ )
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ query = apply_rotary_emb(query, cos, sin)
+ key = apply_rotary_emb(key, cos, sin)
+ return query, key
+
+
+@lru_cache(1)
+def get_rope(
+ head_size: int,
+ rotary_dim: int,
+ max_position: int,
+ base: float,
+ rope_scaling: tuple | None = None,
+):
+ rotary_emb = RotaryEmbedding(
+ head_size, rotary_dim, max_position, base, rope_scaling
+ )
+ return rotary_emb
diff --git a/vllm/kvprune_legacy_save/layers/sampler.py b/vllm/kvprune_legacy_save/layers/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0761b7c79bc7dc511180078c4d059c3423b3f8f
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/sampler.py
@@ -0,0 +1,27 @@
+import torch
+from torch import nn
+
+
+class Sampler(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # @torch.compile
+ def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
+ temps = temperatures.view(-1)
+ scaled = logits.float()
+
+ greedy_mask = temps == 0.0
+ sample_mask = ~greedy_mask
+
+ if sample_mask.any():
+ temps_sample = temps[sample_mask].unsqueeze(-1) # [B_sample, 1]
+ scaled_sample = scaled[sample_mask].div(temps_sample) # temperature scaling
+
+ E = torch.empty_like(scaled_sample).exponential_(1).clamp_min_(1e-10).log()
+ scaled_sample = scaled_sample - E
+
+ scaled = scaled.clone()
+ scaled[sample_mask] = scaled_sample
+
+ return scaled.argmax(dim=-1)
diff --git a/vllm/kvprune_legacy_save/layers/triton_helpers.py b/vllm/kvprune_legacy_save/layers/triton_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c1a31669bac31c9fcef53259f1211e6de19bc37
--- /dev/null
+++ b/vllm/kvprune_legacy_save/layers/triton_helpers.py
@@ -0,0 +1,101 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _masked_index_select_kernel(
+ X_ptr,
+ IDX_ptr,
+ OUT_ptr,
+ N,
+ stride_xn,
+ stride_xh,
+ stride_ob,
+ stride_oh,
+):
+ b = tl.program_id(0) # which output row (0..B-1)
+ h = tl.program_id(1)
+ idx = tl.load(IDX_ptr + b) # int32
+ valid = (idx >= 0) & (idx < N)
+ out_ptrs = OUT_ptr + b * stride_ob + h * stride_oh
+
+ if not valid:
+ tl.store(out_ptrs, 0)
+ else:
+ x_ptrs = X_ptr + idx * stride_xn + h * stride_xh
+ vals = tl.load(x_ptrs)
+ tl.store(out_ptrs, vals)
+
+
+def masked_index_select_triton_dim0(
+ input: torch.Tensor, index: torch.Tensor
+) -> torch.Tensor:
+ """
+ X: [N, H] : contiguous in the H dimension
+ b_m: [B] int32/int64 on same device; out-of-range -> zeros)
+ Returns: [B, H]
+ """
+ assert input.ndim == 2 and index.ndim == 1
+ N, H = input.shape
+ B = index.numel()
+ out = torch.empty((B, H), dtype=input.dtype, device=input.device)
+ _masked_index_select_kernel[(B, H)](
+ input,
+ index,
+ out,
+ N,
+ input.stride(0),
+ input.stride(1),
+ out.stride(0),
+ out.stride(1),
+ )
+ return out
+
+
+@triton.jit
+def _masked_index_copy_kernel(
+ DST_ptr,
+ IDX_ptr,
+ SRC_ptr,
+ N,
+ stride_dn,
+ stride_dh,
+ stride_sb,
+ stride_sh,
+):
+ b = tl.program_id(0)
+ h = tl.program_id(1)
+ idx = tl.load(IDX_ptr + b)
+ valid = (idx >= 0) & (idx < N)
+ if valid:
+ src_ptrs = SRC_ptr + b * stride_sb + h * stride_sh
+ dst_ptrs = DST_ptr + idx * stride_dn + h * stride_dh
+ tl.store(dst_ptrs, tl.load(src_ptrs))
+
+
+def masked_index_copy_triton_dim0(
+ dst: torch.Tensor, index: torch.Tensor, src: torch.Tensor
+):
+ """
+ In-place: dst.index_copy_(0, index, src) but masked:
+ - rows with index[b] < 0 or >= dst.shape[0] are skipped (no write).
+ Shapes:
+ dst: [N, H]
+ src: [B, H]
+ index: [B]
+ """
+ assert dst.ndim == 2 and src.ndim == 2 and index.ndim == 1
+ N, H = dst.shape
+ B, Hs = src.shape
+ assert Hs == H and index.numel() == B
+ _masked_index_copy_kernel[(B, H)](
+ dst,
+ index,
+ src,
+ N,
+ dst.stride(0),
+ dst.stride(1),
+ src.stride(0),
+ src.stride(1),
+ )
diff --git a/vllm/kvprune_legacy_save/models/__init__.py b/vllm/kvprune_legacy_save/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05e4040ba4095b8870b6471ac4cc9a7f041ee767
--- /dev/null
+++ b/vllm/kvprune_legacy_save/models/__init__.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import logging
+
+from vllm.kvprune.models.llama3 import LlamaForCausalLM
+from vllm.kvprune.models.qwen3 import Qwen3ForCausalLM
+
+logger = logging.getLogger(__name__)
+
+MODEL_REGISTRY = {
+ "llama": LlamaForCausalLM,
+ "qwen3": Qwen3ForCausalLM,
+}
+
+try:
+ from vllm.kvprune.models.qwen3_moe import Qwen3MoeForCausalLM
+except Exception as exc:
+ logger.warning("Disabling qwen3_moe due to import error: %s", exc)
+else:
+ MODEL_REGISTRY["qwen3_moe"] = Qwen3MoeForCausalLM
diff --git a/vllm/kvprune_legacy_save/models/llama3.py b/vllm/kvprune_legacy_save/models/llama3.py
new file mode 100644
index 0000000000000000000000000000000000000000..31a762c72736a44865d54e7f014393cc994e93de
--- /dev/null
+++ b/vllm/kvprune_legacy_save/models/llama3.py
@@ -0,0 +1,299 @@
+import os
+from glob import glob
+
+import torch
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import LlamaConfig
+
+from vllm.kvprune.compression import (
+ CompressionMethod,
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from vllm.kvprune.layers.activation import SiluAndMul
+from vllm.kvprune.layers.attention import Attention
+from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from vllm.kvprune.layers.layernorm import RMSNorm
+from vllm.kvprune.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.kvprune.layers.rotary_embedding import get_rope
+from vllm.kvprune.utils.context import get_context
+from vllm.kvprune.utils.tp_utils import tensor_parallel_world_size_for_sharding
+
+
+class LlamaAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: dict | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = tensor_parallel_world_size_for_sharding()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ if rope_scaling is not None:
+ rope_scaling_tuple = (
+ rope_scaling["rope_type"],
+ rope_scaling["factor"],
+ rope_scaling["low_freq_factor"],
+ rope_scaling["high_freq_factor"],
+ rope_scaling["original_max_position_embeddings"],
+ )
+ else:
+ rope_scaling_tuple = None
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling_tuple,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = q.view(-1, self.num_heads, self.head_dim)
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ cc = context.compression_context
+ if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
+ # 关键:注入 wo_weight 到 compression_context
+ wo_raw = self.o_proj.weight
+ hidden_size, _ = wo_raw.shape
+ Hq, D = self.num_heads, self.head_dim
+ cc.wo_weight = (
+ wo_raw.transpose(0, 1)
+ .contiguous()
+ .view(Hq, D, hidden_size)
+ .to(dtype=torch.float32)
+ )
+
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ mlp_bias: bool,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=mlp_bias,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=mlp_bias,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: LlamaConfig,
+ ) -> None:
+ super().__init__()
+ self.self_attn = LlamaAttention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ qkv_bias=getattr(config, "attention_bias", False),
+ head_dim=getattr(config, "head_dim", None),
+ rope_theta=getattr(config, "rope_theta", 500000.0),
+ rope_scaling=getattr(config, "rope_scaling", None),
+ )
+ self.mlp = LlamaMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ mlp_bias=config.mlp_bias,
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class LlamaModel(nn.Module):
+ def __init__(
+ self,
+ config: LlamaConfig,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(positions, hidden_states, residual)
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class LlamaForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(self, config: LlamaConfig) -> None:
+ super().__init__()
+ self.model = LlamaModel(config)
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ org_vocab_size=config.vocab_size,
+ )
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, positions)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", "cpu") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_loaded = False
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, loaded_weight: p.data.copy_(loaded_weight),
+ )
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/kvprune_legacy_save/models/qwen3.py b/vllm/kvprune_legacy_save/models/qwen3.py
new file mode 100644
index 0000000000000000000000000000000000000000..5053f71dd8425de4c5420acf1bcae2b55c368b94
--- /dev/null
+++ b/vllm/kvprune_legacy_save/models/qwen3.py
@@ -0,0 +1,296 @@
+import os
+from glob import glob
+
+import torch
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import Qwen3Config
+
+from vllm.kvprune.compression import (
+ CompressionMethod,
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from vllm.kvprune.layers.activation import SiluAndMul
+from vllm.kvprune.layers.attention import Attention
+from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from vllm.kvprune.layers.layernorm import RMSNorm
+from vllm.kvprune.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.kvprune.layers.rotary_embedding import get_rope, rope_theta_from_hf_config
+from vllm.kvprune.utils.context import get_context
+from vllm.kvprune.utils.tp_utils import tensor_parallel_world_size_for_sharding
+
+
+class Qwen3Attention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: tuple | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = tensor_parallel_world_size_for_sharding()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
+ k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ cc = context.compression_context
+ if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
+ # 关键:注入 wo_weight 到 compression_context
+ wo_raw = self.o_proj.weight
+ hidden_size, _ = wo_raw.shape
+ Hq, D = self.num_heads, self.head_dim
+ cc.wo_weight = (
+ wo_raw.transpose(0, 1)
+ .contiguous()
+ .view(Hq, D, hidden_size)
+ .to(dtype=torch.float32)
+ )
+
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class Qwen3MLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class Qwen3DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3Config,
+ ) -> None:
+ super().__init__()
+ head_dim = getattr(config, "head_dim", None)
+ if head_dim is None:
+ head_dim = config.hidden_size // config.num_attention_heads
+ rope_theta = rope_theta_from_hf_config(config)
+ rs = getattr(config, "rope_scaling", None)
+ rope_scaling_tuple: tuple | None = rs if isinstance(rs, tuple) else None
+ self.self_attn = Qwen3Attention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, "attention_bias", False),
+ head_dim=head_dim,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling_tuple,
+ )
+ self.mlp = Qwen3MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class Qwen3Model(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3Config,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(positions, hidden_states, residual)
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Qwen3ForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(self, config: Qwen3Config) -> None:
+ super().__init__()
+ self.model = Qwen3Model(config)
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ org_vocab_size=config.vocab_size,
+ )
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, positions)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", "cpu") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_loaded = False
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, loaded_weight: p.data.copy_(loaded_weight),
+ )
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/kvprune_legacy_save/models/qwen3_moe.py b/vllm/kvprune_legacy_save/models/qwen3_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..032ae3ed804f21db00040ee2e0f6e6d52919a283
--- /dev/null
+++ b/vllm/kvprune_legacy_save/models/qwen3_moe.py
@@ -0,0 +1,406 @@
+import os
+from glob import glob
+
+import torch
+import tqdm
+from safetensors import safe_open
+from torch import nn
+from transformers import Qwen3MoeConfig
+
+from vllm.kvprune.compression import (
+ CompressionMethod,
+ apply_postrope_compression,
+ apply_prerope_compression,
+)
+from vllm.kvprune.layers.activation import SiluAndMul
+from vllm.kvprune.layers.attention import Attention
+from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
+from vllm.kvprune.layers.layernorm import RMSNorm
+from vllm.kvprune.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.kvprune.layers.moe import (
+ MergedColumnParallelTritonFusedMoeLinear,
+ RowParallelTritonFusedMoeLinear,
+)
+from vllm.kvprune.layers.rotary_embedding import get_rope, rope_theta_from_hf_config
+from vllm.kvprune.triton_kernels.routing import routing
+from vllm.kvprune.utils.context import get_context
+from vllm.kvprune.utils.tp_utils import (
+ tensor_parallel_rank_for_sharding,
+ tensor_parallel_world_size_for_sharding,
+)
+
+
+class Qwen3MoeAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ rope_theta: float = 10000,
+ rope_scaling: tuple | None = None,
+ sliding_window: int | None = None,
+ ) -> None:
+ super().__init__()
+ tp_size = tensor_parallel_world_size_for_sharding()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ assert self.total_num_kv_heads % tp_size == 0
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.sliding_window = sliding_window
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ self.num_kv_heads,
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ context = get_context()
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
+ k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
+ scores = None
+ if context.is_prefill and context.do_compression:
+ scores = apply_prerope_compression(q, k, v, context)
+
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
+ q, k = self.rotary_emb(positions, q, k)
+
+ if context.is_prefill and context.do_compression:
+ cc = context.compression_context
+ if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
+ # 关键:注入 wo_weight 到 compression_context
+ wo_raw = self.o_proj.weight
+ hidden_size, _ = wo_raw.shape
+ Hq, D = self.num_heads, self.head_dim
+ cc.wo_weight = (
+ wo_raw.transpose(0, 1)
+ .contiguous()
+ .view(Hq, D, hidden_size)
+ .to(dtype=torch.float32)
+ )
+
+ scores = apply_postrope_compression(q, k, v, scores, context)
+
+ o = self.attn(q, k, v, scores)
+ output = self.o_proj(o.flatten(1, -1))
+ return output
+
+
+class Qwen3MoeMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ )
+ assert hidden_act == "silu"
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x = self.down_proj(x)
+ return x
+
+
+class Qwen3MoeTritonSparseMoeBlock(nn.Module):
+ def __init__(
+ self,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size: int,
+ num_experts_per_tok: int,
+ norm_topk_prob: bool,
+ hidden_act: str,
+ ) -> None:
+ super().__init__()
+ self.num_experts = num_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.norm_topk_prob = norm_topk_prob
+ self.hidden_size = hidden_size
+ self.moe_intermediate_size = intermediate_size
+
+ self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False)
+ self.gate_up_proj = MergedColumnParallelTritonFusedMoeLinear(
+ hidden_size, [intermediate_size] * 2, num_experts
+ )
+ self.down_proj = RowParallelTritonFusedMoeLinear(
+ intermediate_size, hidden_size, num_experts
+ )
+ self.act_fn = SiluAndMul()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ x = hidden_states
+ if x.numel() == 0:
+ return x
+ logits = self.gate(x)
+ rdata, gather_indx, scatter_indx = routing(
+ logits,
+ self.num_experts_per_tok,
+ simulated_ep=1, # single device, replicated experts
+ )
+ x = self.gate_up_proj(x, routing_data=rdata, gather_indx=gather_indx)
+ x = self.act_fn(x)
+ x = self.down_proj(
+ x, routing_data=rdata, scatter_indx=scatter_indx, gammas=rdata.gate_scal
+ )
+ return x
+
+
+class Qwen3MoeBlock(Qwen3MoeTritonSparseMoeBlock):
+ pass
+
+
+class Qwen3MoeRMSNorm(RMSNorm):
+ pass
+
+
+class Qwen3MoeDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ layer_idx: int,
+ ) -> None:
+ super().__init__()
+ head_dim = getattr(config, "head_dim", None)
+ if head_dim is None:
+ head_dim = config.hidden_size // config.num_attention_heads
+ rope_theta = rope_theta_from_hf_config(config)
+ rs = getattr(config, "rope_scaling", None)
+ rope_scaling_tuple: tuple | None = rs if isinstance(rs, tuple) else None
+ self.self_attn = Qwen3MoeAttention(
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position=config.max_position_embeddings,
+ head_dim=head_dim,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, "attention_bias", False),
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling_tuple,
+ sliding_window=config.sliding_window,
+ )
+ if (layer_idx not in config.mlp_only_layers) and (
+ config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
+ ):
+ self.mlp = Qwen3MoeBlock(
+ num_experts=config.num_experts,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ num_experts_per_tok=config.num_experts_per_tok,
+ norm_topk_prob=config.norm_topk_prob,
+ hidden_act=config.hidden_act,
+ )
+ else:
+ self.mlp = Qwen3MoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = Qwen3MoeRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = Qwen3MoeRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ # Self Attention
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(positions, hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Qwen3MoeModel(nn.Module):
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ ) -> None:
+ super().__init__()
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size, config.hidden_size
+ )
+ self.layers = nn.ModuleList(
+ [
+ Qwen3MoeDecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+ for decoder_layer in self.layers:
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_ids,
+ )
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+
+class Qwen3MoeForCausalLM(nn.Module):
+ packed_modules_mapping = {
+ "q_proj": ("qkv_proj", "q"),
+ "k_proj": ("qkv_proj", "k"),
+ "v_proj": ("qkv_proj", "v"),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ def __init__(
+ self,
+ config: Qwen3MoeConfig,
+ ) -> None:
+ super().__init__()
+ self.model = Qwen3MoeModel(config)
+ self.num_experts = config.num_experts
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ org_vocab_size=config.vocab_size,
+ )
+ if config.tie_word_embeddings:
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.model(input_ids, position_ids)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.lm_head(hidden_states)
+
+ def load_model(
+ self,
+ path: str,
+ *,
+ use_tqdm: bool = False,
+ ) -> None:
+ rank = tensor_parallel_rank_for_sharding()
+ device = torch.cuda.current_device() if torch.cuda.is_available() else rank
+ all_shards = glob(os.path.join(path, "*.safetensors"))
+ for file in (
+ tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
+ ):
+ with safe_open(file, "pt", f"cuda:{device}") as f:
+ for weight_name in f.keys():
+ weight_tensor = f.get_tensor(weight_name)
+ is_expert = "mlp.experts" in weight_name
+ is_loaded = False
+
+ # Process experts params name
+ if is_expert:
+ mlp_module_name, expert_module_name = weight_name.split(
+ ".experts."
+ )
+ expert_idx = int(expert_module_name.split(".")[0])
+ proj_name = expert_module_name.replace(f"{expert_idx}.", "")
+ weight_name = f"{mlp_module_name}.{proj_name}"
+
+ # Load packed modules
+ for k in self.packed_modules_mapping:
+ if k in weight_name:
+ v, shard_id = self.packed_modules_mapping[k]
+ param_name = weight_name.replace(k, v)
+ param = self.get_parameter(param_name)
+ weight_loader = getattr(param, "weight_loader")
+ if is_expert:
+ weight_loader(
+ param, weight_tensor, expert_idx, shard_id
+ )
+ else:
+ weight_loader(param, weight_tensor, shard_id)
+ is_loaded = True
+ break
+
+ # Load other modules
+ if not is_loaded:
+ param = self.get_parameter(weight_name)
+ weight_loader = getattr(
+ param,
+ "weight_loader",
+ lambda p, lw: p.data.copy_(lw, non_blocking=True),
+ )
+ if is_expert:
+ weight_loader(param, weight_tensor, expert_idx)
+ else:
+ weight_loader(param, weight_tensor)
+ is_loaded = True
+
+ assert is_loaded, f"Weight {weight_name} not loaded"
diff --git a/vllm/kvprune_legacy_save/triton_kernels/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d446c6d0a024357f4659d254e327eb4a00e23a1
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/__init__.py
@@ -0,0 +1,22 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Triton kernel utilities (matmul_ogs, MoE, topk, …) plus KV-facing entrypoints.
+
+For KV pruning attention/store, see also ``vllm.kvprune.attention`` and
+``vllm.kvprune.kv_cache``.
+"""
+
+from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
+from vllm.kvprune.kv_cache.store_kv_cache import (
+ decode_store_kv,
+ prefill_store_all_kv,
+ prefill_store_topk_kv,
+)
+
+__all__ = [
+ "causal_sparse_varlen_with_cache",
+ "decode_store_kv",
+ "prefill_store_all_kv",
+ "prefill_store_topk_kv",
+]
diff --git a/vllm/kvprune_legacy_save/triton_kernels/compaction.py b/vllm/kvprune_legacy_save/triton_kernels/compaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..21d471befd0d710f96f01882fa9e8b8a84059bd9
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/compaction.py
@@ -0,0 +1,76 @@
+import torch
+from .compaction_details._masked_compaction import _masked_compaction
+from .tensor import Bitmatrix
+
+
+def compaction(yv, yi, bitmask, sentinel=-1):
+ """
+ Return compacted copies of *yv* and *yi* based on a per-row bitmask.
+
+ Only the elements whose index appears among the active bits of *bitmask*
+ are kept; the rest are replaced by *sentinel*. Kept elements preserve
+ their original left-to-right order.
+
+ Parameters
+ ----------
+ yv : torch.Tensor, shape (B, K)
+ Values tensor.
+ yi : torch.Tensor, shape (B, K), dtype torch.long
+ Integer indices (0 ≤ index < 32) associated with *yv*.
+ bitmask : torch.Tensor, shape (B,) **or** (B, 32)
+ Per-row mask of active indices. See the in-place version for details.
+ sentinel : int, default -1
+ Value written into dropped positions of the returned tensors.
+
+ Returns
+ -------
+ (yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
+ New tensors with the same dtype/device as the inputs.
+
+ """
+
+ n_rows, n_cols = yi.shape
+ ret_yv = torch.empty_like(yv)
+ ret_yi = torch.empty_like(yi)
+ if isinstance(bitmask, Bitmatrix):
+ bitmask = bitmask.storage.data
+
+ _masked_compaction[(n_rows,)](
+ yv,
+ yi,
+ bitmask,
+ bitmask.stride(0),
+ bitmask.stride(1), # inputs
+ ret_yv,
+ ret_yi, # outputs
+ sentinel, # sentinel
+ K=n_cols, # constants
+ )
+ return ret_yv, ret_yi
+
+
+def compaction_torch(
+ yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1
+):
+ """
+ reference implementation of `masked_compact`
+ """
+ B, K = yi.shape
+ device = yi.device
+ # Expand bitmask to a boolean matrix of active bits (B, 32)
+ w = 1 << torch.arange(32, device=device, dtype=bitmask.dtype)
+ bits = (bitmask.unsqueeze(-1) & w) != 0
+ mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
+ # For every yi element decide whether it should be kept
+ keep = mask.gather(1, yi.long())
+ # Build a stable permutation that brings all "keep" items forward
+ # False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
+ order = (~keep).to(torch.int).argsort(dim=1, stable=True)
+ # Re‑order tensors according to above permutation
+ yi_sorted = yi.gather(1, order)
+ yv_sorted = yv.gather(1, order)
+ # fill relevant positions with sentinel
+ keep_sorted = keep.gather(1, order)
+ yi_sorted[~keep_sorted] = sentinel
+ yv_sorted[~keep_sorted] = sentinel
+ return yv_sorted, yi_sorted
diff --git a/vllm/kvprune_legacy_save/triton_kernels/compaction_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/compaction_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/compaction_details/_masked_compaction.py b/vllm/kvprune_legacy_save/triton_kernels/compaction_details/_masked_compaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..58fe2412cf19386dbbe73bea1a5daf75d464ffb2
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/compaction_details/_masked_compaction.py
@@ -0,0 +1,22 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _masked_compaction(
+ Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr
+):
+ pid_m = tl.program_id(0)
+ yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
+ yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
+ div = yi // 32
+ rem = yi % 32
+ active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
+ exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
+ active_flags = active_bits.to(tl.int1)
+ rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
+ write_indx = exc_cumsum + rev_arange
+ yv = tl.where(active_flags, yv, sentinel)
+ yi = tl.where(active_flags, yi, sentinel)
+ tl.store(RetYv + pid_m * K + write_indx, yv)
+ tl.store(RetYi + pid_m * K + write_indx, yi)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..23a681e67d8c92fe27a4df01eb23ac243451cb4e
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs.py
@@ -0,0 +1,609 @@
+# isort: off
+# fmt: off
+from dataclasses import dataclass
+import itertools
+import sys
+import torch
+import triton
+from enum import Enum, auto
+import math
+# utilities
+from vllm.kvprune.triton_kernels import target_info
+from vllm.kvprune.triton_kernels.numerics import InFlexData, OutFlexData
+from vllm.kvprune.triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
+from vllm.kvprune.triton_kernels.target_info import is_cuda
+# details
+from .matmul_ogs_details._matmul_ogs import _matmul_ogs
+from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
+from .matmul_ogs_details._reduce_grouped import _reduce_grouped
+from .numerics_details.mxfp import MXFP_BLOCK_SIZE
+from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
+from .specialize import specialize
+from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
+
+
+@dataclass(frozen=True)
+class FnSpecs:
+ name: str
+ fn: "triton.runtime.jit.JITFunction"
+ fn_arg_names: tuple[str]
+ fn_arg_do_not_specialize: tuple[str] = tuple()
+
+ @staticmethod
+ def default():
+ return FnSpecs("dflt", None, tuple())
+
+
+@dataclass(frozen=True)
+class FusedActivation:
+ specs: FnSpecs = FnSpecs.default()
+ fn_args: tuple[object] = tuple()
+ reduction_n: int = 1
+
+
+@dataclass(frozen=True)
+class Epilogue:
+ specs: FnSpecs = FnSpecs.default()
+ fn_arg_values_matmul: tuple[object] = tuple()
+ fn_arg_values_finalize: tuple[object] = tuple()
+ effective_itemsize: float = None
+
+class FnName(Enum):
+ QUANTIZE_MXFP8 = auto()
+
+
+EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
+
+_kernels = dict()
+
+
+def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
+ global _kernels
+ key = (fused_activation.name, epilogue.name)
+ if key in _kernels:
+ return _kernels[key]
+ spec_constants = {
+ "ACTIVATION_FN": fused_activation.fn,
+ "EPILOGUE_FN": epilogue.fn,
+ }
+ spec_tuples = {
+ "activation_fn_args": fused_activation.fn_arg_names,
+ "epilogue_fn_args": epilogue.fn_arg_names,
+ }
+ do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
+ import types
+
+ module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
+ sys.modules[module.__name__] = module
+ module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ module._reduce_grouped = specialize(_reduce_grouped, module, spec_constants, spec_tuples,
+ do_not_specialize=do_not_specialize)
+ _kernels[key] = module
+ return module
+
+
+# -----------------------------------------------------------------------------
+# Matrix Multiplication + Outer Gather/Scatter
+# -----------------------------------------------------------------------------
+
+
+def can_overflow_int32(tensor: torch.Tensor):
+ max_int32 = (1 << 31) - 1
+ offset = 0
+ for i in range(tensor.ndim):
+ offset += (tensor.shape[i] - 1) * tensor.stride(i)
+ return offset > max_int32
+
+
+def should_upcast_indices(*args):
+ return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
+
+
+# ---------------------
+# Numerics
+# ---------------------
+
+# fmt: off
+
+@dataclass(frozen=True)
+class FlexCtx:
+ lhs_data: InFlexData = InFlexData()
+ rhs_data: InFlexData = InFlexData()
+ out_data: OutFlexData = OutFlexData()
+
+@dataclass
+class PrecisionConfig:
+ max_num_imprecise_acc: int = None
+ allow_tf32: bool = True
+ flex_ctx: FlexCtx = FlexCtx()
+ acc_scale: int = 1.0
+ flexpoint_saturate_inf: bool = False
+ report_quantization_err_fn: callable = None
+ act_scale: Tensor | None = None
+ weight_scale: Tensor| None = None
+ out_scale: Tensor | None = None
+ out_dtype: torch.dtype = None
+ enforce_bitwise_invariance: bool = False
+
+
+# TODO: merge in opt_flags
+def get_swap_xw(precision_config, opt_flags):
+ if target_info.cuda_capability_geq(10, 0):
+ return precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
+ return False
+
+# ---------------------
+# Allocation
+# ---------------------
+
+@dataclass
+class MatmulAllocation:
+ device: str
+ output: tuple[tuple[int], torch.dtype]
+ scratchpads: dict[str, tuple]
+
+def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags):
+ # ---- output ------
+ N = w.shape[-1]
+ # by default - M is number of rows in the activations
+ M = x.shape[-2]
+ # if the activations are gathered, then M is number of gather indices
+ if gather_indx is not None:
+ M = gather_indx.src_indx.shape[0]
+ # final output
+ if routing_data.n_expts_act == 1 or scatter_indx is None:
+ y_rows = M
+ else:
+ Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
+ y_rows = Mc
+ batch_dim = x.shape[0] if x.ndim == 3 else 1
+ out_shape = (batch_dim, y_rows, N // fused_activation.reduction_n)
+ out_dtype = precision_config.out_dtype or x.dtype
+ output = (out_shape, out_dtype)
+ # ---- scratchpad -----#
+ scratchpad = dict()
+ if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
+ scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
+ scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), scratch_out_dtype)
+ if "matmul" in scratchpad and precision_config.out_scale is not None:
+ scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
+ return MatmulAllocation(x.device, output, scratchpad)
+
+def apply_allocation(allocation: MatmulAllocation, output):
+ ret = dict()
+ if output is None:
+ output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
+ else:
+ assert output.shape == allocation.output[0]
+ ret["output"] = output[None, :, :]
+ ret["scratchpad"] = {
+ k: torch.empty(v[0], device=allocation.device, dtype=v[1])
+ for k, v in allocation.scratchpads.items()
+ }
+ return ret
+
+# -----------------------------------------------------------------------------
+# Canonicalize
+# -----------------------------------------------------------------------------
+# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
+# we can canonicalize storages to make the implementation more uniform
+
+def _canonicalize_storage(storage, out_ndim, flex_data):
+ assert out_ndim >= storage.data.ndim
+ # Need to use as_strided instead of view because for a tensor with
+ # shape[-2] == 1 can have ambuiguity related to col-wise. Fo example,
+ # > t = torch.randn(2, 5, 1).mT
+ # > t_view = t.view(t.shape)
+ # > t.stride(), t_view.stride()
+ # ((5, 1, 1), (5, 5, 1))
+ # Our check t_view is col-wise fails since t_view.stride(-2) != 1
+ # This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py
+ new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape)
+ new_storage_view = storage.data.view(new_storage_shape)
+ new_storage_stride = [new_storage_view.stride(0)] * (out_ndim - storage.data.ndim) + list(storage.data.stride())
+ new_storage_data = storage.data.as_strided(new_storage_shape, new_storage_stride)
+ if flex_data is not None:
+ new_storage_data = flex_data.reinterpret(new_storage_data)
+ return Storage(new_storage_data, storage.layout)
+
+#
+
+def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_mx_scale: torch.Tensor,
+ fused_activation, epilogue,
+ x_flex: InFlexData | None = None,
+ out_flex: OutFlexData | None = None, x_mx_scale: torch.Tensor | None = None,
+ out_dtype: bool = None, flexpoint_saturate_inf: bool = False):
+ """
+ In-place grouped row reduction.
+
+ Arguments
+ - x: Tensor[AnyFloat] of shape [(num_groups * K), N]
+ - indx: Tensor[Int] of shape [num_groups, K]
+
+ Description
+ For each group g in [0, num_groups), this routine sums the K rows of `x`
+ specified by `indx[g, :]` and overwrites the row corresponding to the first
+ valid (non-negative) index with the per-group sum. Accumulation is performed
+ in float32 for numerical stability, and the result is written back in the
+ dtype of `x`.
+
+ Behavior and edge cases
+ - Invalid (-1) entries are skipped during accumulation and do not generate
+ memory traffic. If a group has no valid entries, nothing is written for
+ that group.
+ - Reduction is performed tile-by-tile along the N dimension within a single
+ kernel launch (persistent along N) to minimize launch overhead.
+
+ Performance notes
+ - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x),
+ plus index reads. With no invalid entries, this becomes (K + 1) reads/writes
+ of length N per group.
+
+ Returns
+ - The input tensor `x` (modified in place).
+ """
+ if indx is None and x.shape[0] == 1:
+ return x.squeeze(0), None
+ if indx is not None:
+ num_groups = indx.shape[0]
+ else:
+ num_groups = x.shape[-2]
+ if x_flex is None:
+ x_flex = InFlexData()
+ if out_flex is None:
+ out_flex = OutFlexData()
+ K = 1 if indx is None else indx.shape[1]
+ out_dtype = x.dtype if out_dtype is None else out_dtype
+ assert x.shape[-1] % fused_activation.reduction_n == 0
+ BLOCK_N = 512
+ # Resolve scalar flex scales (may be None)
+ x_expected_scale = None if x_flex is None else x_flex.scale
+ out_expected_scale = None if out_flex is None else out_flex.expected_scale
+ out_actual_scale = None if out_flex is None else out_flex.actual_scale
+ out_checksum_scale = None if out_flex is None else out_flex.checksum_scale
+ # Resolve MXFP output scale row stride
+ stride_mxb = 0 if x_mx_scale is None else x_mx_scale.stride(0)
+ stride_mxs = 0 if x_mx_scale is None else x_mx_scale.stride(1)
+ stride_omxs = 0 if out_mx_scale is None else out_mx_scale.stride(0)
+ kernels = get_kernels(epilogue.specs, fused_activation.specs)
+ kernels._reduce_grouped[(num_groups, )](
+ x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), #
+ x_expected_scale, # scalar input scale
+ out_flex.reinterpret(out), out.stride(1), out.stride(2), #
+ out_expected_scale, out_actual_scale, out_checksum_scale, indx, #
+ x.shape[0], x.shape[-1], #
+ x_mx_scale, stride_mxb, stride_mxs, #
+ out_mx_scale, stride_omxs, #
+ *fused_activation.fn_args, fused_activation.reduction_n,
+ *epilogue.fn_arg_values_finalize,
+ HAS_IN_MX_SCALE=x_mx_scale is not None, HAS_OUT_MX_SCALE=out_mx_scale is not None,
+ FLEXPOINT_SATURATE_INF=flexpoint_saturate_inf, #
+ BLOCK_N=BLOCK_N, K=K, #
+ num_warps=1, #
+ )
+ return out, out_mx_scale
+
+# -----------------------------------------------------------------------------
+# Triton Implementation
+# -----------------------------------------------------------------------------
+
+def matmul_ogs_set_idle_sms(num_idle_sms):
+ """
+ persistent kernels will leave `num_idle_sms` idle
+ """
+ update_opt_flags_constraints({"idle_sms": num_idle_sms})
+
+def matmul_ogs(x, w, bias,
+ routing_data: RoutingData | None = None,
+ gather_indx: GatherIndx | None = None,
+ scatter_indx: ScatterIndx | None = None,
+ precision_config: PrecisionConfig | None = None,
+ betas: torch.Tensor | None = None,
+ gammas: torch.Tensor | None = None,
+ out_alpha: float | None = None,
+ y: torch.Tensor | None = None,
+ fused_activation: FusedActivation | None = None,
+ epilogue: Epilogue | None = None,
+ ):
+ """
+ Y[:, :] = 0.
+ for e in num_experts:
+ Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
+ """
+ is_input_batched = x.ndim == 3
+ if is_input_batched:
+ assert gather_indx is None, "gather not supported in batched mode"
+ assert scatter_indx is None, "scatter not supported in batched mode"
+ assert routing_data is None, "routing not supported in batched mode"
+ assert w.ndim == 3 and w.shape[0] == x.shape[0]
+ # canonicalize inputs
+ if precision_config is None:
+ precision_config = PrecisionConfig()
+ if fused_activation is None:
+ fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
+ if epilogue is None:
+ epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
+ if routing_data is None:
+ routing_data = RoutingData(None, None, max(1, w.shape[0]), 1)
+ # unpack scales
+ w_scale = precision_config.weight_scale
+ w_has_mx = w_scale is not None
+ is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
+ if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
+ if not isinstance(w, Tensor):
+ # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
+ dtype = FP4 if w.dtype == torch.uint8 else w.dtype
+ w = wrap_torch_tensor(w, dtype=dtype)
+ if w_scale is not None and not isinstance(w_scale, Tensor):
+ w_scale = Tensor(w_scale)
+ if w_scale is not None:
+ w_scale.storage.data = w_scale.data.view(torch.uint8)
+ w_scale.dtype = torch.uint8
+ x_scale = precision_config.act_scale
+ x_has_mx = x_scale is not None
+ if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp"
+ if x_scale is not None and not isinstance(x_scale, Tensor):
+ x_scale = Tensor(x_scale)
+ if not isinstance(x, Tensor):
+ x = Tensor(x, dtype=x.dtype)
+ # determine shapes
+ has_gather = gather_indx is not None
+ has_scatter = scatter_indx is not None
+ is_ragged = routing_data.expt_hist is not None
+ M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
+ batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
+ K, N = w.shape[-2:]
+ assert K == x.shape[-1]
+ if x.ndim == 3 and w.ndim == 3:
+ assert x.shape[0] == w.shape[0]
+ # compute optimization flags
+ out_dtype = precision_config.out_dtype or x.dtype
+ can_use_tma = x.numel() > 0 and x.storage.is_tma_compliant() and \
+ w.numel() > 0 and w.storage.is_tma_compliant() and \
+ (w_scale is None or w_scale.storage.is_tma_compliant())
+ # hopper w/ mxfp4 doesn't support TMA
+ can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
+ can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
+ opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
+ M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
+ )
+ if not can_use_fused_scatter and opt_flags.fused_scatter:
+ raise InapplicableConstraint("Fused scatter is not supported")
+ if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
+ raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
+ if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
+ raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
+ # fused activation
+ matmul_fused_activation = fused_activation
+ reduce_fused_activation = FusedActivation()
+ if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
+ matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation
+ # allocate output/scratchpad memory
+ allocation = init_allocation(x, w, precision_config, fused_activation,
+ routing_data, gather_indx, scatter_indx, opt_flags)
+ memory = apply_allocation(allocation, y)
+ # early exit
+ if batch_size * M * N == 0:
+ ret = memory["output"].squeeze(0)
+ if not is_input_batched:
+ ret = ret.squeeze(0)
+ return ret
+ # TMA descriptors require a global memory allocation
+ if opt_flags.is_persistent:
+ triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
+ # Intermediate tensors and postprocess kernels for each situation
+ has_scratchpad = "matmul" in memory["scratchpad"]
+ # Canonical output tensor (matmul scratchpad if present, otherwise final output tensor)
+ out_matmul = memory["scratchpad"].get("matmul", memory["output"])
+ out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data
+ # Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer
+ out_matmul_scale = precision_config.out_scale
+ if out_matmul_scale is not None:
+ out_matmul_scale = out_matmul_scale.data.view(torch.uint8)
+ if has_scratchpad and "mx_out_scale" in memory["scratchpad"]:
+ out_matmul_scale = memory["scratchpad"]["mx_out_scale"]
+ out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1
+ # matrix multiplication
+ flex = precision_config.flex_ctx
+ bias_stride = None if bias is None else bias.stride(0)
+ num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
+ # moe metadata
+ expt_data = routing_data.expt_data
+ block_m = opt_flags.block_m
+ expt_hist = None if expt_data is None else expt_data.hist
+ expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
+ expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
+ expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
+ # spmd grid
+ grid_m = triton.cdiv(M, opt_flags.block_m)
+ if expt_block_pid_map is not None:
+ grid_m = routing_data.n_blocks(M, opt_flags.block_m)
+ grid_n = triton.cdiv(N, opt_flags.block_n)
+ max_grid = batch_size * grid_m * grid_n * opt_flags.split_k
+ grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
+ # canonicalize storage
+ has_gather_tma = has_gather and target_info.has_tma_gather()
+ has_scatter_tma = opt_flags.fused_scatter and target_info.has_tma_gather()
+ y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if opt_flags.fused_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:]))
+ x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
+ w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
+ y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)
+ # create tma descriptor for x
+ x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
+ x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k]
+ x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
+ x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data
+ # create tma descriptor for y
+ y_has_tma = opt_flags.is_persistent and (has_scatter_tma or not opt_flags.fused_scatter)
+ block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.reduction_n
+ y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n]
+ y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
+ y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data
+ # create tma descriptor for w
+ w_has_tma = opt_flags.is_persistent
+ w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
+ # create tma descriptor for w_scale
+ w_scale_tensor_or_tma = w_scale
+ w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
+ w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale
+ # canonicalize strides
+ x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
+ x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
+ x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
+ w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
+ w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
+ out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None)
+ out_matmul_scale_strides = (0, ) * (3 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
+ # launch kernel
+ kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs)
+ # When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
+ # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
+ # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
+ # w_transpose = w_storage.data.stride()[-1] != 1
+ w_transpose = w_storage.data.stride()[-2] == 1
+ (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
+ y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
+ *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
+ *out_matmul_scale_strides[-3:],
+ x_tensor_or_tma, x_storage.data, *x_strides,
+ flex.lhs_data.scale,
+ None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
+ w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,
+ flex.rhs_data.scale,
+ w_scale_tensor_or_tma, *w_scale_strides,
+ bias, bias_stride,
+ x.shape[-2],
+ x.shape[-2] if routing_data.expt_hist is None else None,
+ N, K,
+ betas, gammas,
+ None if gather_indx is None else gather_indx.src_indx,
+ None if scatter_indx is None else scatter_indx.src_indx,
+ num_indx,
+ None if not opt_flags.fused_scatter else scatter_indx.dst_indx,
+ None if not opt_flags.fused_scatter else scatter_indx.dst_indx.shape[0],
+ expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
+ batch_size, grid_m, grid_n,
+ out_alpha,
+ *matmul_fused_activation.fn_args, matmul_fused_activation.reduction_n,
+ *epilogue.fn_arg_values_matmul,
+ routing_data.n_expts_tot, routing_data.n_expts_act,
+ precision_config.max_num_imprecise_acc,
+ precision_config.allow_tf32,
+ precision_config.flexpoint_saturate_inf,
+ flex.rhs_data.is_per_batch,
+ opt_flags.block_m,
+ opt_flags.block_n,
+ opt_flags.block_k,
+ opt_flags.group_m,
+ XCD_SWIZZLE=opt_flags.xcd_swizzle,
+ SWIZZLE_MX_VALUE=w.storage.layout.name,
+ SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name,
+ EPILOGUE_SUBTILE=opt_flags.epilogue_subtile,
+ SPLIT_K=opt_flags.split_k,
+ EVEN_K=K % opt_flags.block_k == 0,
+ W_CACHE_MODIFIER=opt_flags.w_cache_modifier,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt,
+ num_warps=opt_flags.num_warps,
+ num_stages=opt_flags.num_stages,
+ arch=opt_flags.arch,
+ UPCAST_INDICES=should_upcast_indices(x, w, out_matmul),
+ X_TMA_MODE=x_tma_mode,
+ Y_TMA_MODE=y_tma_mode,
+ SWAP_XW=get_swap_xw(precision_config, opt_flags),
+ IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name,
+ NUM_SMS = grid if opt_flags.is_persistent else 0,
+ **opt_flags.target_kernel_kwargs)
+ # Build grouped reduction inputs in a uniform way
+ group_indx = None if scatter_indx is None or opt_flags.fused_scatter else scatter_indx.src_indx.view(-1, routing_data.n_expts_act)
+ out_final, out_final_mx_scale = reduce_grouped(
+ out_matmul,
+ group_indx,
+ memory["output"].squeeze(0),
+ precision_config.out_scale,
+ reduce_fused_activation,
+ epilogue,
+ x_flex=InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale),
+ out_flex=precision_config.flex_ctx.out_data,
+ x_mx_scale=out_matmul_scale.squeeze(1) if out_matmul_has_mx else None,
+ out_dtype=memory["output"].dtype,
+ flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf,
+ )
+ if not is_input_batched:
+ out_final = out_final.squeeze(0)
+ if out_final_mx_scale is not None:
+ precision_config.out_scale = out_final_mx_scale
+ return out_final
+
+# -----------------------------------------------------------------------------
+# Reference Implementation
+# -----------------------------------------------------------------------------
+
+def matmul_ogs_torch(x, w, bias,
+ routing_data: RoutingData = None,
+ gather_indx: GatherIndx = None,
+ scatter_indx: ScatterIndx = None,
+ precision_config: PrecisionConfig = None,
+ betas = None,
+ gammas = None,
+ round_x = None, round_y = None,
+ ):
+ is_input_batched = x.ndim == 3
+ assert x.dtype.itemsize > 1
+ assert w.dtype.itemsize > 1
+ if is_input_batched:
+ assert gather_indx is None, "gather not supported in batched mode"
+ assert scatter_indx is None, "scatter not supported in batched mode"
+ assert routing_data is None, "routing not supported in batched mode"
+ assert w.ndim == 3 and w.shape[0] == x.shape[0]
+ if round_x is None:
+ round_x = lambda x, idx: x
+ if round_y is None:
+ round_y = lambda x: x
+ if bias is not None and bias.ndim == 1:
+ bias = bias.view(1, *bias.shape)
+ if w.ndim == 2:
+ w = w.view(1, *w.shape)
+ if x.ndim == 2:
+ x = x.view(1, *x.shape)
+ if routing_data is None:
+ routing_data = RoutingData(None, None, w.shape[0], 1)
+ n_expts_act = routing_data.n_expts_act
+ # memory offsets
+ if routing_data.n_expts_tot > 1 and not is_input_batched:
+ sizes = routing_data.expt_hist
+ off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32)
+ off[1:] = torch.cumsum(sizes, 0)
+ offs = list(itertools.pairwise(off))
+ else:
+ offs = [[0, x.shape[1]] for _ in range(w.shape[0])]
+ # compute
+ n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0]
+ y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype)
+ for i, (lo, hi) in enumerate(offs):
+ if gather_indx is None:
+ idx = torch.arange(lo, hi, device=x.device)
+ else:
+ idx = gather_indx.src_indx[lo:hi] // n_expts_act
+ batch = i if is_input_batched else 0
+ out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
+ w[i].float())
+ if bias is not None:
+ out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
+ if gammas is not None:
+ out *= gammas[lo:hi, None]
+ y[batch, lo:hi, :] = round_y(out)
+ if not is_input_batched:
+ y = y.view(y.shape[1], y.shape[2])
+ if scatter_indx is None:
+ return y
+ # accumulate output from all experts
+ n_rows = y.shape[0] // n_expts_act
+ out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device)
+ for i, (lo, hi) in enumerate(offs):
+ dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act
+ msk = dst_idx != -1
+ out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float()
+ return out
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_common.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d5c99493872d779643aff2a9f7293685d8c4f2b
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_common.py
@@ -0,0 +1,179 @@
+import torch
+
+import triton
+import triton.language as tl
+
+# -----------------------------------------------------------------------------
+# Utilities
+# -----------------------------------------------------------------------------
+
+
+@triton.constexpr_function
+def get_scaled_dot_format_string(dtype: tl.dtype):
+ mapping = {
+ tl.float16: "fp16",
+ tl.bfloat16: "bf16",
+ tl.uint8: "e2m1",
+ tl.float8e4nv: "e4m3",
+ tl.float8e5: "e5m2",
+ }
+ return mapping[dtype]
+
+
+@triton.jit
+def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr):
+ """
+ Swizzle the program id based on integer XCD_SWIZZLE.
+ This is useful for reording how blocks are ordered. A scheduler may, for example,
+ assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2.
+ This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment
+ becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to
+ the same hardware unit.
+ """
+ # Number of pids per group in the new arrangement
+ pids_per_group = domain_size // XCD_SWIZZLE
+ extra_pid_groups = domain_size % XCD_SWIZZLE
+
+ # Compute current current and local pid within the group
+ group = pid % XCD_SWIZZLE
+ local_pid = pid // XCD_SWIZZLE
+
+ # Calculate new pid based on the new grouping
+ new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid
+ return new_pid
+
+
+@triton.jit
+def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr):
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ tl.assume(group_size >= 0)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ return pid_m, pid_n
+
+
+def make_matmul_repr(base_name, order):
+ def matmul_repr(specialization):
+ signature = specialization.signature
+ constants = specialization.constants
+ reorder = lambda L: [L[i] for i in order]
+ layout = lambda stride: "N" if stride in constants else "T"
+
+ def convert_dtype(dtype):
+ if "tensordesc" in dtype:
+ ret = convert_dtype(dtype.split("<")[1].split("[")[0])
+ return ret
+ elif "u8" in dtype:
+ return "mxfp4"
+ elif dtype[0] == "*":
+ return dtype[1:]
+ else:
+ return dtype
+
+ dtypes = "x".join(
+ [convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])]
+ )
+ layouts = "".join(
+ [
+ f"{layout(i)}"
+ for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])
+ ]
+ )
+ blocks = "x".join(
+ [f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]]
+ )
+ # mode = []
+ # if "GatherIndx" not in constants:
+ # mode += ['g']
+ # if "ScatterSrcIndx" not in constants:
+ # mode += ['s']
+ # suffix = "" if not mode else "_o" + (''.join(mode))
+ # if base_name.startswith("_p"):
+ # suffix += "_ptma"
+ return f"{base_name}_{layouts}_{dtypes}_{blocks}"
+
+ return matmul_repr
+
+
+def matmul_launch_metadata(grid, kernel, args):
+ from ..proton_opts import launch_metadata_allow_sync
+
+ ret = dict()
+ M, N, K = args["M"], args["N"], args["K"]
+ Y, X, W = args["YPtr"], args["XPtr"], args["WPtr"]
+ tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
+ hist = args["ExptHist"]
+ if hist is not None:
+ # If annotation is given, use that to generate name for profiling.
+ if tokens_per_expt is not None:
+ n_rows = f"{tokens_per_expt}*"
+ elif launch_metadata_allow_sync():
+ n_rows = int(hist.float().mean())
+ else:
+ n_rows = "unknown"
+
+ if launch_metadata_allow_sync():
+ n_tokens = float(hist.sum())
+ n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (
+ hist > 0
+ ).sum()
+ elif tokens_per_expt is not None:
+ n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
+ # This may not be totally correct (e.g., we might not be using all experts)
+ # but it's better than nothing.
+ n_w_bytes = W.numel() * W.element_size()
+ else:
+ n_tokens = None
+ n_w_bytes = 0
+
+ # If annotation is given, use that to generate name for profiling.
+ tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
+ n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
+ else:
+ n_tokens = None
+ n_w_bytes = W.numel() * W.element_size()
+ repr = (
+ lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}"
+ )
+ nbits = X.dtype.itemsize * 8
+ batch_repr = ""
+ if "batch_size" in args and args["batch_size"] > 1:
+ batch_repr = repr("B", args["batch_size"]) + ", "
+ ret["name"] = (
+ f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}"
+ )
+ ep_subtile = args["EPILOGUE_SUBTILE"]
+ if ep_subtile is not None and ep_subtile > 1:
+ ret["name"] += f" ep/{ep_subtile}"
+
+ if hist is not None and n_tokens is None:
+ return ret # Don't fill metadata because we can't compute them properly.
+
+ fM = M if M is not None else n_tokens
+ fK = K if K is not None else n_tokens
+ ret[f"flops{nbits}"] = 2.0 * fM * N * fK
+
+ gindx = args.get("GatherIndx", None)
+ # sindx = args.get("WriteBackIndx", None)
+ n_x_bytes = X.numel() * X.element_size()
+ n_y_bytes = Y.numel() * Y.element_size()
+ if hist is not None:
+ assert n_tokens is not None
+ n_expts_act = args["N_EXPTS_ACT"]
+
+ if (gindx is not None) and launch_metadata_allow_sync():
+ # recreate inverse GatherIndx.
+ dst = torch.full_like(gindx, -1)
+ idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
+ mask = gindx != -1
+ dst[gindx[mask]] = idx[mask]
+ n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
+ else:
+ n_read_rows = n_tokens
+ n_x_bytes = n_read_rows * X.shape[-1] * X.element_size()
+ n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size()
+ ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes)
+
+ return ret
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..eed22c1549eac7863910cf08ffb4b8e4a5851932
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_matmul_ogs.py
@@ -0,0 +1,429 @@
+# isort: off
+# fmt: off
+import triton
+import triton.language as tl
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.cdna4_scale import unswizzle_mx_scale_cdna4
+from vllm.kvprune.triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
+from vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
+from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
+
+
+@triton.jit
+def _zero_masked_rows(
+ pid_m, pid_n,
+ Y, stride_y_m, stride_y_n,
+ N,
+ ScatterSrcIndx, num_idxs,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
+ offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M)
+ offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
+ src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
+ YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
+ mask_n = offs_n < N
+ mask = (src_idx == -1)[:, None] & mask_n[None, :]
+ tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask)
+
+
+_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
+@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
+ repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
+def _matmul_ogs(
+ Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
+ YExpectedScale, YActualScale, YChecksumScale,
+ stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
+ X, XPtr, stride_x_z, stride_x_m, stride_x_k,
+ XScale,
+ XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
+ W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
+ WScale,
+ WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
+ B, stride_b_e, # Bias
+ NRows, M, N, K, # shapes
+ # expt data
+ Betas, Gammas,
+ GatherIndx,
+ ScatterSrcIndx, num_idxs,
+ WriteBackIndx, writeback_size,
+ ExptHist, ExptOffs, ExptOffsSum, ExptData,
+ # true grid size
+ batch_size, grid_m, grid_n,
+ # Out scale
+ out_alpha,
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
+ # MoE config
+ N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
+ # precision config
+ MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ PER_BATCH_SCALE: tl.constexpr,
+ # optimization config
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
+ # One of ["HOPPER", "BLACKWELL", None]
+ SWIZZLE_MX_VALUE: tl.constexpr,
+ # One of ["HOPPER", "BLACKWELL", None]
+ SWIZZLE_MX_SCALE: tl.constexpr,
+ EPILOGUE_SUBTILE: tl.constexpr,
+ EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
+ W_CACHE_MODIFIER: tl.constexpr,
+ NUM_SMS: tl.constexpr,
+ X_TMA_MODE: tl.constexpr,
+ Y_TMA_MODE: tl.constexpr,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=None,
+ UPCAST_INDICES: tl.constexpr = False,
+ SWAP_XW: tl.constexpr = False,
+ IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
+
+ tl.assume(stride_y_k >= 0)
+ tl.assume(stride_y_z >= 0)
+ tl.assume(stride_y_m >= 0)
+ tl.assume(stride_y_n >= 0)
+ tl.assume(stride_x_z >= 0)
+ tl.assume(stride_x_m >= 0)
+ tl.assume(stride_x_k >= 0)
+ tl.assume(stride_w_e >= 0)
+ tl.assume(stride_w_k >= 0)
+ tl.assume(stride_w_n >= 0)
+ if stride_w_mx_e is not None:
+ tl.assume(stride_w_mx_e >= 0)
+ if stride_w_mx_k is not None:
+ tl.assume(stride_w_mx_k >= 0)
+ if stride_w_mx_n is not None:
+ tl.assume(stride_w_mx_n >= 0)
+ if B is not None:
+ tl.assume(stride_b_e >= 0)
+ tl.assume(batch_size >= 0)
+ tl.assume(grid_m >= 0)
+ tl.assume(grid_n >= 0)
+
+ is_w_microscaled: tl.constexpr = WMxScale is not None
+ MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
+ if is_w_microscaled:
+ w_type: tl.constexpr = W.dtype.element_ty
+ is_mxfp4: tl.constexpr = w_type == tl.uint8
+ tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
+ "mx_weight_ptr must be uint8 or fp8")
+ tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
+ else:
+ tl.static_assert(SWIZZLE_MX_VALUE is None)
+ tl.static_assert(SWIZZLE_MX_SCALE is None)
+ is_x_microscaled: tl.constexpr = XMxScale is not None
+ if is_x_microscaled:
+ x_type: tl.constexpr = X.dtype.element_ty
+ tl.static_assert(is_w_microscaled)
+ tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
+ tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ is_out_microscaled: tl.constexpr = stride_y_mx_z is not None
+
+ OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
+ yN = N // ACTIVATION_REDUCTION_N
+
+ pid = tl.program_id(0)
+ if ExptOffsSum is not None and XCD_SWIZZLE > 1:
+ # Determine how much padding there is on the expert data. This allows us to
+ # know the true grid size and avoid processing padding tiles.
+ padding_m = grid_m - tl.load(ExptOffsSum)
+ else:
+ padding_m: tl.constexpr = 0
+
+ HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
+ index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32
+
+ unpadded_m = grid_m - padding_m
+ tl.assume(unpadded_m >= 0)
+ total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K
+ if padding_m > 0 and pid >= total_actual_tiles:
+ tl.device_assert(batch_size == 0)
+ pid_mn = pid - total_actual_tiles
+ if pid_mn < padding_m * grid_n:
+ pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M)
+
+ # set masked out rows to 0
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
+ _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
+ return
+
+ # swizzle program ids
+ pid_emnk = pid
+ if XCD_SWIZZLE != 1:
+ pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE)
+ pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K)
+ pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K)
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ pid_m, pid_n = swizzle2d(pid_mn, unpadded_m, grid_n, GROUP_M)
+ # For split-k, advance to the output k slice
+ if SPLIT_K > 1:
+ Y += pid_k.to( index_type) * stride_y_k
+ if is_out_microscaled:
+ YActualScale += pid_k.to(index_type) * stride_x_mx_k
+ # set masked out rows to 0
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
+ _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
+ # unpack expert data
+ if ExptData is None:
+ tl.static_assert(M is not None)
+ expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m
+ else:
+ tl.static_assert(M is None)
+ expt_data = tl.load(ExptData + pid_m)
+ if expt_data == -1:
+ return
+ expt_id = expt_data & 0x0000FFFF
+ block_id = expt_data >> 16
+ M = tl.load(ExptHist + expt_id)
+ start_m = tl.load(ExptOffs + expt_id)
+ start_z = 0
+ expt_id, block_id = expt_id.to(index_type), block_id.to(index_type)
+ start_m, start_z = start_m.to(index_type), start_z.to(index_type)
+ pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type)
+ # A pointers
+ offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
+ offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M)
+ X += start_z * stride_x_z
+ if GatherIndx is None:
+ X += start_m * stride_x_m
+ else:
+ GatherIndx += start_m
+ # no needs to bounds-check here because `offs_x_m` wraps around M dim
+ offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT
+ offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K)
+ XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k
+
+ # TODO: refactor if/else when triton front end improves
+ if is_w_microscaled:
+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
+ tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
+ tl.static_assert(not is_x_microscaled)
+ # We have pack 2 fp4 values in a byte but we divide the dimension by 2
+ # when swizzling
+ W_K_DIVISOR: tl.constexpr = 1
+ W_K_MULTIPLIER: tl.constexpr = 2
+ W_N_DIVISOR: tl.constexpr = 4
+ else:
+ # We have pack 2 fp4 values in a byte
+ W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1
+ W_K_MULTIPLIER: tl.constexpr = 1
+ W_N_DIVISOR: tl.constexpr = 1
+
+ if W_TRANSPOSE:
+ # When weight is transposed, 2 fp4 values are packed per Byte along
+ # the contiguous dimension, K.
+ PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR
+ else:
+ # When weight is not transposed, fp4 values are *not* packed along
+ # the contiguous dimension, N.
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_K_DIVISOR
+ MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
+
+ WMxScale += expt_id * stride_w_mx_e
+
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ # TODO: support non W_TRANSPOSE with blackwell swizzling
+ tl.static_assert(W_TRANSPOSE)
+ tl.static_assert(BLOCK_N % 128 == 0)
+ tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0)
+ PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 128
+ stride_scale_k: tl.constexpr = 1
+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
+ # TODO: support non W_TRANSPOSE with Hopper swizzling
+ tl.static_assert(W_TRANSPOSE)
+ n_warps: tl.constexpr = tl.extra.cuda.num_warps()
+ tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0)
+ tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0)
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32
+ stride_scale_k = stride_w_mx_k
+ elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
+ tl.static_assert(stride_w_mx_k is not None)
+ tl.static_assert(stride_w_mx_n is not None)
+ NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE
+ stride_scale_k = stride_w_mx_k
+ else:
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N
+ stride_scale_k = stride_w_mx_k
+ offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
+ offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
+ # K dimension must be the last dimension for the scales
+ offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
+ WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
+ else:
+ WMxScalePtrs = None
+ offs_k_scale = None
+ W_K_DIVISOR: tl.constexpr = 1
+ W_K_MULTIPLIER: tl.constexpr = 1
+ W_N_DIVISOR: tl.constexpr = 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N
+
+ # B pointers
+ offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
+ offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
+
+ if is_x_microscaled:
+ XMxScale += start_z.to(index_type) * stride_x_mx_z
+ if GatherIndx is None:
+ XMxScale += start_m * stride_x_mx_m
+ offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
+ XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
+ else:
+ XMxScalePtrs = None
+
+ offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W)
+ W += expt_id * stride_w_e
+ WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n)
+ # compute output
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)):
+ if EVEN_K:
+ mask_k = tl.full([BLOCK_K], True, dtype=tl.int1)
+ mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1)
+ if is_w_microscaled and SWIZZLE_MX_SCALE is None:
+ mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1)
+ if is_x_microscaled:
+ mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
+ else:
+ mask_k = offs_k < k
+ mask_k_w = offs_w_k < ((k // (W_K_DIVISOR if W_TRANSPOSE else 1)) * W_K_MULTIPLIER)
+ if is_w_microscaled and SWIZZLE_MX_SCALE is None:
+ mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k
+ if is_x_microscaled:
+ mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < k
+
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+ w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER)
+ if is_w_microscaled:
+ x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
+ w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
+
+ if is_x_microscaled:
+ x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :])
+ elif x_format == "fp16" or x_format == "bf16":
+ x_scales: tl.constexpr = None
+ else:
+ # Scale of 1 in E8M0 format
+ x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8)
+
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ w_scales = unswizzle_mx_scale_bw(tl.load(WMxScalePtrs))
+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
+ # Handshake with the swizzling code
+ num_warps: tl.constexpr = tl.extra.cuda.num_warps()
+ w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps)
+ elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
+ w_scales = unswizzle_mx_scale_cdna4(tl.load(WMxScalePtrs), BLOCK_N, MX_SCALE_BLOCK_K)
+ else:
+ w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :])
+
+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
+ # Handshake with the swizzling code
+ tl.static_assert(x_format == "bf16")
+ tl.static_assert(w_format == "e2m1")
+ w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1)
+ tl.static_assert(w.dtype == tl.bfloat16)
+ acc = acc.trans()
+ x = x.trans()
+ # w = w.trans()
+ acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ acc = acc.trans()
+ else:
+ rhs_k_pack: tl.constexpr = W_TRANSPOSE or not is_w_microscaled or W_K_DIVISOR != 2
+ acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True, rhs_k_pack=rhs_k_pack)
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ WMxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_w_mx_k
+ else:
+ WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k
+ if is_x_microscaled:
+ XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k
+ else:
+ acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k
+ WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k
+ # bias + scale
+ offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
+ offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
+ mask_m = offs_m < M
+ mask_n = offs_y_n < N
+ if B is not None:
+ BPtrs = B + expt_id * stride_b_e + offs_y_n
+ if pid_k == 0:
+ bias = tl.load(BPtrs, mask=mask_n, other=0)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ if Betas is not None:
+ betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0)
+ else:
+ betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ if Gammas is not None:
+ gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0)
+ else:
+ gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ # flexpoint
+ x_scale = load_scale(XScale)
+ if PER_BATCH_SCALE:
+ w_scale = load_scale(WScale + expt_id)
+ else:
+ w_scale = load_scale(WScale)
+ acc *= x_scale * w_scale
+ acc = acc + bias[None, :] * betas[:, None]
+ if out_alpha is not None:
+ acc *= out_alpha
+ if ACTIVATION_FN is not None:
+ out = ACTIVATION_FN(acc, *activation_fn_args)
+ tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
+ offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N)
+ mask_n = offs_y_n < yN
+ else:
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
+ out = acc
+ out *= gammas[:, None]
+ # write-back
+ Y += start_z.to(index_type) * stride_y_z
+ if WriteBackIndx is not None:
+ WriteBackIndx += start_m
+ dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1)
+ mask_m = mask_m & (dst_idx != -1)
+ offs_y_m = dst_idx
+ else:
+ Y += start_m * stride_y_m
+ offs_y_m = offs_m
+
+ YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n
+ mask = mask_m[:, None] & mask_n[None, :]
+ if is_out_microscaled:
+ MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
+ N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
+ tl.static_assert(EPILOGUE_FN is not None)
+ out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
+ tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
+ offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N)
+ mask_n_scale = offs_y_n_scale < N_MX_BLOCK
+ YActualScale += start_z.to(index_type) * stride_y_mx_z
+ if WriteBackIndx is None:
+ YActualScale += start_m * stride_y_mx_m
+ YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
+ else:
+ YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
+ tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
+ else:
+ out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)
+ if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8:
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty)
+ tl.store(YPtrs, out, mask=mask)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
new file mode 100644
index 0000000000000000000000000000000000000000..53ea66bac907c39483929aa68421b844589c3573
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
@@ -0,0 +1,471 @@
+# isort: off
+# fmt: off
+import torch
+import triton
+import triton.language as tl
+from triton.tools.ragged_tma import load_ragged, store_ragged
+from vllm.kvprune.triton_kernels import target_info
+from vllm.kvprune.triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
+from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
+ float_to_flex,
+ load_scale,
+ nan_propagating_absmax_reduce,
+ compute_scale,
+)
+from vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
+from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
+
+
+@triton.constexpr_function
+def cuda_capability_geq(major, minor):
+ return target_info.cuda_capability_geq(major, minor)
+
+@triton.constexpr_function
+def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
+ if isinstance(tensor_or_desc, tl.tensor):
+ return tensor_or_desc.dtype.element_ty
+ elif isinstance(tensor_or_desc, tl.tensor_descriptor):
+ return tensor_or_desc.dtype
+ else:
+ raise ValueError(f"Invalid type: {type(tensor_or_desc)}")
+
+@triton.jit
+def _load_tile_attrs(
+ tile_id, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr):
+ # unpack and swizzle program ids
+ pid_emnk = tile_id
+ if XCD_SWIZZLE != 1:
+ pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE)
+ pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
+ pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
+ if SPLIT_K > 1:
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ else:
+ pid_k: tl.constexpr = 0
+ pid_mn = pid_mnk
+ pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
+
+ # unpack expert data
+ if ExptData is None:
+ tl.static_assert(M is not None)
+ expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1
+ else:
+ tl.static_assert(M is None)
+ expt_data = tl.load(ExptData + pid_m)
+ expt_id = expt_data & 0x0000FFFF
+ block_id = expt_data >> 16
+ eM = tl.load(ExptHist + expt_id)
+ start_m = tl.load(ExptOffs + expt_id)
+ start_z = 0
+
+ off_m = BLOCK_M * block_id
+ off_n = BLOCK_N * pid_n
+
+ return expt_id, start_z, start_m, eM, off_m, off_n, pid_k
+
+@triton.jit
+def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
+ mask = mask & (offs < writeback_size)
+ offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1)
+ mask = offs != -1
+ return (offs, mask)
+
+
+_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2])
+@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
+ repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
+def _p_matmul_ogs(
+ Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
+ YExpectedScale, YActualScale, YChecksumScale,
+ stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
+ X, XPtr, stride_x_z, stride_x_m, stride_x_k,
+ XScale,
+ XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
+ W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
+ WScale,
+ MxScale, stride_mx_e, stride_mx_k, stride_mx_n,
+ B, stride_b_e, # Bias
+ NRows, M, N, K, # shapes
+ # expt data
+ Betas, Gammas,
+ GatherIndx,
+ ScatterSrcIndx, num_idxs,
+ WriteBackIndx, writeback_size,
+ ExptHist, ExptOffs, ExptOffsSum, ExptData,
+ # true grid size
+ batch_size, grid_m, grid_n,
+ # Out scale
+ out_alpha,
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
+ # MoE config
+ N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
+ # precision config
+ MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ PER_BATCH_SCALE: tl.constexpr,
+ # optimization config
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
+ # NYI: Must be None
+ SWIZZLE_MX_VALUE: tl.constexpr,
+ # One of ["BLACKWELL", None]
+ SWIZZLE_MX_SCALE: tl.constexpr,
+ EPILOGUE_SUBTILE: tl.constexpr,
+ EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
+ W_CACHE_MODIFIER: tl.constexpr,
+ NUM_SMS: tl.constexpr,
+ X_TMA_MODE: tl.constexpr,
+ Y_TMA_MODE: tl.constexpr,
+ TOKENS_PER_EXPT_FOR_ANNOTATION=None,
+ UPCAST_INDICES:tl.constexpr=False,
+ SWAP_XW: tl.constexpr = False,
+ IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
+ # tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")
+
+ # why is this faster than using host-side tensor descriptor?!
+ if Y_TMA_MODE is not None:
+ Y = tl.make_tensor_descriptor(YPtr, Y.shape, Y.strides[:-1] + (1,), Y.block_shape)
+
+ is_microscaled_format: tl.constexpr = MxScale is not None
+ tl.static_assert(not is_microscaled_format or W_TRANSPOSE, "NYI. Non-transposed mxfp4 weights")
+ MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
+ if is_microscaled_format:
+ w_type: tl.constexpr = get_dtype(W)
+ tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
+ "mx_weight_ptr must be uint8")
+ tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8")
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
+ tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
+
+ # We have pack 2 fp4 values in a byte
+ W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR
+ MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
+ else:
+ W_PACK_DIVISOR: tl.constexpr = 1
+ MX_SCALE_BLOCK_K: tl.constexpr = 1
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
+ tl.static_assert(SWIZZLE_MX_SCALE is None)
+
+ if ExptOffsSum is not None:
+ # Determine how much padding there is on the expert data. This allows us to
+ # know the true grid size and avoid processing padding tiles.
+ padding_m = grid_m - tl.load(ExptOffsSum)
+ else:
+ padding_m: tl.constexpr = 0
+
+ index_type: tl.constexpr = tl.int64
+
+ USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None
+ HAS_SCATTER: tl.constexpr = WriteBackIndx is not None
+ HAS_GATHER: tl.constexpr = GatherIndx is not None
+ USE_GATHER_TMA: tl.constexpr = HAS_GATHER and X_TMA_MODE == "dense"
+ USE_SCATTER_TMA: tl.constexpr = HAS_SCATTER and Y_TMA_MODE == "dense"
+
+ if EPILOGUE_SUBTILE is None:
+ SUBTILE_FACTOR: tl.constexpr = 1
+ else:
+ SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE
+ EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR
+ OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N
+ yN = N // ACTIVATION_REDUCTION_N
+
+ # set masked out rows to 0
+ if HAS_SCATTER and N_EXPTS_ACT == 1:
+ # Iterate with reversed pids so that later pids will get more tiles if the number of
+ # tiles isn't evenly divisible by the number of SMs.
+ # The main loop after this iterates in the forward direction such that earlier
+ # pids get more tiles if the number of tiles isn't evenly divisible.
+ # This helps balance the work across the SMs.
+ for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS):
+ pid_k = pid_mnk % SPLIT_K
+ pid_mn = pid_mnk // SPLIT_K
+ pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M)
+
+ z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32)
+ offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0])
+ offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1])
+ src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
+ YPtrs = YPtr + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
+ mask_n = offs_n < yN
+ mask = (src_idx == -1)[:, None] & mask_n[None, :]
+ tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask)
+
+
+ k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K)
+ num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
+
+ # If true, do not share loop-carried variables between the prologue and the
+ # epilogue to enable better pipelining with mmav5
+ INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0)
+
+ # start negative; will be incremented at the top of the loop
+ if INDEPENDENT_EPILOGUE:
+ tile_id1 = tl.program_id(0) - NUM_SMS
+
+ # Keep track of local max for updating flexpoint scales.
+ THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
+ local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
+
+ DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
+
+ for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True):
+ expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(
+ tile_id, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M, BLOCK_N, SPLIT_K,
+ GROUP_M, XCD_SWIZZLE)
+
+ # Base pointers and offsets.
+ if X_TMA_MODE is None:
+ XBase = X + start_z.to(index_type) * stride_x_z
+ offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k
+ if SPLIT_K > 1:
+ offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k
+
+ if USE_GATHER_TMA:
+ offs_m = off_m + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < (M if M is not None else eM)
+ if ExptData is None:
+ offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m)
+ # Bump rows to account for the Z offset.
+ offs_x_m += start_z * (stride_x_z // stride_x_m)
+ offs_x_m = tl.where(mask_m, offs_x_m, -1)
+ else:
+ offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
+ mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
+ elif X_TMA_MODE is None:
+ tl.static_assert(HAS_GATHER)
+ offs_m = off_m + tl.arange(0, BLOCK_M)
+ if M is not None:
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M)
+ else:
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
+ # no needs to bounds-check here because `offs_m` wraps around M dim
+ offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
+ offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
+
+
+ acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
+ off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
+ off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
+ off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
+
+ # --- load x ---
+ if USE_GATHER_TMA:
+ x = X.gather(offs_x_m, off_k)
+ elif X_TMA_MODE == "dense":
+ x = X.load([start_z, start_m + off_m, off_k])
+ x = x.reshape(BLOCK_M, BLOCK_K)
+ elif X_TMA_MODE == "ragged":
+ x = load_ragged(X, start_m, eM, [start_z, off_m, off_k], ragged_dim=1)
+ x = x.reshape(BLOCK_M, BLOCK_K)
+ else:
+ tl.static_assert(X_TMA_MODE is None)
+ XPtrs = XBase + offs_x_m + offs_x_k
+ XBase += BLOCK_K * SPLIT_K * stride_x_k
+ mask_k = tl.arange(0, BLOCK_K) < K - off_k
+ if EVEN_K:
+ if SPLIT_K > 1:
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+ else:
+ x = tl.load(XPtrs)
+ else:
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
+
+ # --- load w ---
+ if W_TRANSPOSE:
+ w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:]).T
+ else:
+ w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:])
+
+ # --- load w_scale ---
+ if is_microscaled_format:
+ x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
+ mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
+ if x_format == "fp16" or x_format == "bf16":
+ x_scales: tl.constexpr = None
+ else:
+ x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8)
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
+ flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128)
+ w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
+ w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
+ w_scales = unswizzle_mx_scale_bw(w_scales)
+ else:
+ w_scales = MxScale.load([expt_id, off_k_mx, off_n])
+ w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T
+
+ # --- update accumulator ---
+ if is_microscaled_format:
+ if SWAP_XW:
+ acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
+ else:
+ acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
+ else:
+ if SWAP_XW:
+ acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+ else:
+ acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
+
+ if INDEPENDENT_EPILOGUE:
+ tile_id1 += NUM_SMS
+ expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs(
+ tile_id1, num_tiles, grid_m, grid_n, padding_m,
+ M, ExptData, ExptHist, ExptOffs,
+ BLOCK_M, BLOCK_N, SPLIT_K,
+ GROUP_M, XCD_SWIZZLE)
+ else:
+ tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM
+ off_m1, off_n1, pid_k1 = off_m, off_n, pid_k
+
+ offs_m = off_m1 + tl.arange(0, BLOCK_M)
+ mask_m = offs_m < (M if M is not None else eM1)
+ if USE_SCATTER_TMA:
+ offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
+ MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
+ if SPLIT_K > 1:
+ # Compute the split k offset in number of rows, and add it to offs_y_m.
+ # This allows us to write to the correct slice in the output tensor while using
+ # a 2D TMA scatter.
+ tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m))
+ split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m)
+ offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m)
+ elif Y_TMA_MODE is None:
+ tl.static_assert(HAS_SCATTER)
+ offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
+ MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
+ else:
+ offs_y_m = start_m1 + offs_m
+ MASK_ACC = False if USE_GATHER_TMA else USE_FLEXPOINT_SCALE
+
+ # bias + scale
+ offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
+ mask_n = offs_y_n < N
+ if B is not None:
+ BPtrs = B + expt_id1 * stride_b_e + offs_y_n
+ if pid_k1 == 0:
+ bias = tl.load(BPtrs, mask=mask_n, other=0)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ else:
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
+ if Betas is not None:
+ betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
+ else:
+ betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ if Gammas is not None:
+ gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0)
+ else:
+ gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
+ x_scale = load_scale(XScale)
+ if PER_BATCH_SCALE:
+ w_scale = load_scale(WScale + expt_id1)
+ else:
+ w_scale = load_scale(WScale)
+
+ accs = (acc,)
+ biases = (bias,)
+
+ if SUBTILE_FACTOR >= 2:
+ acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
+ accs = (acc0, acc1)
+ bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
+ biases = (bias0, bias1)
+
+ if SUBTILE_FACTOR >= 4:
+ acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
+ acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
+ accs = (acc00, acc01, acc10, acc11)
+ bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
+ bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
+ biases = (bias00, bias01, bias10, bias11)
+
+ tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
+ tl.static_assert(len(accs) == SUBTILE_FACTOR)
+
+ for a_i in tl.static_range(len(accs)):
+ acc_tile = accs[a_i]
+ acc_tile *= x_scale * w_scale
+
+ if SWAP_XW:
+ acc_tile = acc_tile.T
+
+ acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
+ if out_alpha is not None:
+ acc_tile *= out_alpha
+
+ if ACTIVATION_FN is not None:
+ out = ACTIVATION_FN(acc_tile, *activation_fn_args)
+ tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
+ else:
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
+ out = acc_tile
+
+ out *= gammas[:, None]
+
+ if MASK_ACC:
+ out = tl.where(mask_m[:, None], out, 0.0)
+ # Flexpoint
+ out_view = tl.reshape(out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
+ local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
+ out = float_to_flex(
+ out, YExpectedScale,
+ None, # ActualScale: local absmax is tracked and updated after the loop
+ YChecksumScale,
+ None, # mask: out is manually masked to 0
+ YPtr, FLEXPOINT_SATURATE_INF
+ )
+ if EPILOGUE_FN is not None:
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
+
+ out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
+ out = out.to(YPtr.dtype.element_ty)
+ if USE_SCATTER_TMA:
+ # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
+ # there shouldn't be any other negative values.
+ offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True)
+ Y.scatter(out, offs_y_m, out_off_n)
+ elif Y_TMA_MODE == "dense":
+ out = tl.reshape(out, [1] + out.shape)
+ off_kz = pid_k * batch_size + start_z1
+ Y.store([off_kz, off_m1, out_off_n], out)
+ elif Y_TMA_MODE == "ragged":
+ out = tl.reshape(out, [1] + out.shape)
+ store_ragged(Y, start_m1, eM1, [pid_k, off_m1, out_off_n], out, ragged_dim=1)
+ else:
+ tl.static_assert(Y_TMA_MODE is None)
+ offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
+ mask_n = offs_y_n < yN
+
+ YPtrs = YPtr + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n
+ mask = mask_m[:, None] & mask_n[None, :]
+ tl.store(YPtrs, out, mask=mask)
+
+
+ # Update the flexpoint scales
+ if YActualScale is not None:
+ tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), YPtr), sem="relaxed")
+
+
+_per_device_alloc_fns = {}
+def get_per_device_per_stream_alloc_fn(device):
+ if device not in _per_device_alloc_fns:
+ _per_stream_tensors = {}
+ def alloc_fn(size: int, alignment: int, stream):
+ assert alignment == 128
+ if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size:
+ _per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8)
+ _per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
+ return _per_stream_tensors[stream]
+
+ _per_device_alloc_fns[device] = alloc_fn
+ return _per_device_alloc_fns[device]
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_reduce_grouped.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_reduce_grouped.py
new file mode 100644
index 0000000000000000000000000000000000000000..125e1cf09700f4ed593a947ddbb204877b9d02db
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/_reduce_grouped.py
@@ -0,0 +1,126 @@
+from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
+ float_to_flex,
+ load_scale,
+)
+from vllm.kvprune.triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _reduce_grouped(
+ X,
+ stride_xb: tl.uint64,
+ stride_xm: tl.uint64,
+ stride_xn, #
+ XScale, # input scalar flex scale
+ Out,
+ stride_om: tl.uint64,
+ stride_on, # output tensor
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale, # output scalar flex scales
+ InIndx,
+ B,
+ N, #
+ XMxScale,
+ stride_mxb: tl.uint64,
+ stride_mxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
+ OutMxScale,
+ stride_omxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
+ # fused activation function
+ ACTIVATION_FN: tl.constexpr,
+ activation_fn_args,
+ ACTIVATION_REDUCTION_N: tl.constexpr,
+ # epilogue transform
+ EPILOGUE_FN: tl.constexpr,
+ epilogue_fn_args,
+ #
+ HAS_IN_MX_SCALE: tl.constexpr,
+ HAS_OUT_MX_SCALE: tl.constexpr,
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
+ K: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ pid_t = tl.program_id(0)
+ BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
+ # persistent along N: single program on N, iterate tiles of size BLOCK_N
+ start = pid_t * K
+ # load indices into a tuple
+ if InIndx is None:
+ indxs = (pid_t,)
+ else:
+ indxs = ()
+ for i in tl.static_range(0, K):
+ indxs = indxs + (tl.load(InIndx + start + i),)
+ # determine first valid topk row
+ fi = indxs[(K - 1)]
+ for i in tl.static_range(K - 2, -1, -1):
+ fi = tl.where(indxs[i] != -1, indxs[i], fi)
+ # record overwritten row index (may be -1 if none)
+ XPtrs = X + tl.arange(0, BLOCK_N) * stride_xn
+ OutPtrs = Out + tl.arange(0, BLOCK_N_OUT) * stride_on
+ if HAS_IN_MX_SCALE:
+ XScalePtrs = XMxScale + tl.arange(0, BLOCK_N // 32) * stride_xn
+ if HAS_OUT_MX_SCALE:
+ OutScalePtrs = OutMxScale + tl.arange(0, BLOCK_N_OUT // 32) * stride_on
+ x_scale = load_scale(XScale)
+ for n_curr in tl.range(0, N, BLOCK_N, num_stages=4):
+ acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)
+ x_n_mask = tl.arange(0, BLOCK_N) < N - n_curr
+ x_n_mask_scale = tl.arange(0, BLOCK_N // 32) < tl.cdiv(N - n_curr, 32)
+ # accumulate contributions for this tile
+ for i in tl.static_range(0, K):
+ curr = tl.zeros([BLOCK_N], dtype=tl.float32)
+ # iterate over split_k partial values
+ for b in tl.range(0, B):
+ is_valid = indxs[i] != -1
+ x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb
+ vals = tl.load(x_row_ptr, mask=x_n_mask & is_valid, other=0.0)
+ vals = vals.to(tl.float32)
+ if HAS_IN_MX_SCALE:
+ scale_row_ptr = XScalePtrs + indxs[i] * stride_mxs + b * stride_mxb
+ scale = tl.load(
+ scale_row_ptr, mask=x_n_mask_scale & is_valid, other=0.0
+ )
+ scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
+ vals = vals.reshape([BLOCK_N // 32, 32])
+ vals = (scale[:, None] * vals).reshape([BLOCK_N])
+ curr += vals
+ # apply nonlinearity to split-k output
+ if ACTIVATION_FN is not None:
+ curr = ACTIVATION_FN(curr[None, :], *activation_fn_args)
+ curr = tl.reshape(curr, [curr.shape[-1]])
+ # update final accumulator
+ acc += curr
+ acc *= x_scale
+ # Compute per-32-col MXFP scales for this tile if requested
+ Nrem = (N - n_curr) // ACTIVATION_REDUCTION_N
+ out_n_mask = tl.arange(0, BLOCK_N_OUT) < Nrem
+ out_n_mask_scale = tl.arange(0, BLOCK_N_OUT // 32) < tl.cdiv(Nrem, 32)
+ if HAS_OUT_MX_SCALE:
+ acc, acc_scale = quantize_mxfp8_fn(acc[None, :], out_n_mask[None, :])
+ acc = tl.reshape(acc, [acc.shape[-1]])
+ acc_scale = tl.reshape(acc_scale, [acc_scale.shape[-1]])
+ # Convert to flexpoint output if configured (scalar scales)
+ acc = float_to_flex(
+ acc,
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale,
+ None,
+ Out,
+ FLEXPOINT_SATURATE_INF,
+ )
+ # write-back for this tile
+ out_ptr = OutPtrs + pid_t * stride_om
+ tl.store(out_ptr, acc, mask=out_n_mask)
+ if HAS_OUT_MX_SCALE:
+ out_scale_ptr = OutScalePtrs + pid_t * stride_omxs
+ tl.store(out_scale_ptr, acc_scale, mask=out_n_mask_scale)
+ XPtrs += BLOCK_N * stride_xn
+ OutPtrs += BLOCK_N_OUT * stride_on
+ if HAS_IN_MX_SCALE:
+ XScalePtrs += BLOCK_N // 32 * stride_xn
+ if HAS_OUT_MX_SCALE:
+ OutScalePtrs += BLOCK_N_OUT // 32 * stride_xn
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..b964897603406e0a706306ba936591db03e69dc3
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags.py
@@ -0,0 +1,303 @@
+# isort: off
+# fmt: off
+from dataclasses import dataclass
+import triton
+from vllm.kvprune.triton_kernels.target_info import get_cdna_version
+import torch
+from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
+
+
+@dataclass
+class OptFlags:
+ block_m: int
+ block_n: int
+ block_k: int
+ num_warps: int
+ num_stages: int
+ group_m: int
+ xcd_swizzle: int
+ w_cache_modifier: str
+ split_k: int
+ is_persistent: bool
+ fused_scatter: bool
+ idle_sms: int
+ epilogue_subtile: int | None
+ arch: str
+ target_kernel_kwargs: dict
+
+ def __post_init__(self):
+ if self.fused_scatter and self.split_k != 1:
+ raise ValueError("Not supported")
+
+
+def make_default_opt_flags_amd(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ enforce_bitwise_invariance,
+ epilogue_effective_itemsize,
+ constraints,
+):
+ constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
+ # tokens per expert
+ if routing_data is None:
+ tokens_per_expt = m
+ elif routing_data.expected_tokens_per_expt is None:
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
+ else:
+ tokens_per_expt = routing_data.expected_tokens_per_expt
+
+ is_cdna4 = get_cdna_version() == 4
+ # block_m
+ if constraints.get("block_m", None):
+ block_m = constraints["block_m"]
+ elif enforce_bitwise_invariance:
+ block_m = 256 if is_cdna4 else 128
+ elif tokens_per_expt >= 512 and n >= 2048:
+ block_m = 256 if is_cdna4 else 128
+ elif is_cdna4 and m >= 512:
+ block_m = 128
+ else:
+ block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
+
+ if routing_data is not None:
+ grid_m = routing_data.n_blocks(m, block_m)
+ else:
+ grid_m = triton.cdiv(m, block_m)
+ # group_m:
+ group_m = 4
+ # number of xcds
+ num_xcds = 8
+ xcd_swizzle = num_xcds
+ # block_nk:
+ block_n, block_k = opt_flags_amd.compute_block_nk(
+ n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
+ )
+ # Replace block_k if provided in constraints.
+ # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
+ if constraints.get("block_k", None) is not None:
+ block_k = constraints["block_k"]
+ if constraints.get("block_n", None) is not None:
+ block_n = constraints["block_n"]
+ is_persistent = constraints.get("is_persistent", False)
+ # split_k:
+ if constraints.get("split_k", None) is not None:
+ split_k = constraints["split_k"]
+ elif is_persistent or enforce_bitwise_invariance:
+ split_k = 1
+ else:
+ grid_size = grid_m * ((n + block_n - 1) // block_n)
+ n_cu = torch.cuda.get_device_properties(0).multi_processor_count
+ split_k = max(1, n_cu // grid_size)
+ # w_cache_modifier:
+ w_cache_modifier = ".cg" if block_m <= 32 else None
+ # num_warps, num_stages
+ num_warps = 2 if (m is not None and m <= 16) else 8
+ num_stages = 2
+ # AMD-specific
+ target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
+ epilogue_subtile = constraints.get('epilogue_subtile', None)
+ if epilogue_subtile is None:
+ epilogue_subtile = 1
+ ret = OptFlags(
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ group_m=group_m,
+ xcd_swizzle=xcd_swizzle,
+ w_cache_modifier=w_cache_modifier,
+ split_k=split_k,
+ is_persistent=is_persistent,
+ fused_scatter=constraints.get('fused_scatter', False),
+ idle_sms=0,
+ epilogue_subtile=epilogue_subtile,
+ arch=None,
+ target_kernel_kwargs=target_kernel_kwargs,
+ )
+ # check constraints
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
+ return ret
+
+def make_default_opt_flags_nvidia(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ enforce_bitwise_invariance,
+ epilogue_effective_itemsize,
+ constraints,
+):
+ constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
+ # tokens per expert
+ if routing_data is None:
+ tokens_per_expt = m
+ elif routing_data.expected_tokens_per_expt is None:
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
+ else:
+ tokens_per_expt = routing_data.expected_tokens_per_expt
+ # pid swizzling
+ group_m = 8
+ xcd_swizzle = 1
+ # block_m
+ if constraints.get("block_m", None):
+ block_m = constraints["block_m"]
+ elif enforce_bitwise_invariance:
+ block_m = 128
+ else:
+ block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
+ # block n
+ arch = None
+ block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
+ # is_persistent
+ grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
+ n_sms = torch.cuda.get_device_properties(0).multi_processor_count
+ tiles_per_sm = grid_size / n_sms
+ supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
+ if constraints.get("is_persistent", None) is not None:
+ is_persistent = constraints["is_persistent"]
+ else:
+ has_simple_epilogue = precision_config.max_num_imprecise_acc is None
+ is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
+ # TEMP CHANGE
+ if precision_config.act_scale is not None or precision_config.out_scale is not None:
+ is_persistent = False
+ # block k
+ if constraints.get("block_k", None) is not None:
+ block_k = constraints["block_k"]
+ else:
+ block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
+ # split_k
+ if constraints.get("split_k", None) is not None:
+ split_k = constraints["split_k"]
+ elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
+ split_k = 1
+ else:
+ estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
+ split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
+ if split_k > 1:
+ # With split_k, results are written in f32. Use that for the following computations.
+ out_dtype = torch.float32
+ compute_num_stages_args = (
+ precision_config,
+ is_persistent,
+
+ block_m,
+ block_n,
+ block_k,
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ )
+
+ if constraints.get("epilogue_subtile", None) is not None:
+ subtiles_to_check = [constraints["epilogue_subtile"]]
+ else:
+ subtiles_to_check = [1, 2, 4]
+ num_stages = -1
+ for ep in subtiles_to_check:
+ ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
+ if ns > num_stages:
+ epilogue_subtile, num_stages = ep, ns
+ assert num_stages >= 1
+ if constraints.get("num_stages", None):
+ num_stages = constraints["num_stages"]
+ # fused scatter scratchpad
+ if constraints.get("fused_scatter", None) is not None:
+ fused_scatter = constraints["fused_scatter"]
+ else:
+ fused_scatter = can_use_fused_scatter and split_k == 1
+ # Handshake with the HBM swizzling
+ num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
+ ret = OptFlags(
+ block_m=block_m,
+ block_n=block_n,
+ block_k=block_k,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ fused_scatter=fused_scatter,
+ group_m=group_m,
+ xcd_swizzle=xcd_swizzle,
+ w_cache_modifier=None,
+ split_k=split_k,
+ is_persistent=is_persistent,
+ epilogue_subtile=epilogue_subtile,
+ arch=arch,
+ target_kernel_kwargs=dict(),
+ idle_sms=constraints.get("idle_sms", 0),
+ )
+ # check constraints
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
+ return ret
+
+# --------------
+# User Interface
+# --------------
+
+_opt_flags_constraints: dict = dict()
+_opt_flags: OptFlags | None = None
+
+def update_opt_flags_constraints(constraints: dict[str, int]):
+ global _opt_flags_constraints
+ _opt_flags_constraints.update(constraints)
+
+def reset_opt_flags_constraints():
+ global _opt_flags_constraints
+ _opt_flags_constraints = dict()
+
+def set_opt_flags(opt_flags: OptFlags):
+ global _opt_flags
+ assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
+ assert not _opt_flags, "opt_flags already set; please reset to None first"
+ _opt_flags = opt_flags
+
+class InapplicableConstraint(Exception):
+ pass
+
+def make_opt_flags(
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ precision_config,
+ m,
+ n,
+ k,
+ routing_data,
+ can_use_persistent_tma,
+ can_use_fused_scatter,
+ epilogue_effective_itemsize,
+):
+ if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
+ raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
+ if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter:
+ raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint")
+ enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
+ if _opt_flags is not None:
+ assert not _opt_flags_constraints
+ return _opt_flags
+ args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
+ routing_data, can_use_persistent_tma, can_use_fused_scatter,
+ enforce_bitwise_invariance, epilogue_effective_itemsize,
+ _opt_flags_constraints]
+ backend = triton.runtime.driver.active.get_current_target().backend
+ if backend == "hip":
+ return make_default_opt_flags_amd(*args)
+ if backend == "cuda":
+ return make_default_opt_flags_nvidia(*args)
+ assert False
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccc84c5f555263723ce9f268ebabf48f56656b41
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
@@ -0,0 +1,37 @@
+import torch
+import triton
+from vllm.kvprune.triton_kernels.target_info import get_cdna_version
+from vllm.kvprune.triton_kernels.tensor import bitwidth
+
+
+def compute_block_nk(
+ n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
+):
+ lhs_width = bitwidth(lhs_dtype) / 8
+ rhs_width = bitwidth(rhs_dtype) / 8
+
+ # block_n:
+ n_cu = torch.cuda.get_device_properties(0).multi_processor_count
+ if n is not None:
+ if n <= 128 and (n & (n - 1)) == 0:
+ block_n = n
+ else:
+ block_n = max(
+ 32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))
+ )
+ elif block_m > 64:
+ block_n = 256
+ else:
+ block_n = 128
+
+ if get_cdna_version() == 4 and block_m == 128:
+ block_n = 512
+
+ # block_k needs to match the cacheline size (128B)
+ block_k = int(128 // min(lhs_width, rhs_width))
+
+ # TODO: block_k = 128 seems to work better for now.
+ # perhaps due to increased number of k loops to pipeline
+ if precision_config.weight_scale is not None and get_cdna_version() != 4:
+ block_k = 128
+ return block_n, block_k
diff --git a/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
new file mode 100644
index 0000000000000000000000000000000000000000..29fffd9410f987a83248e50879befa09b85df938
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
@@ -0,0 +1,119 @@
+import torch
+import triton
+from vllm.kvprune.triton_kernels import target_info
+from vllm.kvprune.triton_kernels.tensor import get_layout, bitwidth, FP4
+from vllm.kvprune.triton_kernels.tensor_details.layout import HopperMXScaleLayout
+from vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import (
+ MXFP_BLOCK_SIZE,
+)
+
+
+def compute_grid_size(routing_data, m, n, block_m, block_n):
+ if routing_data is not None:
+ grid_m = routing_data.n_blocks(m, block_m)
+ else:
+ grid_m = triton.cdiv(m, block_m)
+ grid_n = (n + block_n - 1) // block_n
+ return grid_m * grid_n
+
+
+def compute_block_n(n: int, arch, precision_config):
+ # block_n:
+ layout = get_layout(precision_config.weight_scale)
+ if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
+ return 128
+ elif precision_config.max_num_imprecise_acc is None and n > 128:
+ return 256
+ else:
+ return max(16, min(128, triton.next_power_of_2(n)))
+
+
+def compute_block_k(
+ m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config
+):
+ lhs_width = bitwidth(lhs_dtype)
+ rhs_width = bitwidth(rhs_dtype)
+ # block_k needs to match the cacheline size (1024 bits)
+ block_k = int(1024 // min(lhs_width, rhs_width))
+ has_native_mxfp = target_info.cuda_capability_geq(10, 0)
+ if rhs_width == 4 and not has_native_mxfp:
+ block_k = 128
+ elif k is not None:
+ block_k = max(32, min(triton.next_power_of_2(k), block_k))
+ has_mx_weight_scale = (
+ precision_config is not None and precision_config.weight_scale is not None
+ )
+ if has_native_mxfp and is_persistent and has_mx_weight_scale:
+ block_k = min(block_k, 128)
+ return block_k
+
+
+def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
+ device_props = torch.cuda.get_device_properties(0)
+ n_sms = device_props.multi_processor_count
+ split_k = n_sms // grid_size
+ if k is not None:
+ # avoid split_k for small k
+ num_block_k = triton.cdiv(k, block_k)
+ split_k = min(split_k, num_block_k // 4)
+ split_k = max(split_k, 1)
+ return split_k
+
+
+def compute_num_warps(block_m, block_n, precision_config):
+ layout = get_layout(precision_config.weight_scale)
+ if isinstance(layout, HopperMXScaleLayout):
+ return layout.num_warps
+ return max(block_m * block_n // 4096, 4)
+
+
+def compute_num_stages(
+ precision_config,
+ is_persistent,
+ block_m,
+ block_n,
+ block_k,
+ out_dtype,
+ lhs_dtype,
+ rhs_dtype,
+ epilogue_subtile,
+ epilogue_effective_itemsize,
+):
+ if precision_config.max_num_imprecise_acc is not None:
+ return 3
+ weight_size = bitwidth(rhs_dtype) / 8
+ stage_size = (
+ block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size
+ )
+ device_props = torch.cuda.get_device_properties(0)
+ smem_capacity = device_props.shared_memory_per_block_optin
+ has_native_mxfp = target_info.cuda_capability_geq(10, 0)
+ if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None:
+ if rhs_dtype == FP4:
+ # 4-bit e2m1 weights are padded 2x
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
+ stage_size += block_k * block_n * weight_size
+
+ if is_persistent:
+ # Per-stage wait barrier
+ stage_size += 8
+ if target_info.cuda_capability_geq(10, 0):
+ acc_size = epilogue_effective_itemsize or out_dtype.itemsize
+ else:
+ acc_size = out_dtype.itemsize
+ if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None:
+ acc_block_n = block_n // epilogue_subtile
+ else:
+ acc_block_n = block_n
+ # pipelined TMA store local to global, or
+ # pipelined layout conversion before store of the accumulator
+ # note: layout conversion has some padding
+ smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
+ if precision_config.weight_scale is not None:
+ # mx scales
+ stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
+ elif has_native_mxfp:
+ # mx scales
+ stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
+ num_stages = min(4, smem_capacity // int(stage_size))
+ return num_stages
diff --git a/vllm/kvprune_legacy_save/triton_kernels/numerics.py b/vllm/kvprune_legacy_save/triton_kernels/numerics.py
new file mode 100644
index 0000000000000000000000000000000000000000..024d3fcf0b819646a485596070b14c7a0a2e17ed
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/numerics.py
@@ -0,0 +1,42 @@
+import torch
+from dataclasses import dataclass
+
+MAX_FINITE_FLOAT8E5 = 57344.0
+MAX_FINITE_FLOAT8E4NV = 448.0
+MAX_FINITE_FLOAT8E4B8 = 240.0
+
+
+@dataclass(frozen=True)
+class BaseFlexData:
+ dtype: torch.dtype | None = None
+
+ def view(self, x: torch.Tensor):
+ if self.dtype is None:
+ return x
+ return x.view(self.dtype)
+
+ def reinterpret(self, x):
+ if self.dtype is None or x.dtype.itemsize > 1:
+ return x
+ return x.view(self.dtype)
+
+
+@dataclass(frozen=True)
+class InFlexData(BaseFlexData):
+ scale: torch.Tensor | None = None
+
+ @property
+ def is_per_batch(self):
+ return False if self.scale is None else len(self.scale) > 1
+
+
+@dataclass(frozen=True)
+class OutFlexData(BaseFlexData):
+ expected_scale: torch.Tensor | None = None
+ actual_scale: torch.Tensor | None = None
+ checksum_scale: torch.Tensor | None = None
+
+ def __iter__(self):
+ yield self.expected_scale
+ yield self.actual_scale
+ yield self.checksum_scale
diff --git a/vllm/kvprune_legacy_save/triton_kernels/numerics_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/numerics_details/flexpoint.py b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/flexpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..194fdb87295cd9bd028c31d216eafae43a7cec14
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/flexpoint.py
@@ -0,0 +1,204 @@
+from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
+import triton
+import triton.language as tl
+from vllm.kvprune.triton_kernels.target_info import cuda_capability_geq
+
+# -------------------------------
+# Kernels stuff
+# -------------------------------
+
+TL_MAX_FINITE_FLOAT8E5 = tl.constexpr(MAX_FINITE_FLOAT8E5)
+TL_MAX_FINITE_FLOAT8E4NV = tl.constexpr(MAX_FINITE_FLOAT8E4NV)
+TL_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(MAX_FINITE_FLOAT8E4B8)
+TL_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(1.750)
+TL_MAX_FINITE_FLOAT16 = tl.constexpr(65472.0)
+
+TL_RCP_MAX_FINITE_FLOAT8E5 = tl.constexpr(0x37924925) # 0x1.24924Ap-16
+TL_RCP_MAX_FINITE_FLOAT8E4NV = tl.constexpr(0x3B124925) # 0x1.24924Ap-9
+TL_RCP_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(0x3B888889) # 0x1.111112p-8
+TL_RCP_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(0x3F124925) # 0x1.24924Ap-1
+TL_RCP_MAX_FINITE_FLOAT16 = tl.constexpr(0x37802008) # 0x1.004010p-16
+
+
+@triton.jit
+def max_finite(dtype):
+ if dtype == tl.constexpr(tl.float8e5):
+ return TL_MAX_FINITE_FLOAT8E5
+ elif dtype == tl.constexpr(tl.float8e4nv):
+ return TL_MAX_FINITE_FLOAT8E4NV
+ elif dtype == tl.constexpr(tl.float8e4b8):
+ return TL_MAX_FINITE_FLOAT8E4B8
+ elif dtype == tl.constexpr(tl.float8e4b15):
+ return TL_MAX_FINITE_FLOAT8E4B15
+ elif dtype == tl.constexpr(tl.float16):
+ return TL_MAX_FINITE_FLOAT16
+ else:
+ tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
+
+
+@triton.jit
+def rcp_max_finite(dtype):
+ if dtype == tl.constexpr(tl.float8e5):
+ return TL_RCP_MAX_FINITE_FLOAT8E5
+ elif dtype == tl.constexpr(tl.float8e4nv):
+ return TL_RCP_MAX_FINITE_FLOAT8E4NV
+ elif dtype == tl.constexpr(tl.float8e4b8):
+ return TL_RCP_MAX_FINITE_FLOAT8E4B8
+ elif dtype == tl.constexpr(tl.float8e4b15):
+ return TL_RCP_MAX_FINITE_FLOAT8E4B15
+ elif dtype == tl.constexpr(tl.float16):
+ return TL_RCP_MAX_FINITE_FLOAT16
+ else:
+ tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
+
+
+@triton.jit
+def sm86_min_nan_xorsign_abs_f32(a, b):
+ """Wrapper for min.NaN.xorsign.abs.f32 PTX instruction.
+
+ Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
+ NaN inputs are propagated to the output.
+
+ Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
+ """
+ tl.static_assert(
+ cuda_capability_geq(8, 6),
+ "min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+",
+ )
+ tl.static_assert(
+ a.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+ tl.static_assert(
+ b.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+
+ return tl.inline_asm_elementwise(
+ """{
+ min.NaN.xorsign.abs.f32 $0, $1, $2;
+ }""",
+ "=r,r,r",
+ [a, b],
+ dtype=tl.float32,
+ is_pure=True,
+ pack=1,
+ )
+
+
+@triton.jit
+def sm86_max_nan_xorsign_abs_f32(a, b):
+ """Wrapper for max.NaN.xorsign.abs.f32 PTX instruction.
+
+ Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
+ NaN inputs are propagated to the output.
+
+ Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
+ """
+ tl.static_assert(
+ cuda_capability_geq(8, 6),
+ "max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+",
+ )
+ tl.static_assert(
+ a.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+ tl.static_assert(
+ b.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs"
+ )
+
+ return tl.inline_asm_elementwise(
+ """{
+ max.NaN.xorsign.abs.f32 $0, $1, $2;
+ }""",
+ "=r,r,r",
+ [a, b],
+ dtype=tl.float32,
+ is_pure=True,
+ pack=1,
+ )
+
+
+@triton.jit
+def load_scale(scale_ptr):
+ return 1.0 if scale_ptr is None else tl.load(scale_ptr)
+
+
+@triton.jit
+def flex_to_float(x, scale_ptr):
+ scale = load_scale(scale_ptr)
+ return x.to(tl.float32) * scale
+
+
+@triton.jit
+def clip(x, limit):
+ res = tl.minimum(x, limit)
+ res = tl.maximum(-limit, res)
+ return res
+
+
+@triton.jit
+def nan_propagating_absmax_reduce(x, axis=None):
+ if cuda_capability_geq(8, 6):
+ # abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported.
+ x_absmax = tl.reduce(x, axis, sm86_max_nan_xorsign_abs_f32)
+ # Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it.
+ x_absmax = x_absmax.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
+ else:
+ # Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float)
+ masked_abs_x = x.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
+ x_absmax = tl.max(masked_abs_x, axis)
+
+ return x_absmax
+
+
+@triton.jit
+def compute_scale(x, Out):
+ x_absmax = nan_propagating_absmax_reduce(tl.ravel(x, can_reorder=True))
+
+ # atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000).
+ # We use integer minimum because NaNs are above +inf in integer representation.
+ x_absmax = tl.minimum(x_absmax, 0x7F800000).to(tl.float32, bitcast=True)
+ RCP_MAX_VALUE = rcp_max_finite(Out.dtype.element_ty)
+ return tl.fma(x_absmax, RCP_MAX_VALUE.to(tl.float32, bitcast=True), 1.0e-30)
+
+
+@triton.jit
+def update_scale(x, scale_ptr, Out) -> None:
+ if scale_ptr is not None:
+ scale = compute_scale(x, Out)
+ tl.atomic_max(scale_ptr, scale, sem="relaxed")
+
+
+@triton.jit
+def float_to_flex(
+ x,
+ expected_scale_ptr_or_val,
+ actual_scale_ptr,
+ checksum_scale_ptr,
+ mask,
+ Out,
+ saturate_infs: tl.constexpr,
+):
+ if expected_scale_ptr_or_val is not None:
+ if expected_scale_ptr_or_val.dtype.is_ptr():
+ invscale = 1.0 / tl.load(expected_scale_ptr_or_val)
+ else:
+ invscale = 1.0 / expected_scale_ptr_or_val
+ else:
+ invscale = 1.0
+ if checksum_scale_ptr is not None:
+ x_int32 = x.to(tl.int32, bitcast=True)
+ zero = tl.cast(0.0, tl.int32)
+ if mask is not None:
+ x_int32 = tl.where(mask, x_int32, zero)
+ checksum_local = tl.xor_sum(tl.ravel(x_int32, can_reorder=True), 0)
+ tl.atomic_add(checksum_scale_ptr, checksum_local)
+ if mask is not None:
+ if actual_scale_ptr is not None:
+ x = tl.where(mask, x, 0.0)
+ update_scale(x, actual_scale_ptr, Out)
+ x = x * invscale
+ # if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case.
+ if expected_scale_ptr_or_val is not None:
+ if saturate_infs:
+ CLIP_VALUE = max_finite(Out.dtype.element_ty)
+ x = clip(x, CLIP_VALUE)
+ return x
diff --git a/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp.py b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..37c69c83c1dd77668ae80cbee0f21bafc5767815
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp.py
@@ -0,0 +1,303 @@
+# isort: off
+# fmt: off
+from enum import Enum
+import triton
+import torch
+import torch.nn.functional as F
+from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp
+from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, MXFP_BLOCK_SIZE, _quantize_mxfp8_fn
+
+# -----------------------------------------------------------------------------
+# Dequantization / Quantization Utilities
+# -----------------------------------------------------------------------------
+
+
+class DequantScaleRoundingMode(Enum):
+ ROUND_UP = 0
+ ROUND_DOWN = 1
+
+
+def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
+ DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
+ """
+ Convert the src weights to mx format. The src weight is quantized along the axis dimension.
+
+ If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
+ Note that this means the k_dim of the tensor will be half of the logical k_dim.
+
+ If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
+ in their respective formats.
+ """
+ ndim = src_tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ axis = axis if axis >= 0 else axis + ndim
+ # downcast
+ src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
+ is_fp4 = out_quant_type == torch.uint8
+ is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
+ assert is_fp4 or is_fp8
+ divisor = 2 if is_fp4 else 1
+ L = src_tensor.shape[-1]
+ if is_fp4:
+ assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
+ out_shape = src_tensor.shape[:-1] + (L // divisor, )
+ out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
+
+ out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
+ out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
+
+ if src_tensor.numel() > 0:
+ kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
+ kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
+ kernel_scale = out_scale.view(-1, out_scale.shape[-1])
+
+ BLOCK_OUT_DIM = 128
+ BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
+ grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
+ grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
+
+ _downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
+ *kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
+ *kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
+ DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
+
+ out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
+ out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
+ return out_quant_tensor, out_scale
+
+
+def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
+ """
+ Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
+
+ The function assumes that the tensors were quantized along the given axis.
+ It permutes the tensor so that the quantized axis is last, reshapes to 2D,
+ launches the Triton upcast kernel, and then unpermutes back to the original order.
+ """
+ ndim = tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ axis = axis if axis >= 0 else axis + ndim
+ assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
+ f"Got {tensor.ndim=} and {scale.ndim=}")
+ # dtype checks
+ assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
+ f"Invalid tensor dtype {tensor.dtype=}"
+ assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
+ assert target_dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {target_dtype=}"
+ # upcast
+ logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
+ tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
+ scale = scale.transpose(axis, scale.ndim - 1).contiguous()
+ out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=target_dtype, device=tensor.device)
+ reshaped_out = out.view(-1, out.shape[-1])
+ reshaped_tensor = tensor.view(-1, tensor.shape[-1])
+ reshaped_scale = scale.view(-1, scale.shape[-1])
+ BLOCK_OUT_DIM = 128
+ BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
+ blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
+ blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
+ _upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
+ *reshaped_scale.stride(), reshaped_tensor,
+ *reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
+ BLOCK_QUANT_DIM, num_warps=8)
+ out = out.transpose(axis, scale.ndim - 1).contiguous()
+ return out
+
+
+# ------------
+
+
+def right_shift_unsigned(x, shift):
+ # CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
+ return (x >> shift) & ((1 << (32 - shift)) - 1)
+
+
+def get_max_quant_val(dtype: torch.dtype):
+ d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
+ assert dtype in d
+ return d[dtype]
+
+
+def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
+ DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
+ """
+ Converts the src tensor to the output format specified by out_quant_type.
+ axis: The axis along which the tensors are contiguous and quantization is applied.
+ DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
+
+ Returns:
+ out_quant_tensor: Quantized tensor in mx format.
+ • For mxfp8, the output has the same shape as src_tensor.
+ • For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
+ scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
+ Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
+ where L is the original length along that axis.
+ """
+ # This should probably be packed into its own tiny class
+ ndim = src_tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ assert src_tensor.dtype in {torch.float32, torch.bfloat16,
+ torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
+
+ axis = axis if axis >= 0 else axis + ndim
+ is_fp4 = out_quant_type == torch.uint8
+ is_fp8 = "float8" in str(out_quant_type)
+ assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
+
+ device = src_tensor.device
+
+ # For mxfp4 conversion, we assume the contiguous axis length is even.
+ if is_fp4:
+ axis_shape = src_tensor.size(axis)
+ assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
+
+ # Permute the tensor so that the contiguous axis becomes the last dimension.
+ src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
+ axis_shape = src.shape[-1]
+
+ # Pad the axis to be divisible by 32, in case it is not.
+ next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
+ pad_amount = next_multiple - axis_shape
+ padded_src = F.pad(src, (0, pad_amount))
+ valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
+ padded_axis_shape = padded_src.size(-1) # now divisible by 32
+
+ # --- Compute per-group maximums for scale ---
+ # Set padded entries to -1 so they don’t affect the max.
+ abs_f = torch.abs(padded_src)
+ abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
+ # Reshape the last dimension into groups of 32.
+ new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
+ abs_groups = abs_f.view(*new_shape)
+ # Compute maximum along the group dimension (of size 32).
+ max_val, _ = abs_groups.max(dim=-1, keepdim=True)
+
+ # Choose a max quantization value depending on type.
+ max_quant_val = get_max_quant_val(out_quant_type)
+ dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
+
+ # Convert to int to round the FP32 scale, prior to quantization!
+ ds_int = dequant_scale.view(torch.int32)
+ if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
+ ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
+ else:
+ ds_int_rounded = ds_int & 0x7F800000
+ # Reinterpret back as float32.
+ dequant_scale_rounded = ds_int_rounded.view(torch.float32)
+
+ # Compute the quantization scale.
+ quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
+
+ # Quantize the tensor
+ orig_padded_shape = padded_src.shape
+ padded_src_groups = padded_src.view(*new_shape)
+ quant_tensor = padded_src_groups * quant_scale
+ # Reshape back to the original shape and trim padding
+ quant_tensor = quant_tensor.view(orig_padded_shape)
+ quant_tensor = quant_tensor[..., :axis_shape]
+
+ # Finally, convert the quantized tensor to the target format
+ if is_fp8:
+ # Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
+ quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
+ out_weight = quant_tensor.to(out_quant_type)
+ else:
+ assert is_fp4, f"Invalid output quantization type {out_quant_type}"
+ # For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
+ # First, reinterpret the quantized tensor bits.
+ q_int = quant_tensor.contiguous().view(torch.int32)
+ # Extract sign, exponent, and mantissa.
+ signs = q_int & 0x80000000
+ exponents = right_shift_unsigned(q_int, 23) & 0xFF
+ mantissas = q_int & 0x7FFFFF
+
+ E8_BIAS = 127
+ E2_BIAS = 1
+ # Adjust mantissas for subnormals.
+ mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
+ (E8_BIAS - exponents - 1), mantissas)
+ exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
+ e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
+ e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
+ e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
+
+ # Pack pairs of 4-bit values along the last dimension.
+ e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
+ evens = e2m1_value[..., 0]
+ odds = e2m1_value[..., 1]
+ out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
+
+ # --- Process and output the scale ---
+ dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
+ dq_scale = dq_scale.squeeze(-1)
+ out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
+ dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
+ return out_weight, dq_scale
+
+
+def cvt_e2m1_to_fp32(input_tensor):
+ assert input_tensor.dtype == torch.uint8
+
+ input_tensor = input_tensor.to(torch.int32)
+ evens = input_tensor & 0xF
+ odds = (input_tensor >> 4) & 0xF
+
+ vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
+ outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
+ outputs = torch.cat([outputs, -outputs])
+
+ even_floats = outputs[evens]
+ odd_floats = outputs[odds]
+ output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
+ output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
+ return output_tensor
+
+
+def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
+ """
+ Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
+ axis: The axis along which dequantization is applied.
+
+ Returns:
+ out_weight: Tensor in the target format.
+ """
+
+ ndim = tensor.ndim
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
+ is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
+ assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
+
+ # Permute the tensor and scale so that the quantization axis becomes the last dimension
+ axis = axis if axis >= 0 else axis + ndim
+ scale = scale.transpose(axis, scale.ndim - 1)
+ tensor = tensor.transpose(axis, tensor.ndim - 1)
+
+ dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
+ if tensor.dtype == torch.uint8:
+ fp32_tensor = cvt_e2m1_to_fp32(tensor)
+ else:
+ fp32_tensor = tensor.to(torch.float32)
+
+ logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
+ axis_shape = fp32_tensor.size(-1)
+ padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
+ pad_size = padded_axis_shape - axis_shape
+ padded_tensor = F.pad(fp32_tensor, (0, pad_size))
+
+ new_axis_shape = padded_tensor.shape[-1]
+ new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
+ padded_tensor = padded_tensor.view(*new_shape)
+ dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
+ out_padded = padded_tensor * dq_scale_padded
+
+ # Flatten back and remove the padded tail
+ out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
+ out_tensor = out_padded[..., :axis_shape]
+
+ out_tensor = out_tensor.to(target_dtype).contiguous()
+ out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
+
+ return out_tensor
+
+
+quantize_mxfp8_fn = _quantize_mxfp8_fn
diff --git a/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eac6467e2d8d49385106574ec073cf677c622e0
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
@@ -0,0 +1,158 @@
+import triton
+import triton.language as tl
+
+# fmt: off
+
+
+MXFP_BLOCK_SIZE = tl.constexpr(32)
+
+
+@triton.jit
+def _get_max_quant_val(dtype: tl.constexpr):
+ if dtype == tl.uint8:
+ return 6.0
+ elif dtype == tl.float8e5:
+ return 57344.0
+ elif dtype == tl.float8e4nv:
+ return 448.0
+ else:
+ tl.static_assert(False, f"Invalid {dtype=}")
+
+@triton.jit
+def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
+ BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
+
+ # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
+ f32_tensor = src_tensor.to(tl.float32)
+ abs_tensor = tl.abs(f32_tensor)
+ abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
+ abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
+ dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
+ if DEQUANT_SCALE_ROUNDING_MODE == 0:
+ # DequantScaleRoundingMode.ROUND_UP
+ # compute 2 ** ceil(log2(dequant_scale))
+ # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
+ # A corner case: exponent is 0xFF that will overflow but that's already
+ # NaN so assume we don't care.
+ dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
+ else:
+ # DequantScaleRoundingMode.ROUND_DOWN
+ # compute 2 ** floor(log2(dequant_scale))
+ assert DEQUANT_SCALE_ROUNDING_MODE == 1
+ dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
+ dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
+ quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
+
+ f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ quant_tensor = f32_tensor * quant_scale
+
+ # Reshape the tensors after scaling
+ quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
+ # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
+ quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
+ dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
+
+ # First, we simply extract the exponent part of the scales and store the result
+ dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
+ # Now we must convert the tensors to the mx format.
+ if is_fp8:
+ out_tensor = quant_tensor.to(mx_tensor_dtype)
+ else:
+ quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
+ signs = quant_tensor & 0x80000000
+ exponents = (quant_tensor >> 23) & 0xFF
+ mantissas = (quant_tensor & 0x7FFFFF)
+
+ # 0.25 <= x < 0.75 maps to 0.5, a denormal number
+ E8_BIAS = 127
+ E2_BIAS = 1
+ # Move implicit bit 1 at the beginning to mantissa for denormals
+ adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
+ mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
+
+ # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
+ exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
+
+ # Combine sign, exponent, and mantissa, while saturating
+ # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
+ e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
+ e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
+
+ e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
+ evens, odds = tl.split(e2m1_value)
+ out_tensor = evens | (odds << 4)
+
+ return out_tensor, dequant_scale_exponent
+
+@triton.jit
+def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
+ mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
+ src_ptr, stride_src_outer, stride_src_quant,
+ outer_dim, quant_dim,
+ BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
+
+ tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
+
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
+ tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
+ f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
+
+ src_dtype: tl.constexpr = src_ptr.dtype.element_ty
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
+ tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
+
+ outer_block = tl.program_id(0).to(tl.int64)
+ quant_block = tl.program_id(1).to(tl.int64)
+
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
+
+ start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
+ start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
+
+ src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
+ mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
+ mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
+
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
+ offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
+ offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
+
+ mask_src_quant = start_src_quant + offs_src_quant < quant_dim
+ mask_n = start_out + offs_outer < outer_dim
+ full_mask_src = mask_src_quant & mask_n
+
+ mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
+ full_mask_mxt = mask_mxt_quant & mask_n
+
+ scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
+ full_scale_mask = scale_mask_k & mask_n
+
+ src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
+ mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
+ mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
+ src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
+
+ out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
+ DEQUANT_SCALE_ROUNDING_MODE)
+
+ tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
+ tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
+
+
+@triton.jit(repr=lambda _: "_dequantize_mxfp8")
+def _quantize_mxfp8_fn(input, mask, pid=None):
+ return _compute_quant_and_scale(input, mask, tl.float8e4nv)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e5f027fa986c06f402405a4a5047b649b3e1bfe
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
@@ -0,0 +1,125 @@
+import triton
+import triton.language as tl
+
+from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
+
+
+# fmt: off
+@triton.jit
+def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
+ stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
+ outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
+
+ tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
+ dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
+ tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 or dst_dtype == tl.float32)
+ tl.static_assert(
+ mx_tensor_dtype == tl.uint8
+ or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
+ "mx_tensor_ptr must be uint8 or float8 or dst_dtype")
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
+
+ # Determine if we are dealing with fp8 types.
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
+
+ # Compute starting indices for the quantized (packed) dimension and the outer dimension.
+ outer_block = tl.program_id(0).to(tl.int64)
+ quant_block = tl.program_id(1).to(tl.int64)
+
+ start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
+ start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
+
+ mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
+ mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
+ out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
+
+ # Compute offsets and masks.
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
+ offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
+ offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
+
+ mask_outer = start_out + offs_outer < outer_dim
+ mask_out_quant = start_out_quant + offs_out_quant < quant_dim
+ full_mask_out = mask_out_quant & mask_outer
+
+ mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
+ full_mask_src = mask_src_quant & mask_outer
+
+ mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
+ full_scale_mask = mask_scale & mask_outer
+
+ tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
+ scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
+ out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
+
+ # Load the packed tensor and scale.
+ tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
+ scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
+
+ # Upcast the scale to the destination type.
+ if dst_dtype == tl.bfloat16:
+ dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
+ else:
+ dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
+ if dst_dtype == tl.float16:
+ dst_scale = dst_scale.to(tl.float16)
+
+ # Now upcast the tensor.
+ intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
+ if is_fp8:
+ dst_tensor = tensor.to(intermediate_dtype)
+ if tensor.dtype == tl.float8e5:
+ from_e_bits: tl.constexpr = 5
+ from_m_bits: tl.constexpr = 2
+ to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
+ to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
+
+ # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
+ non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
+ non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
+ dst_tensor = tl.where(
+ (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
+ (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
+ dst_tensor,
+ )
+ else:
+ assert is_fp4
+ dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
+ dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
+ dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
+ # e2m1
+ em0 = tensor & 0x07
+ em1 = tensor & 0x70
+ x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
+ x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
+ # Three cases:
+ # 1) x is normal and non-zero: Correct bias
+ x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
+ x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
+ # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
+ x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
+ x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
+ # 3) x is zero, do nothing
+ dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
+ dst_tensor = dst_tensor.to(dst_dtype)
+
+ # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
+ dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
+ dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
+ scale = scale.reshape(dst_scale.shape)
+
+ out_tensor = dst_tensor * dst_scale
+ # Correct any NaNs encoded via the scale.
+ out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
+ out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
+ tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/proton_opts.py b/vllm/kvprune_legacy_save/triton_kernels/proton_opts.py
new file mode 100644
index 0000000000000000000000000000000000000000..a187eecc2d66659c278be3668e7865ee8a785694
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/proton_opts.py
@@ -0,0 +1,19 @@
+# proton options
+
+import os
+
+_launch_metadata_allow_sync = None
+
+
+def launch_metadata_allow_sync():
+ global _launch_metadata_allow_sync
+ if _launch_metadata_allow_sync is None:
+ _launch_metadata_allow_sync = not (
+ os.getenv("PROTON_LAUNCH_METADATA_NOSYNC") == "1"
+ )
+ return _launch_metadata_allow_sync
+
+
+def set_launch_metadata_allow_sync(allow_sync: bool):
+ global _launch_metadata_allow_sync
+ _launch_metadata_allow_sync = allow_sync
diff --git a/vllm/kvprune_legacy_save/triton_kernels/reduction_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/reduction_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/reduction_details/reduce_bitmatrix.py b/vllm/kvprune_legacy_save/triton_kernels/reduction_details/reduce_bitmatrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..398482c321e119dfeb059fc420341ca58d1cceb1
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/reduction_details/reduce_bitmatrix.py
@@ -0,0 +1,133 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def vpopc(x):
+ """
+ Vertical popcount
+ Input x : uint32[..., N]
+ Output y : uint32[..., 32]
+ semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
+ credits: @apgoucher
+ """
+
+ tl.static_assert(
+ x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers"
+ )
+
+ BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
+ BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
+ if BLOCK_N >= 8:
+ sa1: tl.constexpr = 8
+ else:
+ sa1: tl.constexpr = BLOCK_N
+ # create 8-way sums in 4-bit fields:
+ y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1])
+ y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111
+ y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4]
+ if BLOCK_N >= 128:
+ sa2: tl.constexpr = 16
+ else:
+ sa2: tl.constexpr = BLOCK_N // sa1
+ # create 128-way sums in 8-bit fields:
+ y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4])
+ y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0F0F0F0F
+ y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
+ sa3: tl.constexpr = BLOCK_N // (sa1 * sa2)
+ # create N-way sums in 32-bit fields:
+ y = tl.reshape(y, [BATCHES, 1, sa3, 8])
+ y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000FF
+ y = tl.sum(y, 2) # [BATCHES, 4, 8]
+ y = tl.reshape(y, x.shape[:-1] + [32])
+ return y
+
+
+@triton.jit
+def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr):
+ pid = tl.program_id(0)
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
+ tl.store(Ret + offs, 0)
+
+
+@triton.jit
+def _sum_bitmatrix_rows(
+ B,
+ shape_bm,
+ stride_bm: tl.constexpr,
+ stride_bn: tl.constexpr, # input bitmatrix
+ Ret,
+ Partials,
+ stride_pm: tl.constexpr,
+ stride_pn,
+ shape_pn, # outputs
+ BLOCK_MM: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ tl.static_assert(BLOCK_MM % BLOCK_M == 0)
+ TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M
+ if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr():
+ shape_bm = tl.load(shape_bm)
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM)
+ offs_n = pid_n * 32 + tl.arange(0, 32)
+ n_rows = shape_bm
+ bits = tl.load(
+ B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0
+ )
+ bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M])
+ ret = vpopc(bits) # [TILE_SIZE, 32]
+
+ offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
+
+ tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed")
+ tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)
+
+
+def clear_sums(n_cols, device, MEMSET_BLOCK=512):
+ cdiv = triton.cdiv
+ blocks = cdiv(n_cols, MEMSET_BLOCK)
+ out_ret = torch.empty((blocks * MEMSET_BLOCK,), device=device, dtype=torch.int32)
+ _sum_bitmatrix_memset[(blocks,)](out_ret, MEMSET_BLOCK)
+ return out_ret
+
+
+def sum_bitmatrix_rows(x, out_ret, partials_block_size=None):
+ assert partials_block_size is not None
+ cdiv = triton.cdiv
+ PARTIALS_BLOCK_M = partials_block_size
+ n_rows, n_cols = x.shape
+ n_rows_max = x.shape_max[0]
+ assert out_ret.shape == (n_cols,)
+
+ TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
+ BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
+
+ pids_x = cdiv(n_rows_max, BLOCK_MM)
+ pids_y = cdiv(n_cols, 32)
+ out_partials = torch.empty(
+ (pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32
+ )
+ out_partials = torch.transpose(out_partials, 0, 1)
+
+ # output tensors
+ _sum_bitmatrix_rows[(pids_x, pids_y)](
+ x.storage.data,
+ n_rows,
+ x.stride(0),
+ x.stride(1), # input
+ out_ret, # output [final reduction]
+ out_partials,
+ out_partials.stride(0),
+ out_partials.stride(1),
+ out_partials.shape[1], # output [partial reductions]
+ BLOCK_M=PARTIALS_BLOCK_M,
+ BLOCK_MM=BLOCK_MM, # constants
+ num_warps=8,
+ )
+
+ out_partials = out_partials[: cdiv(n_rows_max, PARTIALS_BLOCK_M), :]
+
+ return out_ret, out_partials
diff --git a/vllm/kvprune_legacy_save/triton_kernels/routing.py b/vllm/kvprune_legacy_save/triton_kernels/routing.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd736f6f0867b95c67a3c857b4f0bcc80c79fc0
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/routing.py
@@ -0,0 +1,521 @@
+import torch
+import triton
+from dataclasses import dataclass, field
+from .routing_details._routing_compute import _combined_routing_compute
+from .routing_details._routing_compute import _combined_routing_memset
+from .routing_details._routing_compute import _routing_clear_bitmatrix
+from .routing_details._expt_data import _expt_data_memset
+from .routing_details._expt_data import _expt_data_compute
+from .target_info import is_hip
+
+
+@dataclass
+class GatherIndx:
+ """
+ Indices for an operation that performs:
+ Y = X[src_idx, :]
+ """
+
+ # array such that `dst_idx[src_idx] = arange(0, N)`
+ src_indx: torch.Tensor
+ dst_indx: torch.Tensor
+
+
+@dataclass
+class ScatterIndx:
+ """
+ Indices for an operation that performs:
+ Y[dst_idx, :] = X
+ """
+
+ # array such that `dst_idx[src_idx] = arange(0, N)`
+ src_indx: torch.Tensor
+ dst_indx: torch.Tensor
+
+
+@dataclass
+class ExptData:
+ # hist[i] is the number of tokens routed to expert i
+ hist: torch.Tensor
+ # token_offs_raw[i] is the offset of the first token routed
+ # to expert i in an expert-sorted array
+ token_offs_raw: torch.Tensor
+ # token_offs_pad[block][i] is the offset of the first token routed
+ # to expert i in an expert-sorted array, assuming histogram
+ # rounded to the next multiple of `block`
+ token_offs_pad: dict[int, torch.Tensor]
+ # block_id_map[block] contain one value for each `pid`` launched by
+ # the matrix multiplication kernel launched with BLOCK_M=block:
+ # - the value is -1 if the `pid` has no work to do
+ # - otherwise, the value is two int16 (packed as an int32) that
+ # correspond respectively to (1) the expert assigned to
+ # the tokens processed by this pid; (2) the block assigned to the
+ # tokens processed by this pid (think `pid_m` in a regular matmul)
+ # see `test_routing.py` for a reference implementation and more details
+ block_pid_map: dict[int, torch.Tensor]
+
+ def __post_init__(self):
+ if self.hist is not None:
+ assert self.hist.dtype == torch.int32
+ if self.token_offs_raw is not None:
+ assert self.token_offs_raw.dtype == torch.int32
+ if self.token_offs_pad is not None:
+ for v in self.token_offs_pad.values():
+ assert v.dtype == torch.int32
+ if self.block_pid_map is not None:
+ for v in self.block_pid_map.values():
+ assert v.dtype == torch.int32
+
+
+@dataclass
+class RoutingData:
+ gate_scal: torch.Tensor = field()
+ expt_hist: torch.Tensor = field()
+ n_expts_tot: int = field()
+ n_expts_act: int = field()
+ expt_data: ExptData = None
+
+ # Used to make perf annotation cleaner: when we use expert sharding, we can
+ # use this to tell the "expected" number of local tokens per expert, because
+ # the actual number can vary per each input.
+ expected_tokens_per_expt: int = field(default=None)
+
+ def n_blocks(self, n_rows, block_m):
+ if n_rows <= self.n_expts_tot:
+ return n_rows
+ else:
+ return (
+ triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m)
+ + self.n_expts_tot
+ - 1
+ )
+
+
+# --------------------------
+# sort tokens by expert
+# --------------------------
+
+
+class SortTokens(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
+ HIST_BLOCK_M = 32
+ INDX_OFFS_BLOCK_M = 512
+ MEMSET_BLOCK = 1024
+ cdiv = triton.cdiv
+
+ device = expt_scal.device
+ dtype = expt_scal.dtype
+ n_tokens_raw, _ = bitmatrix.shape
+ n_tokens_pad, n_expts_act = expt_scal.shape
+ n_gates_pad = n_tokens_pad * n_expts_act
+
+ hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
+ hist = hist[:n_expts_tot]
+ assert hist.dtype == torch.int32
+ # scratchpad
+ expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
+ combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
+ # output
+ topk_indx = combined_indx[:n_gates_pad]
+ gate_indx = combined_indx[n_gates_pad:]
+ gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
+
+ (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1a,
+ blocks2a,
+ MEMSET_BLOCK_A,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ ) = _compute_expt_data_internal(hist, n_expts_tot, n_gates_pad)
+
+ blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
+ blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
+
+ _combined_routing_memset[(blocks1a + blocks1b,)](
+ combined_indx,
+ n_gates_pad * 2,
+ -1,
+ MEMSET_BLOCK,
+ hist, #
+ expt_offs,
+ hist.shape[0],
+ n_expts_tot,
+ partial_hist, # inputs
+ partial_hist.shape[0],
+ partial_hist.stride(0),
+ partial_hist.stride(1), # outputs
+ token_offs_combined,
+ token_offs_combined.stride(0), #
+ blocks1a,
+ block_pid_map, #
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
+ BLOCK_N=512,
+ BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
+ )
+
+ indx_offs = partial_hist
+
+ _combined_routing_compute[(blocks2a + blocks2b,)](
+ topk_indx,
+ gate_indx,
+ gate_scal, # outputs
+ expt_scal,
+ expt_indx,
+ indx_offs,
+ indx_offs.stride(0),
+ indx_offs.stride(1), # inputs
+ expt_offs,
+ n_tokens_raw, # input shape
+ HIST_BLOCK_M,
+ n_expts_act, # constants
+ hist,
+ token_offs_pad,
+ token_offs_pad.stride(0),
+ block_pid_map,
+ block_pid_map.stride(0), # outputs
+ block_m_log2_start,
+ block_m_num,
+ HIST2_BLOCK_M,
+ blocks2a, # etc.
+ )
+
+ ctx.n_tokens_raw = n_tokens_raw
+ ctx.n_tokens_pad = n_tokens_pad
+ ctx.n_expts_act = n_expts_act
+ ctx.save_for_backward(gate_indx)
+ return (
+ hist,
+ topk_indx,
+ gate_indx,
+ gate_scal,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ )
+
+ @staticmethod
+ def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
+ (gate_indx,) = ctx.saved_tensors
+ dgate_scal = dgate_scal[gate_indx]
+ dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
+ return dgate_scal, None, None, None
+
+
+def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix):
+ return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix)
+
+
+# --------------------------
+# prune routing
+# --------------------------
+
+
+class PruneRouting(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
+ from .compaction import compaction
+
+ n_tokens_pad = expt_scal.shape[0]
+ assert n_expts_tot % simulated_ep == 0
+ _routing_clear_bitmatrix[(n_tokens_pad,)](
+ bitmatrix.storage.data,
+ bitmatrix.storage.data.stride(0),
+ bitmatrix.storage.data.stride(1),
+ bitmatrix.storage.data.shape[1],
+ n_expts_tot // simulated_ep,
+ BLOCK_N=512,
+ )
+ # perform compaction to update expt_scal / expt_indx
+ expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
+ n_expts_tot = n_expts_tot // simulated_ep
+ bitmatrix.shape[-1] = n_expts_tot
+ return expt_scal, expt_indx, bitmatrix
+
+
+def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
+ return PruneRouting.apply(
+ expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep
+ )
+
+
+# --------------------------
+# expt_data
+# --------------------------
+
+
+def log2_power_of_two(x):
+ assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two"
+ return x.bit_length() - 1
+
+
+block_m_log2_start = 4
+
+
+def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
+ MEMSET_BLOCK = 512
+ HIST2_BLOCK_M = 512
+ device = expt_hist.device
+ n_expts_tot = n_expts_tot
+ cdiv = triton.cdiv
+ # block_ms are all powers-of-two between 16 and 128 (inclusive)
+ block_m_log2_end = 9 if is_hip() else 8
+ block_m_num = block_m_log2_end - block_m_log2_start
+ if n_gates <= n_expts_tot:
+ max_n_tiles = n_gates
+ else:
+ max_n_tiles = (
+ n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start)
+ )
+ # allocate memory
+ pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
+ dtype = torch.int32
+
+ token_offs_combined = torch.empty(
+ (block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device
+ )
+
+ token_offs_raw = token_offs_combined[0][: n_expts_tot + 1]
+ token_offs_pad = token_offs_combined[1:]
+
+ block_pid_map = torch.empty(
+ (block_m_num, pad(max_n_tiles)), dtype=dtype, device=device
+ )
+ memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
+ # compute outputs
+ token_offs_pad = token_offs_pad[:, : n_expts_tot + 1]
+ block_pid_map = block_pid_map[:, :max_n_tiles]
+
+ blocks1 = memset_grid + block_m_num + 1
+ blocks2 = n_expts_tot * block_m_num
+ return (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1,
+ blocks2,
+ MEMSET_BLOCK,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ )
+
+
+def _unpack_into_dict(x):
+ block_m_log2_end = block_m_log2_start + x.shape[0]
+ x = {
+ 2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))
+ }
+ return x
+
+
+def compute_expt_data(expt_hist, n_expts_tot, n_gates):
+ if expt_hist is None:
+ return ExptData(None, None, None, None)
+
+ # this just computes the kernel arguments:
+ (
+ token_offs_combined,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ blocks1,
+ blocks2,
+ MEMSET_BLOCK,
+ HIST2_BLOCK_M,
+ block_m_log2_start,
+ block_m_num,
+ ) = _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates)
+
+ _expt_data_memset[(blocks1,)](
+ expt_hist,
+ n_expts_tot, #
+ token_offs_combined,
+ token_offs_combined.stride(0), #
+ block_pid_map, #
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK=MEMSET_BLOCK, # optimization parameters
+ num_warps=4,
+ )
+ _expt_data_compute[(blocks2,)](
+ expt_hist,
+ token_offs_pad,
+ token_offs_pad.stride(0),
+ block_pid_map,
+ block_pid_map.stride(0), # outputs
+ block_m_log2_start,
+ SIZES=block_m_num,
+ BLOCK=HIST2_BLOCK_M, # optimization parameters
+ num_warps=4,
+ )
+
+ token_offs_pad = _unpack_into_dict(token_offs_pad)
+ block_pid_map = _unpack_into_dict(block_pid_map)
+ return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+
+# --------------------------
+# routing
+# --------------------------
+
+
+def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
+ (
+ hist,
+ topk_indx,
+ gate_indx,
+ gate_scal,
+ token_offs_raw,
+ token_offs_pad,
+ block_pid_map,
+ ) = sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix)
+ token_offs_pad = _unpack_into_dict(token_offs_pad)
+ block_pid_map = _unpack_into_dict(block_pid_map)
+ expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+ # pack the matmul data structure
+ gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
+ scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
+ return (
+ RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data),
+ gather_indx,
+ scatter_indx,
+ )
+
+
+def routing(
+ logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None
+):
+ from .topk import topk
+
+ if sm_first:
+ logits = torch.softmax(logits, dim=-1)
+ expt_scal, expt_indx, bitmatrix = topk(
+ logits,
+ n_expts_act, #
+ apply_softmax=not sm_first,
+ y_indx=expt_indx,
+ n_rows=n_rows,
+ )
+ n_expts_tot = logits.shape[-1] // simulated_ep
+ # mutate bitmatrix
+ if simulated_ep > 1:
+ expt_scal, expt_indx, bitmatrix = prune_routing(
+ expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep
+ )
+
+ return routing_from_bitmatrix(
+ bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
+ )
+
+
+# --------------------------
+# torch reference
+# --------------------------
+
+
+def compute_expt_data_torch(hist, n_expts_tot, n_gates):
+ # offset for each experts
+ device = hist.device
+ token_offs_raw = torch.cumsum(hist, dim=0)
+ token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw))
+ token_offs_raw = token_offs_raw.int()
+ # maximum number of tiles for all values of `block_m` considered
+ block_ms = [16, 32, 64, 128]
+ if is_hip():
+ block_ms.append(256)
+ if n_gates <= n_expts_tot:
+ max_n_tiles = n_gates
+ else:
+ # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
+ # ceil_div(x, y): -(-x // y)
+ max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms))
+ # fill up tile offset/infos for each block
+ token_offs_pad = dict()
+ block_pid_map = dict()
+ for block_m in block_ms:
+ n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
+ token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
+ token_offs_pad[block_m] = torch.cat(
+ (torch.zeros(1, device=device), token_offs_pad[block_m])
+ )
+ token_offs_pad[block_m] = token_offs_pad[block_m].int()
+ # compute data required to drive ragged batch matmul
+ block_pid_map[block_m] = -torch.ones(
+ max_n_tiles, dtype=torch.int32, device=device
+ )
+
+ # for e in range(n_expts_tot):
+ # offset = token_offs_pad[block_m][e]
+ # for b in range(n_tiles[e]):
+ # block_pid_map[block_m][offset + b] = (b << 16) + e
+
+ col = torch.arange(max_n_tiles, device=device)
+ map_vals = (
+ torch.arange(n_expts_tot, device=device)[:, None] + (col << 16)[None, :]
+ )
+ map_idxs = token_offs_pad[block_m][:-1, None] + col[None, :]
+ mask = col[None, :] < n_tiles[:, None]
+ block_pid_map[block_m].index_put_((map_idxs[mask],), map_vals.int()[mask])
+ return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
+
+
+def topk_torch(vals, k, expt_indx, has_user_provided_indx=False):
+ # topk of experts
+ if has_user_provided_indx:
+ tk_indx = expt_indx
+ else:
+ tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
+ tk_indx = tk_indx.long()
+ tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
+ tk_indx = tk_indx.int()
+ return tk_val, tk_indx
+
+
+def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
+ has_user_provided_indx = expt_indx is not None
+ n_gates_pad = logits.shape[0] * n_expts_act
+
+ if n_rows is not None:
+ logits = logits[:n_rows, :]
+ _, n_expts_tot = logits.shape
+ if sm_first:
+ logits = torch.softmax(logits, dim=-1)
+ expt_scal, expt_indx = topk_torch(
+ logits, n_expts_act, expt_indx, has_user_provided_indx=has_user_provided_indx
+ )
+ if not sm_first:
+ expt_scal = torch.softmax(expt_scal, dim=-1)
+ # sort each token's selections by expert
+ if not has_user_provided_indx:
+ expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
+ expt_scal = torch.gather(expt_scal, 1, sort_indices)
+ # flatten topk data
+ expt_scal = expt_scal.reshape(-1)
+ expt_indx = expt_indx.reshape(-1).to(torch.int32)
+ # sort by expert_id so experts are contiguous for the matmul
+ topk_indx = torch.argsort(expt_indx, stable=True)
+ gate_indx = torch.argsort(topk_indx, stable=True)
+ gate_scal = expt_scal[topk_indx]
+ hist = torch.histc(
+ expt_indx, bins=n_expts_tot, max=n_expts_tot - 1
+ ).int() # histogram of tokens over experts
+ # pack the matmul data structure
+ gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
+ scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
+ # compute expt_data
+ expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad)
+ return (
+ RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data),
+ gather_indx,
+ scatter_indx,
+ )
diff --git a/vllm/kvprune_legacy_save/triton_kernels/routing_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/routing_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/routing_details/_expt_data.py b/vllm/kvprune_legacy_save/triton_kernels/routing_details/_expt_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd625868fb668d1a317e193ec4d5ec24a4da6206
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/routing_details/_expt_data.py
@@ -0,0 +1,75 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _cdiv_pow2(n, log2_k):
+ return (n + ((1 << log2_k) - 1)) >> log2_k
+
+
+@triton.jit
+def _expt_data_memset(
+ Hist,
+ n_expts_tot,
+ MDStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ if pid <= SIZES:
+ MDStarts += pid * tile_starts_stridem
+ x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty)
+ Tile_ptrs = MDStarts + tl.arange(0, BLOCK)
+ tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1)
+
+ for i in range(0, n_expts_tot + 1, BLOCK):
+ offs_n = tl.arange(0, BLOCK) + i
+ mask_n0 = offs_n < n_expts_tot
+ hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0)
+ hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2)
+
+ tile_starts = tl.cumsum(hist_tile, 0) + x_tile
+ x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty)
+ tl.store(Tile_ptrs, tile_starts - hist_tile)
+ Tile_ptrs += BLOCK
+
+ else:
+ pid -= SIZES + 1
+ TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK)
+ tl.store(TileInfoOut, 0xFFFFFFFF)
+
+
+@triton.jit
+def _expt_data_compute(
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ expt_id = pid // SIZES
+ buff_id = pid % SIZES
+
+ MDTileStarts += buff_id * tile_starts_stridem
+ MDTileInfo += buff_id * tile_info_stridem
+
+ n_tokens = tl.load(Hist + expt_id)
+ tile_dim_log2 = first_tile_dim_log2 + buff_id
+ n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2)
+
+ tile_off = tl.load(MDTileStarts + expt_id)
+ MDTileInfo += tile_off
+
+ for block_off in range(0, n_blocks, BLOCK):
+ block_offs = block_off + tl.arange(0, BLOCK)
+ data = (block_offs << 16) + expt_id
+ tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/routing_details/_routing_compute.py b/vllm/kvprune_legacy_save/triton_kernels/routing_details/_routing_compute.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b097cc1cc8c1117363f031cfc9a785b94a7d5ed
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/routing_details/_routing_compute.py
@@ -0,0 +1,241 @@
+import triton
+import triton.language as tl
+
+from ._expt_data import _expt_data_compute, _expt_data_memset
+
+
+@triton.jit
+def _routing_compute_expt_offs(
+ ExpertHist,
+ FinalExpertOffs,
+ hist_size, # histogram
+ BLOCK_N: tl.constexpr,
+):
+ loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N
+ x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty)
+ for i in range(loop_iterations):
+ offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_n < hist_size
+ hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
+ tok_starts = tl.cumsum(hist2, 0) - hist2 + x
+ x += tl.sum(hist2, 0)
+ tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
+ offs_n += BLOCK_N
+
+
+@triton.jit
+def _routing_compute_indx_offs(
+ PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id
+):
+ offs_m = tl.arange(0, BLOCK_M)
+ # iterate over input data
+ curr_sum = 0
+ for _ in range(0, shape_pm, BLOCK_M):
+ offs = offs_m * stride_pm + expt_id * stride_pn
+ curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
+ out = tl.cumsum(curr, 0) + curr_sum
+ curr_sum += tl.sum(curr, 0)
+ tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm)
+ offs_m += BLOCK_M
+
+
+@triton.jit
+def _keyed_add(x, y):
+ # we keep the key in the upper 16 bits of a uint32:
+ key_mask: tl.constexpr = 0xFFFF0000
+
+ kx = x & key_mask
+ ky = y & key_mask
+ z = tl.where(kx == ky, x + y - kx, y)
+ return z
+
+
+@triton.jit
+def _routing_compute_indx(
+ pid_m,
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+):
+ if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
+ n_tokens = tl.load(n_tokens)
+ n_gates = n_tokens * N_EXPTS_ACT
+
+ tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
+
+ local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
+ expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
+
+ # stable-sort by expert ID:
+ kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
+ kv_pairs = tl.sort(kv_pairs, 0)
+ expert = kv_pairs >> 16
+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xFFFF)
+ mask = expert != 0xFFFF
+ gate_scal = tl.load(ExptScal + offs, mask=mask)
+
+ # compute run lengths in expert-sorted order:
+ x = kv_pairs & 0xFFFF0000 | 0x00000001
+ expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
+ exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF
+
+ gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask)
+ gates += tl.load(TokensStart + expert, mask=mask)
+ gates += exclusive_run_lengths
+
+ tl.store(ScatterIndx + offs, gates, mask=mask)
+ tl.store(GatherIndx + gates, offs, mask=mask)
+ tl.store(GateScal + gates, gate_scal, mask=mask)
+
+
+@triton.jit
+def _combined_routing_compute(
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK: tl.constexpr,
+ blocks2a,
+):
+ pid = tl.program_id(0)
+ if pid < blocks2a:
+ _expt_data_compute(
+ Hist,
+ MDTileStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ tile_info_stridem,
+ first_tile_dim_log2,
+ SIZES,
+ BLOCK,
+ )
+ else:
+ pid -= blocks2a
+ _routing_compute_indx(
+ pid,
+ GatherIndx,
+ ScatterIndx,
+ GateScal,
+ ExptScal,
+ ExptIndx,
+ PartialOffs,
+ stride_pm,
+ stride_pn,
+ TokensStart,
+ n_tokens,
+ BLOCK_M,
+ N_EXPTS_ACT,
+ )
+
+
+@triton.jit
+def _routing_clear_bitmatrix(
+ Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr
+):
+ pid_m = tl.program_id(0)
+ cutoff_word = cutoff // 32
+ cutoff_bit = cutoff % 32
+ cutoff_mask = (1 << (cutoff_bit)) - 1
+ for start_n in range(0, shape_bn, BLOCK_N):
+ offs_n = start_n + tl.arange(0, BLOCK_N)
+ values = tl.load(
+ Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn
+ )
+ values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
+ values = tl.where(offs_n > cutoff_word, 0, values)
+ tl.store(
+ Bitmatrix + pid_m * stride_bm + offs_n * stride_bn,
+ values,
+ mask=offs_n < shape_bn,
+ )
+
+
+@triton.jit
+def _combined_routing_memset(
+ Indx,
+ size,
+ sentinel,
+ BLOCK: tl.constexpr,
+ ExpertHist,
+ FinalExpertOffs,
+ hist_size,
+ n_expts_tot,
+ PartialHist,
+ shape_pm,
+ stride_pm,
+ stride_pn,
+ MDStarts,
+ tile_starts_stridem,
+ blocks1a,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES: tl.constexpr,
+ BLOCK_A: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ """
+ This kernel essentially combines 6 different pieces of functionality,
+ statically branching on the value of tl.program_id(0) to decide which
+ codepath to take.
+
+ pid == 0: create the token cumsum
+ 1 <= pid <= SIZES: create a tile cumsum
+ SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
+ blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
+ pid == blocks1a + n_expts_tot: compute_expt_offs
+ pid > blocks1a + n_expts_tot: initialise Indx to sentinel
+
+ As each of these is a relatively trivial workload, launching them from
+ this single trampoline is beneficial as they can execute on different
+ streaming multiprocesses in parallel.
+ """
+
+ pid = tl.program_id(0)
+
+ if pid < blocks1a:
+ _expt_data_memset(
+ ExpertHist,
+ n_expts_tot,
+ MDStarts,
+ tile_starts_stridem,
+ MDTileInfo,
+ first_tile_dim_log2,
+ SIZES,
+ BLOCK_A,
+ )
+ elif pid == n_expts_tot + blocks1a:
+ _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
+ elif pid < n_expts_tot + blocks1a:
+ _routing_compute_indx_offs(
+ PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a
+ )
+ else:
+ offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK)
+ mask = offs < size
+ tl.store(Indx + offs, sentinel, mask=mask)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/specialize.py b/vllm/kvprune_legacy_save/triton_kernels/specialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcf44d70cb47664e6a837ec4cf0d28f04fbb1c16
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/specialize.py
@@ -0,0 +1,143 @@
+import inspect
+import re
+import textwrap
+import types
+import triton
+
+
+def cacheable(f):
+ """
+ A decorator that allow you to write something of the form:
+
+ @cacheable
+ def my_kernel(): return (expression dynamically defining a kernel)
+
+ such that it interacts gracefully with triton cache and preload.
+ """
+
+ g = f()
+ g.fn.__name__ = f.__name__
+ g.fn.__module__ = f.__module__
+ g.fn.__qualname__ = f.__qualname__
+ g.__name__ = f.__name__
+ g.__module__ = f.__module__
+ g.__qualname__ = f.__qualname__
+ g._fn_name = f"{f.__module__}.{f.__qualname__}"
+ return g
+
+
+def define_kernel(src, module, attrs=None, **extra_globals):
+ """
+ Dynamically create a Triton function or kernel from a src string,
+ linking any symbols in the kernel to objects specified by extra_globals.
+ """
+
+ # create templace function
+ def _empty_fn():
+ pass
+
+ gdict = dict(**(_empty_fn.__globals__))
+ gdict.update(extra_globals)
+ f = types.FunctionType(_empty_fn.__code__, gdict)
+ f.__module__ = module.__name__
+
+ src = textwrap.dedent(src)
+ src = src[src.find("def ") :]
+
+ stored_functions = []
+ function_name = src[4:].split("(")[0].strip()
+
+ exec_globals = gdict
+ exec_globals.update({"stored_functions": stored_functions})
+ exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
+
+ f.__signature__ = inspect.signature(stored_functions[0])
+ f.__name__ = function_name
+ f.__doc__ = stored_functions[0].__doc__
+
+ if attrs is None:
+ attrs = dict()
+ f = triton.JITFunction(f, **attrs)
+ f._unsafe_update_src(src)
+ return f
+
+
+def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
+ if name is None:
+ name = f"{fn.__name__}"
+ # Get original source code
+ src = inspect.getsource(fn.fn)
+ src = textwrap.dedent(src)
+ lines = src.split("\n")
+ # Skip decorator and def line
+ def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def"))
+ # separate header vs body LOC
+ header_end = def_idx
+ while not lines[header_end].rstrip().endswith(":"):
+ header_end += 1
+ body_lines = lines[header_end + 1 :]
+ header_lines = lines[def_idx : header_end + 1]
+ # clean-up header
+ header_clean = [
+ l.split("#", 1)[0].strip() # keep code, discard comment
+ for l in header_lines
+ if l.split("#", 1)[0].strip() # skip blank‑after‑comment lines
+ ]
+ # decompose arguments
+ header_src = " ".join(header_clean) # turn it into a single line
+ m = re.search(r"\((.*)\)\s*:", header_src)
+ if not m:
+ raise ValueError("Could not parse function header")
+ args_str = m.group(1)
+ args = [arg.strip() for arg in args_str.split(",") if arg.strip()]
+ non_specialized_args = []
+ for arg in args:
+ arg_key = arg.split(":")[0].split("=")[0].strip()
+ new_args = tuples.get(arg_key, [arg])
+ if arg_key not in constants:
+ non_specialized_args += new_args
+ # add global symbols
+ spec_fns = {
+ v.__name__: v
+ for k, v in constants.items()
+ if isinstance(v, triton.runtime.jit.JITFunction)
+ }
+ globals = spec_fns | fn.get_capture_scope()
+ # build new source code and define kernel dynamically
+ new_signature = f"def {name}({', '.join(non_specialized_args)}):"
+ constexpr_lines = [
+ f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}"
+ for key, value in constants.items()
+ ]
+ tuple_lines = [
+ f" {key} = {'(' + ','.join(value) + (',' if len(value) >= 1 else '') + ')'}"
+ for key, value in tuples.items()
+ ]
+ new_src = "\n".join(
+ ["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines
+ )
+ # find function parameters
+ sig = inspect.signature(triton.runtime.jit.JITFunction.__init__)
+ params = list(sig.parameters.values())[2:]
+ attrs = {param.name: getattr(fn, param.name, param.default) for param in params}
+
+ # make a new repr which appends the repr of the specialized functions.
+ base_repr = attrs["repr"]
+
+ def new_repr(specialization):
+ ret = base_repr(specialization)
+ for spec_fn in spec_fns.values():
+ spec_repr = spec_fn.repr(None)
+ if spec_repr:
+ spec_repr = spec_repr.strip("_")
+ if spec_repr:
+ ret += f"_{spec_repr}"
+ return ret
+
+ attrs["repr"] = new_repr
+
+ if do_not_specialize:
+ attrs["do_not_specialize"] = do_not_specialize
+ ret = define_kernel(new_src, module, attrs, **globals)
+ return ret
diff --git a/vllm/kvprune_legacy_save/triton_kernels/swiglu.py b/vllm/kvprune_legacy_save/triton_kernels/swiglu.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e1af2c4a0191ee2897d8d37fbd915cc77f07ec
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/swiglu.py
@@ -0,0 +1,99 @@
+from dataclasses import dataclass
+from vllm.kvprune.triton_kernels.numerics import InFlexData, OutFlexData
+import torch
+import triton
+from .swiglu_details._swiglu import _swiglu, _swiglu_fn
+from vllm.kvprune.triton_kernels import target_info
+
+
+@dataclass(frozen=True)
+class FlexCtx:
+ out_data: OutFlexData = OutFlexData()
+ inp_data: InFlexData = InFlexData()
+ saturate_inf: bool = False
+
+
+@dataclass(frozen=True)
+class PrecisionConfig:
+ limit: float
+ flex_ctx: FlexCtx = FlexCtx()
+
+
+swiglu_fn = _swiglu_fn
+
+
+class SwiGLU(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, a, alpha, precision_config, routing_data):
+ N = a.shape[-1]
+ M = a.numel() // N
+ assert a.stride()[-1] == 1
+ assert a.shape[-1] % 2 == 0
+ out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
+ flex_ctx = precision_config.flex_ctx
+ # optimization hyperparameters
+ BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
+ num_warps = 4
+ kwargs = {"maxnreg": 64} if not target_info.is_hip() else {}
+ # launch semi-persistent kernel
+ N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
+ num_sms = target_info.num_sms()
+ if routing_data is not None:
+ waves_per_sm = 32 if target_info.is_hip() else 128
+ num_pid = num_sms * (waves_per_sm // num_warps)
+ M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
+ grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms),)
+ else:
+ M_BLOCKS = triton.cdiv(M, BLOCK_M)
+ if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
+ grid = (8 * num_sms,)
+ else:
+ grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms),)
+ n_tokens = None
+ if routing_data is not None:
+ n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
+ _swiglu[grid](
+ flex_ctx.out_data.reinterpret(out),
+ flex_ctx.out_data.expected_scale,
+ flex_ctx.out_data.actual_scale,
+ flex_ctx.out_data.checksum_scale,
+ flex_ctx.inp_data.reinterpret(a),
+ flex_ctx.inp_data.scale,
+ alpha,
+ M,
+ N // 2,
+ a.shape[-1],
+ 1,
+ out.shape[-1],
+ 1,
+ precision_config.limit,
+ n_tokens,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ EVEN_N=(N // 2) % BLOCK_N == 0,
+ M_BLOCKS=M_BLOCKS,
+ N_BLOCKS=N_BLOCKS,
+ flexpoint_saturate_inf=flex_ctx.saturate_inf,
+ num_warps=num_warps,
+ **kwargs,
+ )
+ out = out.view(a.shape[:-1] + out.shape[-1:])
+ return out
+
+
+def swiglu(a, alpha, precision_config, routing_data=None):
+ return SwiGLU.apply(a, alpha, precision_config, routing_data)
+
+
+def swiglu_torch(a, alpha, precision_config):
+ limit = precision_config.limit
+ a_gelu = a[..., ::2]
+ if limit is not None:
+ a_gelu = a_gelu.clamp(max=limit)
+ a_linear = a[..., 1::2]
+ if limit is not None:
+ a_linear = a_linear.clamp(min=-limit, max=limit)
+
+ out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
+ out = out_gelu * (a_linear + 1)
+ return out
diff --git a/vllm/kvprune_legacy_save/triton_kernels/swiglu_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/swiglu_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/swiglu_details/_swiglu.py b/vllm/kvprune_legacy_save/triton_kernels/swiglu_details/_swiglu.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e5546330498d2a29241068ab3f3fef7021250fa
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/swiglu_details/_swiglu.py
@@ -0,0 +1,141 @@
+from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
+ load_scale,
+ float_to_flex,
+ update_scale,
+)
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def clip(x, limit, clip_lower: tl.constexpr):
+ res = tl.minimum(x, limit)
+ if clip_lower:
+ res = tl.maximum(-limit, res)
+ return res
+
+
+@triton.jit
+def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
+ return tl.max(
+ tl.reshape(
+ tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True
+ ),
+ axis=1,
+ )
+
+
+def swiglu_repr(specialization):
+ signature = specialization.signature
+ constants = specialization.constants
+ convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype
+ dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]])
+ blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]])
+ return f"_swiglu_{dtypes}_{blocks}"
+
+
+def swiglu_launch_metadata(grid, kernel, args):
+ M, N = args["M"], args["N"]
+ ret = dict()
+ ret["name"] = f"{kernel.name} [M = {M}, N = {N}]"
+ A, Out = args["A"], args["Out"]
+ ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size()
+ return ret
+
+
+@triton.jit
+def compute_swiglu(gelu, linear, scale, alpha, limit):
+ gelu = gelu.to(tl.float32) * scale
+ if limit is not None:
+ gelu = clip(gelu, limit, clip_lower=False)
+ linear = linear.to(tl.float32) * scale
+ if limit is not None:
+ linear = clip(linear, limit, clip_lower=True)
+ s = gelu / (1 + tl.exp(-alpha * gelu))
+ return tl.fma(s, linear, s) # (s * (linear + 1))
+
+
+@triton.jit(repr=lambda _: "_swiglu")
+def _swiglu_fn(input, alpha, limit):
+ gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
+ return compute_swiglu(gelu, linear, 1.0, alpha, limit)
+
+
+@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
+def _swiglu(
+ Out,
+ OutExpectedScale,
+ OutActualScale,
+ OutChecksumScale,
+ A,
+ AScale,
+ alpha,
+ M,
+ N,
+ stride_am,
+ stride_an,
+ stride_outm,
+ stride_outn,
+ limit: tl.constexpr,
+ NTokens,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ EVEN_N: tl.constexpr,
+ M_BLOCKS,
+ N_BLOCKS,
+ flexpoint_saturate_inf: tl.constexpr,
+):
+ if NTokens is not None:
+ M = tl.load(NTokens)
+ M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
+
+ local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32)
+
+ a_scale = load_scale(AScale)
+ out_expected_scale = load_scale(OutExpectedScale)
+
+ for pid in tl.range(
+ tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2
+ ):
+ pid_m = pid // N_BLOCKS
+ pid_n = pid % N_BLOCKS
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_m = off_m < M
+ mask_n = off_n < N
+ packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2
+ packed_mask_n = packed_off_n < N
+ packed_mask_n = tl.max_constancy(packed_mask_n, [16])
+ # load a
+ packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N)
+ packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an
+ if EVEN_N:
+ a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.0)
+ else:
+ if pid_n * BLOCK_N + BLOCK_N <= N:
+ a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.0)
+ else:
+ packed_mask = mask_m[:, None] & packed_mask_n[None, :]
+ a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.0)
+ a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
+ out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit)
+ # update flexpoint stats and divide by scale
+ # we don't need masking because of the `other` when loading `A`
+ if OutActualScale is not None:
+ absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads())
+ local_max = tl.maximum(local_max, absmax)
+ out = float_to_flex(
+ out,
+ out_expected_scale,
+ None, # ActualScale: local absmax is tracked and updated after the loop
+ OutChecksumScale,
+ None,
+ Out,
+ flexpoint_saturate_inf,
+ )
+ mask = mask_m[:, None] if EVEN_N else mask_m[:, None] & mask_n[None, :]
+ tl.store(
+ Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask
+ )
+
+ update_scale(local_max, OutActualScale, Out)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/target_info.py b/vllm/kvprune_legacy_save/triton_kernels/target_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac308ef9781e4e6005d56d0398a1b397a015e93
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/target_info.py
@@ -0,0 +1,116 @@
+import torch
+import triton
+
+# ``constexpr_function`` moved across Triton versions; ROCm/vendor wheels often
+# only expose ``triton.constexpr_function`` (not ``triton.runtime.jit``).
+def _resolve_constexpr_function():
+ fn = getattr(triton, "constexpr_function", None)
+ if fn is not None:
+ return fn
+ try:
+ from triton.runtime.jit import constexpr_function as _fn
+
+ return _fn
+ except ImportError:
+ pass
+ _jit = getattr(triton, "jit", None)
+ if _jit is not None:
+ fn = getattr(_jit, "constexpr_function", None)
+ if fn is not None:
+ return fn
+ raise ImportError(
+ "Cannot resolve Triton constexpr_function (try: pip install -U triton)"
+ )
+
+
+constexpr_function = _resolve_constexpr_function()
+
+__all__ = [
+ "cuda_capability_geq",
+ "get_cdna_version",
+ "has_tma_gather",
+ "has_native_mxfp",
+ "is_cuda",
+ "is_hip",
+ "is_hip_cdna3",
+ "is_hip_cdna4",
+ "num_sms",
+]
+
+try:
+ from triton.language.target_info import (
+ cuda_capability_geq,
+ current_target,
+ is_cuda,
+ is_hip,
+ is_hip_cdna3,
+ is_hip_cdna4,
+ )
+except ImportError:
+ # Some ROCm / vendor Triton wheels omit ``triton.language.target_info``.
+ # Mirror upstream Triton (see triton/language/target_info.py) via runtime.
+ from triton.runtime import driver
+
+ def current_target():
+ try:
+ active_driver = driver.active
+ except RuntimeError:
+ return None
+ return active_driver.get_current_target()
+
+ @constexpr_function
+ def is_cuda():
+ target = current_target()
+ return target is not None and target.backend == "cuda"
+
+ @constexpr_function
+ def is_hip():
+ target = current_target()
+ return target is not None and target.backend == "hip"
+
+ @constexpr_function
+ def cuda_capability_geq(major, minor=0):
+ target = current_target()
+ if target is None or target.backend != "cuda":
+ return False
+ assert isinstance(target.arch, int)
+ return target.arch >= major * 10 + minor
+
+ @constexpr_function
+ def is_hip_cdna3():
+ target = current_target()
+ return target is not None and target.arch == "gfx942"
+
+ @constexpr_function
+ def is_hip_cdna4():
+ target = current_target()
+ return target is not None and target.arch == "gfx950"
+
+
+@constexpr_function
+def get_cdna_version():
+ """
+ AMD CDNA generation: 3 (gfx942) or 4 (gfx950); -1 if unknown / non-HIP.
+ """
+ target = current_target()
+ if target is None or target.backend != "hip":
+ return -1
+ if target.arch == "gfx942":
+ return 3
+ if target.arch == "gfx950":
+ return 4
+ return -1
+
+
+@constexpr_function
+def has_tma_gather():
+ return cuda_capability_geq(10, 0)
+
+
+@constexpr_function
+def has_native_mxfp():
+ return cuda_capability_geq(10, 0)
+
+
+def num_sms():
+ return torch.cuda.get_device_properties(0).multi_processor_count
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor.py b/vllm/kvprune_legacy_save/triton_kernels/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6992e942365b2cf52701be8d013f174dd4458784
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor.py
@@ -0,0 +1,227 @@
+from dataclasses import dataclass, fields
+from typing import Type
+
+import torch
+from triton.tools.tensor_descriptor import TensorDescriptor
+from triton.tools.ragged_tma import create_ragged_descriptor
+
+from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
+from .target_info import cuda_capability_geq
+from .tensor_details.layout import Layout, StridedLayout
+
+
+@dataclass
+class Storage:
+ data: torch.Tensor
+ layout: Layout = None
+
+ def __post_init__(self):
+ assert isinstance(self.data, torch.Tensor)
+ if self.layout is None:
+ self.layout = StridedLayout(self.data.shape)
+
+ @property
+ def device(self):
+ return self.data.device
+
+ def is_tma_compliant(self):
+ # TMAs didn't exist until Hopper
+ if not cuda_capability_geq(9, 0):
+ return False
+ # TMAs only exist for 2D, 3D, 5D inputs
+ if len(self.data.shape) not in [2, 3, 5]:
+ return False
+ # TMAs need at most one stride equal to 1
+ # and all other strides divisble by 16
+ strides = list(self.data.stride())
+ try:
+ major_dim = strides.index(1)
+ except ValueError:
+ major_dim = -1
+ ndim = self.data.ndim
+ bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
+ compliant = [
+ strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim
+ ]
+ return all(compliant)
+
+ def make_dense_tma(self, block_shape, transpose=False):
+ strides = list(self.data.stride())
+ shape = list(self.data.shape)
+ transpose = self.data.stride()[-1] != 1
+ if transpose:
+ block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
+ shape = shape[:-2] + [shape[-1], shape[-2]]
+ strides = strides[:-2] + [strides[-1], strides[-2]]
+ if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
+ indx = strides.index(1)
+ block_shape[indx] = block_shape[indx] // 2
+ if shape[-1] % 128 != 0:
+ raise ValueError(
+ "inner shape need to be multiple of 128 for "
+ "mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
+ )
+ block_shape = self.layout.swizzle_block_shape(block_shape)
+ return TensorDescriptor(self.data, shape, strides, block_shape)
+
+ def make_tma(self, block_shape, mode, transpose=False):
+ if mode in ["dense", "gather", "scatter"]:
+ return self.make_dense_tma(block_shape, transpose)
+ assert mode == "ragged"
+ ragged_dim = len(self.data.shape) - 2
+ return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)
+
+
+@dataclass
+class IntegerType:
+ bitwidth: int
+
+
+@dataclass
+class FloatType:
+ bitwidth_exponent: int
+ bitwidth_mantissa: int
+ is_signed: bool
+
+ def __post_init__(self):
+ self.bitwidth = (
+ int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
+ )
+
+
+BIT = IntegerType(1)
+FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
+
+
+def bitwidth(type: IntegerType | FloatType | torch.dtype):
+ if isinstance(type, torch.dtype):
+ return type.itemsize * 8
+ return type.bitwidth
+
+
+@dataclass
+class Tensor:
+ storage: Storage | torch.Tensor
+ dtype: IntegerType | FloatType | torch.dtype = None
+ shape: list[int] | None = None
+ shape_max: list[int] | None = None
+
+ def __post_init__(self):
+ # set storage
+ if isinstance(self.storage, torch.Tensor):
+ self.storage = Storage(self.storage)
+ # initialize dtype
+ if self.dtype is None:
+ self.dtype = self.storage.data.dtype
+ if bitwidth(self.dtype) < 8 and self.shape is None:
+ raise ValueError("shape must be provided for sub-byte types")
+ # initialize shape
+ if self.shape is None:
+ self.shape = list(self.storage.data.shape)
+ # validate shape: all elements must be `int` or numel-1 `torch.Tensor`
+ is_int = lambda s: isinstance(s, int)
+ is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
+ assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
+ # initialize shape_max
+ if self.shape_max is None:
+ self.shape_max = [None] * len(self.shape)
+ for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
+ if smax is not None and not is_int(smax):
+ raise ValueError(
+ f"shape_max[{i}] must be `int` or `None`; got {type(smax)}"
+ )
+ if smax is None:
+ self.shape_max[i] = s
+ # validate shape_max: all elements must be `int`
+ assert all(map(is_int, self.shape_max))
+
+ # torch compatibility layer
+ @property
+ def ndim(self):
+ return len(self.shape)
+
+ @property
+ def device(self):
+ return self.storage.device
+
+ def stride(self, i=None):
+ return self.storage.data.stride() if i is None else self.storage.data.stride(i)
+
+ def data_ptr(self):
+ return self.storage.data.data_ptr()
+
+ def numel(self):
+ return self.storage.data.numel()
+
+ def element_size(self):
+ return bitwidth(self.dtype) // 8
+
+ @property
+ def data(self):
+ t = self.storage
+ return t.data if isinstance(t, Storage) else t
+
+ def dim(self):
+ return self.ndim
+
+ def size(self, i=None):
+ if i is None:
+ return self.shape
+ return self.shape[i]
+
+
+@dataclass
+class Bitmatrix(Tensor):
+ """
+ Represents a boolean matrix in a packed format where each element occupies
+ a single bit of memory.
+
+ _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
+ with the actual bitmatrix to avoid having to launch a separate memset
+ kernel when we call Bitmatrix::sum().
+ """
+
+ scratchpad: torch.Tensor = None
+
+ def __init__(self, storage, shape, shape_max=None, scratchpad=None):
+ super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
+ self.scratchpad = scratchpad
+
+ def sum(self, partials_block_size):
+ _, n_cols = self.shape
+ dev = self.device
+ if self.scratchpad is None:
+ self.scratchpad = clear_sums(n_cols, dev)
+ out_ret = self.scratchpad[:n_cols]
+ self.scratchpad = None # throw error if we try to sum again
+ return sum_bitmatrix_rows(self, out_ret, partials_block_size)
+
+
+def get_layout(tensor: torch.Tensor | Tensor | None):
+ if tensor is None:
+ return None
+ if isinstance(tensor, Tensor):
+ return tensor.storage.layout
+ return StridedLayout
+
+
+def wrap_torch_tensor(torch_tensor, dtype=None):
+ if dtype is None:
+ dtype = torch_tensor.dtype
+ shape = list(torch_tensor.shape)
+ shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(
+ dtype
+ )
+ return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
+
+
+def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
+ assert isinstance(tensor, Tensor)
+ old_storage = tensor.storage
+ old_data = old_storage.layout.unswizzle_data(old_storage.data)
+ new_layout = layout_cls(old_data.shape, **layout_kwargs)
+ new_data = new_layout.swizzle_data(old_data)
+ attrs = {
+ k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"
+ }
+ return Tensor(Storage(new_data, new_layout), **attrs)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout.py
new file mode 100644
index 0000000000000000000000000000000000000000..98122f3517a593b1bc479c43d8d64fb64191a7af
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout.py
@@ -0,0 +1,40 @@
+from .layout_details.base import Layout
+from .layout_details.blackwell_scale import BlackwellMXScaleLayout
+from .layout_details.blackwell_value import BlackwellMXValueLayout
+from .layout_details.hopper_scale import HopperMXScaleLayout
+from .layout_details.hopper_value import HopperMXValueLayout
+from .layout_details.cdna4_scale import CDNA4MXScaleLayout
+from .layout_details.strided import StridedLayout
+from ..target_info import cuda_capability_geq, is_hip_cdna4
+
+__all__ = [
+ "Layout",
+ "BlackwellMXValueLayout",
+ "BlackwellMXScaleLayout",
+ "HopperMXScaleLayout",
+ "HopperMXValueLayout",
+ "CDNA4MXScaleLayout",
+ "StridedLayout",
+]
+
+
+def make_default_matmul_mxfp4_w_layout(mx_axis: int):
+ if cuda_capability_geq(10):
+ # return StridedLayout, dict()
+ return BlackwellMXValueLayout, dict()
+ elif cuda_capability_geq(9):
+ return HopperMXValueLayout, {"mx_axis": mx_axis}
+ else:
+ return StridedLayout, dict()
+
+
+def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
+ if is_hip_cdna4():
+ return CDNA4MXScaleLayout, dict()
+ else:
+ if cuda_capability_geq(10):
+ return BlackwellMXScaleLayout, dict()
+ elif cuda_capability_geq(9):
+ return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
+
+ return StridedLayout, dict()
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/base.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d23dab8f42abd1d87bf77c08c3b64c1efe4d3e3
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/base.py
@@ -0,0 +1,18 @@
+from abc import ABC, abstractmethod
+
+
+class Layout(ABC):
+ def __init__(self, shape) -> None:
+ self.initial_shape = shape
+
+ @abstractmethod
+ def swizzle_data(self, data):
+ pass
+
+ @abstractmethod
+ def unswizzle_data(self, data):
+ pass
+
+ @abstractmethod
+ def swizzle_block_shape(self, block_shape):
+ pass
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/blackwell_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..a54a300cfdd906dec1a78aaf4f48259529659cdf
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/blackwell_scale.py
@@ -0,0 +1,81 @@
+import math
+import triton
+import triton.language as tl
+import torch
+from .base import Layout
+
+SWIZZLE_ALIGN_INNER = 8
+SWIZZLE_SIZE_INNER = 4
+SWIZZLE_SIZE_OUTER = 128
+
+
+class BlackwellMXScaleLayout(Layout):
+ name: str = "BLACKWELL_SCALE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+ (
+ *self.leading_shape,
+ self.K,
+ self.N,
+ ) = shape
+ self.B = math.prod(self.leading_shape)
+ self.ALIGN_K = 8
+ self.ALIGN_N = 128
+ self.SWIZZLE_K = 4
+ self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K
+ self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N
+
+ def swizzle_data(self, data):
+ data = torch.nn.functional.pad(
+ data, (0, self.N_pad - self.N, 0, self.K_pad - self.K)
+ )
+ data = data.transpose(-1, -2).contiguous()
+ data = data.reshape(
+ self.B,
+ self.N_pad // self.ALIGN_N,
+ self.ALIGN_N // 32,
+ 32,
+ self.K_pad // self.SWIZZLE_K,
+ self.SWIZZLE_K,
+ )
+ data = data.transpose(2, 4).contiguous()
+ data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
+ return data
+
+ def unswizzle_data(self, data):
+ data = data.reshape(
+ self.B,
+ self.N_pad // self.ALIGN_N,
+ self.K_pad // self.SWIZZLE_K,
+ 32,
+ self.ALIGN_N // 32,
+ self.SWIZZLE_K,
+ )
+ data = data.transpose(2, 4)
+ data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad)
+ data = data.transpose(-1, -2)
+ return data[..., : self.K, : self.N]
+
+ def swizzle_block_shape(self, block_shape):
+ MX_PACK_DIVISOR = 32
+ MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
+ return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
+
+
+@triton.jit
+def unswizzle_mx_scale_bw(
+ x,
+ SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER,
+ SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER,
+ ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER,
+):
+ shape_0: tl.constexpr = x.shape[0]
+ shape_1: tl.constexpr = x.shape[1]
+ tl.static_assert(shape_1 % SIZE_OUTER == 0)
+ tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER)
+ x = x.reshape(
+ shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER
+ )
+ x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER)
+ return x
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/blackwell_value.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/blackwell_value.py
new file mode 100644
index 0000000000000000000000000000000000000000..622744888b91eb0c99ba6d9c7fb150acb2d89702
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/blackwell_value.py
@@ -0,0 +1,37 @@
+import torch
+from .base import Layout
+
+
+class BlackwellMXValueLayout(Layout):
+ name: str = "BLACKWELL_VALUE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+ self.shape = shape
+
+ def swizzle_data(self, data):
+ # permutation needed to make `data` row major
+ to_row_major = sorted(range(data.ndim), key=lambda d: (data.stride(d), d))[::-1]
+ # permutation needed to retrieve original order
+ inv = [0] * data.ndim
+ for i, d in enumerate(to_row_major):
+ inv[d] = i
+ # leading dimension must be padded to be aligned to 128
+ align_dim = lambda x: (x + 128 - 1) // 128 * 128
+ major_dim = data.stride().index(1)
+ pad = align_dim(data.shape[major_dim]) - data.shape[major_dim]
+ data = torch.nn.functional.pad(data.permute(to_row_major), (0, pad)).permute(
+ inv
+ )
+ return data
+
+ def unswizzle_data(self, data: torch.Tensor):
+ # Trim padding along all dims back to the original shape recorded at init.
+ assert data.ndim == len(self.shape), (
+ "Rank mismatch between data and recorded shape"
+ )
+ sizes = [min(data.size(i), self.shape[i]) for i in range(data.ndim)]
+ return data[tuple(slice(0, s) for s in sizes)]
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/cdna4_scale.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/cdna4_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..beecaee3e12d93294df0365010966e15d625635e
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/cdna4_scale.py
@@ -0,0 +1,50 @@
+import triton
+import triton.language as tl
+from .base import Layout
+
+NON_K_PRESHUFFLE_BLOCK_SIZE = 32
+
+
+class CDNA4MXScaleLayout(Layout):
+ name: str = "CDNA4_SCALE"
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+
+ def swizzle_data(self, data):
+ block_shape = data.shape
+ SCALE_K = block_shape[-2]
+ N = block_shape[-1]
+ data = data.transpose(-1, -2)
+ data = data.view(
+ -1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1
+ )
+ data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
+ if len(block_shape) == 3:
+ E = block_shape[0]
+ data = data.reshape(E, N // 32, SCALE_K * 32)
+ else:
+ assert len(block_shape) == 2
+ data = data.reshape(N // 32, SCALE_K * 32)
+ return data.transpose(-1, -2)
+
+ def unswizzle_data(self, data):
+ raise NotImplementedError()
+
+ def swizzle_block_shape(self, block_shape):
+ SCALE_K = block_shape[-2]
+ N = block_shape[-1]
+ return block_shape[:-2] + [N // 32, SCALE_K * 32]
+
+
+@triton.jit
+def unswizzle_mx_scale_cdna4(
+ x,
+ BLOCK_N: tl.constexpr,
+ MX_SCALE_BLOCK_K: tl.constexpr,
+ N_PRESHUFFLE_FACTOR: tl.constexpr = NON_K_PRESHUFFLE_BLOCK_SIZE,
+):
+ x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1)
+ x = x.permute(0, 5, 3, 1, 4, 2, 6)
+ x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K)
+ return x
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/hopper_scale.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/hopper_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ef61e889b2c4c38bad4832bd160734a4b492b26
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/hopper_scale.py
@@ -0,0 +1,91 @@
+import torch
+import triton
+import triton.language as tl
+from .base import Layout
+
+
+class HopperMXScaleLayout(Layout):
+ name: str = "HOPPER_SCALE"
+
+ def __init__(self, shape, mx_axis, num_warps=8) -> None:
+ assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
+ super().__init__(shape)
+ self.mx_axis = mx_axis
+ self.num_warps = num_warps
+ *self.leading_shape, _, _ = shape
+
+ def _maybe_mT(self, data):
+ if self.mx_axis == len(self.leading_shape):
+ return data.contiguous().mT
+ return data
+
+ def swizzle_data(self, data):
+ data = self._maybe_mT(data).contiguous()
+ *batch, M, K = data.shape
+ SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
+ SWIZZLE_ALIGN_K = 2
+ pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
+ pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
+ data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
+ *batch, M, K = data.shape
+ assert data.is_contiguous()
+ assert M % (2 * self.num_warps * 2 * 8) == 0 and K % 2 == 0, (
+ f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
+ )
+ b = len(batch)
+ data = data.reshape(
+ *batch,
+ M // (2 * self.num_warps * 2 * 8),
+ 2,
+ self.num_warps,
+ 2,
+ 8,
+ K // 2,
+ 2,
+ )
+ perm = [0, 2, 5, 1, 4, 6, 3]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.flatten(-5, -1)
+ data = data.flatten(-3, -2)
+ assert data.shape[-2] == M // 32
+ assert data.shape[-1] == K * 32
+ data = self._maybe_mT(data)
+ return data
+
+ def unswizzle_data(self, data):
+ data = self._maybe_mT(data)
+ *batch, M, K = data.shape
+ b = len(batch)
+ data = data.reshape(
+ *batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2
+ )
+ perm = [0, 3, 1, 6, 4, 2, 5]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.reshape(*batch, M * 32, K // 32)
+ data = self._maybe_mT(data)
+ return data
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
+
+
+@triton.jit
+def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
+ """
+ Triton inverse of swizzle_mxfp4_scale_hopper
+ """
+ tl.static_assert(len(x.shape) == 2, "NYI")
+ # implementation assumes mxfp data is packed along the last dimension
+ x = x.trans() if mx_axis == 0 else x
+ M: tl.constexpr = x.shape[0]
+ K: tl.constexpr = x.shape[1]
+ tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
+ tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
+ x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
+ x = x.trans(0, 3, 1, 6, 4, 2, 5)
+ x = x.reshape(M * 32, K // 32)
+ # implementation assumed mxfp data is packed along the last dimension
+ x = x.trans() if mx_axis == 0 else x
+ return x
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/hopper_value.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/hopper_value.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ddfadf09427f519bc9867094c7855d9d12eac7
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/hopper_value.py
@@ -0,0 +1,362 @@
+import torch
+import triton
+import triton.language as tl
+from .base import Layout
+
+
+def right_shift_unsigned(x, shift):
+ return (x >> shift) & ((1 << (32 - shift)) - 1)
+
+
+# -----------------------------------------------------------------------
+# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
+# 1000000111000000 (first fp4)
+# 1000000111000000 (second fp4)
+# 1000000111000000 (third fp4)
+# 0110110000000000 (fourth fp4)
+# This is done so that dequantization can be done in 14 SASS instructions
+# -----------------------------------------------------------------------
+
+
+def _compress_fp4(x):
+ x = x.to(torch.int32)
+ return ((x & 0x8) << 12) | ((x & 0x7) << 6)
+
+
+def _compress_fourth(x):
+ x = x.to(torch.int32)
+ return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13)
+
+
+def _pack_bits(x: torch.Tensor, mx_axis: int):
+ x = x.contiguous()
+ assert x.shape[-1] % 4 == 0, (
+ "Input tensor must have a last dimension divisible by 4"
+ )
+ x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4))
+ first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16)
+ second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16)
+ third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16)
+ fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16)
+ x = (
+ first
+ | right_shift_unsigned(second, 3)
+ | right_shift_unsigned(third, 6)
+ | fourth
+ )
+ assert x.is_contiguous()
+ x = x.view(torch.uint8)
+ return x
+
+
+# -----------------------------------------------------------------------
+# inverse operation of _pack_bits
+# -----------------------------------------------------------------------
+
+
+def _bf16_to_fp4e2m1(x):
+ # 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
+ assert x.dtype == torch.int16
+ s = (right_shift_unsigned(x, 15) & 0x1) << 3
+ em = right_shift_unsigned(x, 6) & 0x7
+ return (s | em).to(torch.uint8)
+
+
+def _bf16x2_to_fp4e2m1x2(x):
+ # 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
+ assert x.dtype == torch.int32
+ lo = (x & 0xFFFF).to(torch.int16)
+ hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16)
+ ret_lo = _bf16_to_fp4e2m1(lo)
+ ret_hi = _bf16_to_fp4e2m1(hi)
+ return ret_lo | (ret_hi << 4)
+
+
+def _unpack_bits(x, mx_axis: int):
+ x = x.view(torch.int32)
+ m = 0b10000001110000001000000111000000
+ a = (x << 1) & 0b10000000000000001000000000000000
+ b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000
+ c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000
+ unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c]
+ x = torch.stack(unpacked, dim=-1)
+ x = x.flatten(-2, -1)
+ x = _bf16x2_to_fp4e2m1x2(x)
+ return x
+
+
+# -----------------------------------------------------------------------
+
+
+class HopperMXValueLayout(Layout):
+ name: str = "HOPPER_VALUE"
+
+ def __init__(self, shape, mx_axis, mma_version=3):
+ super().__init__(shape)
+ assert mx_axis in range(len(shape))
+ self.mx_axis = mx_axis
+ self.mma_version = mma_version
+ (
+ *self.leading_shape,
+ self.K,
+ self.N,
+ ) = shape
+
+ def _maybe_mT(self, data):
+ if self.mx_axis == len(self.leading_shape):
+ return data.mT
+ return data
+
+ def swizzle_data(self, data):
+ """
+ Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
+ (*, M // 4, K * 4) such that:
+
+ 1) Groups contiguously all the elements owned by the same thread of 4
+ mma tiles along the K axis. The following animation shows a similar
+ grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
+ as done here:
+ https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
+
+ 2) Moves the elements belonging to thread 4-7 to be contiguous with those
+ from thread 0-3. This is done to get a full cache line when loading them
+ from HBM.
+
+ mx_axis selects the lhs or rhs of the matmul.
+
+ WARNING: Assumes that the matmul will be done in bf16 or fp16!
+ Implementing it for fp8 is as easy as making the tile size (8, 8)
+ """
+ batch = data.ndim - 2
+ assert batch >= 0
+ assert self.mma_version in (2, 3)
+ data = self._maybe_mT(data)
+ init_shape = data.shape
+
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
+
+ # Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
+ contig = (1, u8_kwidth)
+ scott_trick = (2, 1)
+ threads = (4, 4)
+ warp_tile = (2, 2)
+ k_tile = (1, 4 // u8_kwidth)
+
+ sizes = list(data.shape[:-2])
+ pads = []
+ # [rest, K, tile, threads] per dimension
+ for i, (a, b, c, s, d) in enumerate(
+ zip(k_tile, warp_tile, threads, scott_trick, contig)
+ ):
+ pack = a * b * c * s * d
+ size = data.shape[batch + i]
+ pad = (pack - size % pack) % pack
+ pads += [(0, pad)]
+ sizes.append((size + pad) // pack)
+ sizes += [a, b, c, s, d]
+
+ pads = tuple(x for t in pads[::-1] for x in t)
+ data = torch.nn.functional.pad(data, pads)
+ init_shape = data.shape
+ # 0: rest[0]
+ # 1: k_tile[0]
+ # 2: warp_tile[0]
+ # 3: threads[0]
+ # 4: scott_trick[0]
+ # 5: contig[0]
+ # 6: rest[1]
+ # 7: k_tile[1]
+ # 8: warp_tile[1]
+ # 9: threads[1]
+ # 10: scott_trick[1]
+ # 11: contig[1]
+ data = data.view(*sizes)
+ # Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
+ perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11]
+ perm = list(range(batch)) + [batch + p for p in perm]
+ data = data.permute(*perm).contiguous()
+ # These are views
+ data = data.flatten(-10, -1)
+ data = data.flatten(-3, -2)
+ assert data.is_contiguous()
+ assert data.shape[-2] == init_shape[-2] // 4
+ assert data.shape[-1] == init_shape[-1] * 4
+ # twiddle the bits
+ data = _pack_bits(data, self.mx_axis)
+ data = self._maybe_mT(data)
+ return data
+
+ def unswizzle_data(self, data):
+ data = self._maybe_mT(data)
+ data = _unpack_bits(data, self.mx_axis)
+ *batch, M, K = data.shape
+ # We have two times the elements if we already upcasted to bfloat16
+ mult = 2 if data.dtype == torch.bfloat16 else 1
+ assert M % 4 == 0, "M must be divisible by 4"
+ assert K % (4 * 8 * 2 * 2 * mult) == 0, (
+ f"K must be divisible by {4 * 8 * 2 * 2 * mult}"
+ )
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
+ data = data.reshape(
+ *batch,
+ M // 4,
+ 4,
+ K // (4 * 8 * 2 * 2 * mult),
+ 2,
+ 4,
+ 8 // u8_kwidth,
+ 2,
+ u8_kwidth * mult,
+ )
+ b = len(batch)
+ perm = [0, 6, 1, 3, 2, 5, 4, 7]
+ perm = list(range(b)) + [b + p for p in perm]
+ data = data.permute(*perm)
+ data = data.reshape(*batch, M * 4, K // 4)
+ data = self._maybe_mT(data)
+ return data[..., : self.K, : self.N]
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
+
+
+@triton.jit
+def _unshuffle_triton(x, mma_version: tl.constexpr):
+ """
+ Triton inverse of swizzle_mxfp4_value_hopper
+ """
+ tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3")
+ # if mx_axis == 0:
+ # x = x.trans()
+
+ # We have two times the elements if we already upcasted to bfloat16
+ mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1
+ M: tl.constexpr = x.shape[0]
+ K: tl.constexpr = x.shape[1]
+ tl.static_assert(M % 4 == 0, "M must be divisible by 4")
+ tl.static_assert(
+ K % (4 * 8 * 2 * 2 * mult) == 0,
+ f"K must be divisible by {4 * 8 * 2 * 2 * mult}",
+ )
+
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
+ # Every u8 represents 2 mxfp4 elements
+ u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1
+ x = x.reshape(
+ M // 4,
+ 4,
+ K // (4 * 8 * 2 * 2 * mult),
+ 2,
+ 4,
+ 8 // u8_kwidth,
+ 2,
+ u8_kwidth * mult,
+ )
+ x = x.trans(0, 6, 1, 3, 2, 5, 4, 7)
+ x = x.reshape(M * 4, K // 4)
+ # if mx_axis == 0:
+ # x = x.trans()
+ return x
+
+
+@triton.jit
+def _unpack_fp4_to_bf16_triton(x):
+ # For now we implement just H100 support (mul.bf16x2)
+ # A100 support is possible via fma
+ r0, r1 = tl.inline_asm_elementwise(
+ r"""
+ {
+ .reg .b32 b, c, d<7>, scale;
+ .reg .b32 bias;
+ mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
+ // We add the missing bias to the scale directly
+ and.b32 $0, $4, 0b10000001110000001000000111000000;
+ mul.bf16x2 $0, $0, bias;
+ shl.b32 b, $4, 3;
+ and.b32 $1, b, 0b10000001110000001000000111000000;
+ mul.bf16x2 $1, $1, bias;
+ shl.b32 c, $4, 6;
+ and.b32 $2, c, 0b10000001110000001000000111000000;
+ mul.bf16x2 $2, $2, bias;
+ // Unpack last two elements
+ shl.b32 d0, $4, 1;
+ and.b32 d1, d0, 0b10000000000000001000000000000000;
+ shr.b32 d2, $4, 3;
+ and.b32 d3, d2, 0b00000001100000000000000110000000;
+ or.b32 d4, d1, d3;
+ shr.b32 d5, $4, 7;
+ and.b32 d6, d5, 0b00000000010000000000000001000000;
+ or.b32 $3, d4, d6;
+ mul.bf16x2 $3, $3, bias;
+ }
+ """,
+ constraints="=r,=r,=r,=r,r",
+ args=[x],
+ dtype=(tl.bfloat16, tl.bfloat16),
+ is_pure=True,
+ pack=4,
+ )
+ # Concat each pack of 4
+ x = tl.join(r0, r1)
+ x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2])
+ x = x.trans(0, 1, 3, 2)
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
+ return x
+
+
+@triton.jit
+def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
+ """
+ Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
+ (x << 0) & 0b1000000111000000
+ (x << 3) & 0b1000000111000000
+ (x << 6) & 0b1000000111000000
+ ((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
+ """
+ # upcast values to bfloat16
+ tl.static_assert(len(x.shape) == 2)
+ tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1")
+ tl.static_assert(x.shape[1] % 4 == 0)
+ tl.static_assert(x.dtype == tl.uint8)
+ if mx_axis == 0:
+ x = x.trans()
+ x = _unpack_fp4_to_bf16_triton(x)
+ x = _unshuffle_triton(x, mma_version=3)
+ if mx_axis == 0:
+ x = x.trans()
+
+ # upcast scale to bfloat16
+ # Add bias missing from the bf16 upcasting sequence
+ # triton / LLVM generates terrible code for this sequence
+ # scale = scale.to(tl.uint16)
+ # scale = scale << 7
+ # scale = scale.to(tl.bfloat16, bitcast=True)
+ scale = tl.inline_asm_elementwise(
+ r"""
+ {
+ prmt.b32 $0, $2, 0, 0x5140;
+ shl.b32 $0, $0, 7;
+ prmt.b32 $1, $2, 0, 0x7362;
+ shl.b32 $1, $1, 7;
+ }
+ """,
+ constraints="=r,=r,r",
+ args=[scale],
+ dtype=tl.bfloat16,
+ is_pure=True,
+ pack=4,
+ )
+ # Broadcast scale
+ scale = scale.expand_dims(mx_axis + 1)
+ scale = scale.broadcast_to(
+ scale.shape[: mx_axis + 1] + [32] + scale.shape[mx_axis + 2 :]
+ )
+ scale = scale.reshape(x.shape)
+
+ # Combine scale and x
+ x = x * scale
+ return x
diff --git a/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/strided.py b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/strided.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbfd9248fca219eb94dae358cafd7fac6e082cd1
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/tensor_details/layout_details/strided.py
@@ -0,0 +1,17 @@
+from .base import Layout
+
+
+class StridedLayout(Layout):
+ name: str = None
+
+ def __init__(self, shape) -> None:
+ super().__init__(shape)
+
+ def swizzle_data(self, data):
+ return data
+
+ def unswizzle_data(self, data):
+ return data
+
+ def swizzle_block_shape(self, block_shape):
+ return block_shape
diff --git a/vllm/kvprune_legacy_save/triton_kernels/testing.py b/vllm/kvprune_legacy_save/triton_kernels/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..179623c2242f08a316a18bb08afefc4f4baab0e1
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/testing.py
@@ -0,0 +1,215 @@
+import enum
+import functools
+import os
+import subprocess
+import sys
+import torch
+from vllm.kvprune.triton_kernels.numerics import (
+ MAX_FINITE_FLOAT8E4B8,
+ MAX_FINITE_FLOAT8E4NV,
+ MAX_FINITE_FLOAT8E5,
+)
+
+
+def assert_equal(ref, tri):
+ if isinstance(ref, torch.Tensor):
+ assert torch.all(ref == tri)
+ else:
+ assert ref == tri
+
+
+def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
+ if tri.dtype.itemsize == 1:
+ ref_as_type = ref.to(tri.dtype)
+ if ref.dtype == tri.dtype:
+ assert torch.all(ref_as_type == tri)
+ return
+ ref = ref_as_type
+
+ if ref.numel() == 0:
+ return
+
+ if maxtol is None:
+ maxtol = 2e-2
+ if rmstol is None:
+ rmstol = 4e-3
+ """
+ Compare reference values against obtained values.
+ """
+
+ # cast to float32:
+ ref = ref.to(torch.float32).detach()
+ tri = tri.to(torch.float32).detach()
+ assert ref.shape == tri.shape, (
+ f"Tensors must have same size {ref.shape=} {tri.shape=}"
+ )
+
+ # deal with infinite elements:
+ inf_mask_ref = torch.isinf(ref)
+ inf_mask_tri = torch.isinf(tri)
+ assert torch.equal(inf_mask_ref, inf_mask_tri), (
+ "Tensor must have same infinite elements"
+ )
+ refn = torch.where(inf_mask_ref, 0, ref)
+ trin = torch.where(inf_mask_tri, 0, tri)
+
+ # normalise so that RMS calculation doesn't overflow:
+ eps = 1.0e-30
+ multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
+ refn *= multiplier
+ trin *= multiplier
+
+ ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
+
+ rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
+ max_err = torch.max(rel_err).item()
+ rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
+
+ if verbose:
+ print(
+ "%s maximum relative error = %s (threshold = %s)"
+ % (description, max_err, maxtol)
+ )
+ print(
+ "%s RMS relative error = %s (threshold = %s)"
+ % (description, rms_err, rmstol)
+ )
+
+ if max_err > maxtol:
+ bad_idxs = torch.nonzero(rel_err > maxtol)
+ num_nonzero = bad_idxs.size(0)
+ bad_idxs = bad_idxs[:1000]
+ print(
+ "%d / %d mismatched elements (shape = %s) at coords %s"
+ % (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())
+ )
+
+ bad_idxs = bad_idxs.unbind(-1)
+ print("ref values: ", ref[tuple(bad_idxs)].cpu())
+ print("tri values: ", tri[tuple(bad_idxs)].cpu())
+
+ assert max_err <= maxtol
+ assert rms_err <= rmstol
+
+
+class ComputeSanitizerTool(enum.Enum):
+ MEMCHECK = "memcheck"
+ RACECHECK = "racecheck"
+ SYNCCHECK = "synccheck"
+ INITCHECK = "initcheck"
+
+
+def compute_sanitizer(**target_kwargs):
+ """
+ Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
+ to expose potential memory access errors.
+ This decorator requires the `request` fixture to be present.
+ If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
+ Running tests under compute sanitizer requires launching subprocess and is slow,
+ so use sparingly
+ """
+
+ def decorator(test_fn):
+ @functools.wraps(test_fn)
+ def wrapper(*args, **kwargs):
+ if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
+ test_fn(*args, **kwargs)
+ return
+
+ import psutil
+
+ if target_kwargs.pop("clear_torch_cache", False):
+ # If we don't pop clear_torch_cache, it won't pass
+ # target_kwargs.items() <= kwargs.items() condition below.
+ torch.cuda.empty_cache()
+ tools_to_check = target_kwargs.pop(
+ "tools_to_check", [ComputeSanitizerTool.MEMCHECK]
+ )
+ assert isinstance(tools_to_check, list), f"{tools_to_check=}"
+ assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
+ f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}"
+ )
+
+ ppid_name = psutil.Process(os.getppid()).exe()
+ run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
+ if "run_sanitizer" in kwargs:
+ run_compute_sanitizer &= kwargs["run_sanitizer"]
+ if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
+ for tool in tools_to_check:
+ path = os.path.realpath(test_fn.__globals__["__file__"])
+ # get path of current file
+ env = {
+ "PATH": os.environ["PATH"],
+ "PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
+ "TORCH_SHOW_CPP_STACKTRACES": "1",
+ "CUDA_LAUNCH_BLOCKING": "1",
+ }
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
+ assert "request_fixture" in kwargs, (
+ "memcheck'ed test must have a (possibly unused) `request` fixture"
+ )
+ test_id = kwargs["request_fixture"].node.callspec.id
+ cmd = f"{path}::{test_fn.__name__}[{test_id}]"
+ cmd = [
+ "compute-sanitizer",
+ "--target-processes=application-only",
+ "--destroy-on-device-error=context",
+ f"--tool={tool.value}",
+ sys.executable,
+ "-m",
+ "pytest",
+ "-vsx",
+ cmd,
+ ]
+ for opt in ["--update_checksum", "--ignore_checksum_error"]:
+ if opt in sys.argv:
+ cmd.append(opt)
+ out = subprocess.run(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ env=env,
+ )
+ sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
+ out.stdout
+ ) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
+ test_output = out.stdout
+ if type(test_output) is bytes:
+ test_output = test_output.decode()
+
+ fail = False
+ if not sanitizer_ok:
+ print("compute-sanitizer returned an error")
+ fail = True
+ elif out.returncode != 0:
+ print(
+ "The test failed due to some other reason: consider running without compute-sanitizer to verify."
+ )
+ print(f"{out.returncode=}")
+ fail = True
+
+ if fail:
+ print("*****************************************************")
+ print("******************** TEST OUTPUT ********************")
+ print("*****************************************************")
+ print(test_output)
+ print("*****************************************************")
+ print("****************** TEST OUTPUT END ******************")
+ print("*****************************************************")
+ assert None
+ else:
+ test_fn(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def compute_actual_scale(x, dtype):
+ max_finite = {
+ torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
+ torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
+ torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
+ }[dtype]
+ return x.abs().max() / max_finite
diff --git a/vllm/kvprune_legacy_save/triton_kernels/topk.py b/vllm/kvprune_legacy_save/triton_kernels/topk.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0615cb588e37b4a1badde9a49def98676634a81
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/topk.py
@@ -0,0 +1,157 @@
+import torch
+import triton
+from vllm.kvprune.triton_kernels.topk_details._topk_forward import _topk_forward
+from vllm.kvprune.triton_kernels.topk_details import _topk_backward
+from vllm.kvprune.triton_kernels.tensor import Tensor, Bitmatrix
+from typing import Optional, Union
+
+
+def topk_forward(
+ x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None
+):
+ if not isinstance(x, Tensor):
+ x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
+ x_shape_max = [x.shape[0], x.shape[1]]
+ x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
+ cdiv = lambda a, b: (a + b - 1) // b
+ BLOCK_M = 32
+ BLOCK_N = 32
+ BLOCK_S = 128
+ assert len(x.shape) == 2
+ assert x.shape_max[-1] < 32768
+ assert dim == 1
+ assert return_bitmatrix
+ n_rows, n_cols = x.shape
+ n_rows_max, _ = x.shape_max
+ dev = x.device
+ # scratchpad tensors
+ # NOTE: these are not returned
+ y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
+ if y_indx is not None:
+ use_provided_indx = True
+ else:
+ y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
+ use_provided_indx = False
+ # create bitmatrix in transposed memory layout:
+ n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
+ n_cols_words = n_cols_pad // 32
+ bitmatrix = torch.empty(
+ (n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev
+ )
+ bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
+ s_blocks = cdiv(n_cols, BLOCK_S)
+ s_cols = s_blocks * BLOCK_S
+ scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev)
+ pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
+ _topk_forward[(pids,)](
+ x,
+ x.stride(0), # inputs
+ y_vals,
+ y_indx,
+ y_vals.stride(0),
+ use_provided_indx, # output [topk]
+ bitmatrix,
+ bitmatrix.stride(0),
+ bitmatrix.stride(1), # output [bitmatrix]
+ n_rows,
+ n_cols, # shapes
+ scratchpad,
+ BLOCK_S,
+ s_blocks, # thing to memset to zero
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N, # tunable parameter
+ APPLY_SOFTMAX=apply_softmax,
+ N_EXPTS_PAD=n_cols_pad,
+ N_EXPTS_ACT=k, # constants
+ )
+ bitmatrix_shape = [n_rows, n_cols_words * 32]
+ bitmatrix_shape_max = [n_rows_max, None]
+ bitmatrix = Bitmatrix(
+ bitmatrix,
+ shape=bitmatrix_shape,
+ shape_max=bitmatrix_shape_max,
+ scratchpad=scratchpad,
+ )
+ return y_vals, y_indx, bitmatrix
+
+
+def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):
+ assert dy_vals.shape[-1] == k
+ n_expts_pad = triton.next_power_of_2(x.shape[-1])
+ dx = torch.empty_like(x)
+ _topk_backward[(dy_vals.shape[0],)](
+ y_indx,
+ y_indx.stride(0),
+ dy_vals,
+ dy_vals.stride(0),
+ x,
+ x.stride(0), # inputs
+ dx, # outputs
+ dx.stride(0),
+ x.shape[0],
+ n_rows,
+ x.shape[-1],
+ APPLY_SOFTMAX=apply_softmax,
+ N_EXPTS_ACT=k,
+ N_EXPTS_PAD=n_expts_pad,
+ )
+ return dx
+
+
+class TopK(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows):
+ y_vals, y_indx, bitmatrix = topk_forward(
+ x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows
+ )
+ ctx.save_for_backward(x, y_indx)
+ ctx.apply_softmax = apply_softmax
+ ctx.k = k
+ ctx.n_rows = n_rows
+ return y_vals, y_indx, bitmatrix
+
+ @staticmethod
+ def backward(ctx, dy_vals, _0, _1):
+ x, y_indx = ctx.saved_tensors
+ dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax)
+ return dx, None, None, None, None, None, None
+
+
+def topk(
+ x: Union[Tensor, torch.Tensor],
+ k: int,
+ apply_softmax: bool = True,
+ dim: int = 1,
+ return_bitmatrix: bool = True,
+ y_indx: Optional[torch.Tensor] = None,
+ n_rows: Optional[int] = None,
+):
+ """
+ Computes the top-k values and indices along a specified dimension of a tensor.
+ Note that the input can be either a `Tensor` or a `torch.Tensor`, but the output will always be a `torch.Tensor`.
+
+ Parameters
+ ----------
+ x : Union[triton_kernels.Tensor, torch.Tensor]
+ Input tensor of shape (n_tokens, n_expts).
+ k : int
+ Number of top elements to retrieve.
+ apply_softmax : bool, default True
+ Whether to apply softmax to the input tensor before computing top-k.
+ dim : int, default 1
+ Dimension along which to compute top-k.
+ return_bitmatrix : bool, default True
+ A bitmatrix of shape (n_tokens, cdiv(n_expts, 32)).
+ Each bit on [t, b] indicates whether the b-th expert was selected for the t-th token.
+ y_indx : torch.Tensor, optional
+ Pre-allocated tensor for storing indices of top-k elements with shape (n_tokens, k).
+ If provided, we skip the computation of top-k indices and use this tensor instead.
+ n_rows : int, optional
+ Number of rows to apply top-k on. If None, we consider all rows in `x`.
+
+ Returns
+ -------
+ (expt_scal, expt_indx, bitmatrix) : Tuple[torch.Tensor, torch.Tensor, Bitmatrix]
+ """
+ ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
+ return ret
diff --git a/vllm/kvprune_legacy_save/triton_kernels/topk_details/__init__.py b/vllm/kvprune_legacy_save/triton_kernels/topk_details/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_backward.py b/vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_backward.py
new file mode 100644
index 0000000000000000000000000000000000000000..eebe481771543a05cfab5741bf1a0c875248f70d
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_backward.py
@@ -0,0 +1,51 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _topk_backward(
+ Yi,
+ stride_ym, # topk indices
+ DY,
+ stride_dym, # output gradient values
+ X,
+ stride_xm, # input values
+ DX,
+ stride_dxm, # input gradient values
+ n_rows,
+ NRows,
+ n_expts_tot,
+ APPLY_SOFTMAX: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ N_EXPTS_PAD: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ if NRows is not None:
+ n_rows = tl.load(NRows)
+ if pid_m >= n_rows:
+ return
+ Yi += pid_m * stride_ym
+ DY += pid_m * stride_dym
+ X += pid_m * stride_xm
+ DX += pid_m * stride_dxm
+ # --
+ offs_xn = tl.arange(0, N_EXPTS_PAD)
+ offs_yn = tl.arange(0, N_EXPTS_ACT)
+ mask_xn = offs_xn < n_expts_tot
+ # recompute softmax
+ y_indx = tl.load(Yi + offs_yn)
+ x = tl.load(X + y_indx)
+ x = x.to(tl.float32)
+ y = tl.softmax(x)
+ # compute input-gradient
+ dy = tl.load(DY + offs_yn)
+ dy = dy.to(tl.float32)
+ s = tl.sum(y * dy, 0)
+ # write-back input gradient
+ tl.store(DX + offs_xn, 0, mask=mask_xn)
+ tl.debug_barrier()
+ if APPLY_SOFTMAX:
+ dx = y * (dy - s)
+ else:
+ dx = dy
+ tl.store(DX + y_indx, dx)
diff --git a/vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_forward.py b/vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf27ba999cca1a2b8fe63f1c386680c77ea4cec9
--- /dev/null
+++ b/vllm/kvprune_legacy_save/triton_kernels/topk_details/_topk_forward.py
@@ -0,0 +1,183 @@
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def get_topmask_and_fullmask(x):
+ tl.static_assert(
+ x.dtype.is_int_unsigned(), "floating-point value must be passed as bits"
+ )
+ tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
+ fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
+ tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
+ fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
+ return tm_arr, fm_arr
+
+
+@triton.jit
+def fpval_to_key(x):
+ tm, fm = get_topmask_and_fullmask(x)
+ return x ^ tl.where((x & tm) != 0, fm, tm)
+
+
+@triton.jit
+def key_to_fpval(x):
+ tm, fm = get_topmask_and_fullmask(x)
+ return x ^ tl.where((x & tm) == 0, fm, tm)
+
+
+# stable top-k tie-breaks to value with smaller index
+@triton.jit
+def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr):
+ return N_EXPTS_PAD - indx
+
+
+@triton.jit
+def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr):
+ return N_EXPTS_PAD - indx
+
+
+@triton.jit
+def streaming_topk(
+ X,
+ stride_xm,
+ n_expts_tot,
+ offs_m,
+ mask_m,
+ N_EXPTS_PAD: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
+ x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
+ if x_nbits < 16:
+ # this ensures that we leave at least 16 bits for expert index
+ # even if the input dtype is smaller than 16 bits:
+ y_nbits: tl.constexpr = 32
+ else:
+ y_nbits: tl.constexpr = x_nbits * 2
+ x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}")
+ x_dtype: tl.constexpr = X.dtype.element_ty
+
+ # subtract 1 from loop iterations because we peel the first (masked) iteration:
+ loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
+ offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N)
+ mask_n = offs_x_n[None, :] < n_expts_tot
+
+ # first iteration:
+ X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
+ x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
+ x = fpval_to_key(x.to(x_utype, bitcast=True))
+ x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
+ acc = tl.topk(x, N_EXPTS_ACT, dim=1)
+
+ # subsequent iterations:
+ for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
+ acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
+ X_ptrs -= BLOCK_N
+ offs_x_n -= BLOCK_N
+ x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
+ x = fpval_to_key(x.to(x_utype, bitcast=True))
+ x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
+ acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
+
+ # rotate expert index into upper 16 bits:
+ # 0000vvvvvvvviiii --> iiii0000vvvvvvvv
+ acc = (acc << (y_nbits - 16)) | (acc >> 16)
+ # sort in ascending order of expert (descending order of key)
+ acc = tl.sort(acc, dim=1, descending=True)
+ # iiii0000vvvvvvvv --> 0000iiii:
+ y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
+ y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
+ # iiii0000vvvvvvvv --> vvvvvvvv:
+ y_values_raw = acc.to(x_utype)
+ y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
+
+ return y_values, y_indices
+
+
+@triton.jit
+def _topk_forward(
+ X,
+ stride_xm, # inputs
+ Yv,
+ Yi,
+ stride_ym, # topk values/indices
+ USE_PROVIDED_INDX: tl.constexpr,
+ Bits,
+ stride_rm: tl.constexpr,
+ stride_rn: tl.constexpr, # bitmatrix
+ n_rows,
+ n_expts_tot, # shape
+ S,
+ BLOCK_S: tl.constexpr,
+ s_blocks, # thing to memset
+ APPLY_SOFTMAX: tl.constexpr, # constant
+ BLOCK_M: tl.constexpr,
+ N_EXPTS_PAD: tl.constexpr,
+ N_EXPTS_ACT: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr():
+ n_rows = tl.load(n_rows)
+
+ if pid < s_blocks:
+ tl.store(
+ S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32)
+ )
+
+ if pid * BLOCK_M >= n_rows:
+ # early exit:
+ return
+
+ tl.static_assert(BLOCK_N % 32 == 0)
+ tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
+ x_dtype: tl.constexpr = X.dtype.element_ty
+
+ # load logits
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_y_n = tl.arange(0, N_EXPTS_ACT)
+ mask_m = offs_m[:, None] < n_rows
+ if USE_PROVIDED_INDX:
+ Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ y_indices = tl.load(Yi_ptrs, mask=mask_m)
+ Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
+ y_values = tl.load(Xv_ptrs, mask=mask_m)
+ else:
+ y_values, y_indices = streaming_topk(
+ X,
+ stride_xm,
+ n_expts_tot,
+ offs_m,
+ mask_m, #
+ N_EXPTS_PAD,
+ N_EXPTS_ACT,
+ BLOCK_N,
+ )
+
+ # normalize selected values
+ if APPLY_SOFTMAX:
+ y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(
+ x_dtype
+ )
+
+ # write back
+ Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ tl.store(Yv_ptrs, y_values, mask=mask_m)
+ if not USE_PROVIDED_INDX:
+ Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
+ tl.store(Yi_ptrs, y_indices, mask=mask_m)
+
+ # pack into bitmatrix
+ y_div = y_indices // 32
+ y_rem = y_indices % 32
+ loop_iterations = N_EXPTS_PAD // BLOCK_N
+ for i in range(loop_iterations):
+ offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
+ y2 = tl.where(
+ y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0
+ )
+ r = tl.reduce_or(y2, axis=1)
+ BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
+ tl.store(BitsPtrs, r, mask=mask_m)
diff --git a/vllm/kvprune_legacy_save/utils/__init__.py b/vllm/kvprune_legacy_save/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f61e74eab6605369645345370ef951e9141fef14
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/__init__.py
@@ -0,0 +1,29 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Shared helpers: Triton compat, layout bridge, context, sequences."""
+
+from vllm.kvprune.utils.layout_bridge import (
+ block_table_to_global_page_table,
+ build_batch_mapping,
+ build_page_table_head_major,
+ flatten_kv_cache_head_major,
+ flatten_kv_cache_plane,
+ write_head_major_flat_to_interleaved,
+)
+from vllm.kvprune.utils.triton_compat import (
+ autotune as triton_autotune,
+ cuda_capability_geq,
+ maybe_set_allocator,
+)
+
+__all__ = [
+ "block_table_to_global_page_table",
+ "build_batch_mapping",
+ "build_page_table_head_major",
+ "cuda_capability_geq",
+ "flatten_kv_cache_head_major",
+ "flatten_kv_cache_plane",
+ "write_head_major_flat_to_interleaved",
+ "maybe_set_allocator",
+ "triton_autotune",
+]
diff --git a/vllm/kvprune_legacy_save/utils/arguments.py b/vllm/kvprune_legacy_save/utils/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ab122cc6dfdc30b4cdf38c8d2bc15b770d4d106
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/arguments.py
@@ -0,0 +1,445 @@
+import itertools
+import math
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+from vllm.kvprune.compression import CompressionMethod
+from vllm.kvprune.compression.compression_config import BatchCompressionParams
+from vllm.kvprune.config.engine_config import LLMConfig
+from vllm.kvprune.utils.sequence import Sequence
+from vllm.kvprune.utils.kv_dist import broadcast_from_tp_rank0
+from vllm.kvprune.utils.tp_utils import kv_heads_shard_divisor
+
+
+@dataclass
+class PrefillBatchArguments:
+ B: int
+ N: int
+ do_compression: bool
+ compression_method: CompressionMethod
+ compression_chunk_size: int
+
+ seq_ids: torch.Tensor
+
+ input_ids: torch.Tensor
+ positions: torch.Tensor
+ cu_seqlens_q: torch.Tensor
+ cu_seqlens_k: torch.Tensor
+ max_seqlen_q: int
+ max_seqlen_k: int
+
+ batch_tokens_to_retain: Optional[torch.Tensor]
+ max_tokens_to_retain: Optional[int]
+ protected_first: Optional[List[int]]
+ protected_last: Optional[List[int]]
+
+ PHI: Optional[torch.Tensor]
+
+ # args needed for memory reservation
+ context_lens: torch.Tensor
+ max_new_tokens: torch.Tensor
+
+ # 与 kvpress ``CompactorPress`` blending 默认(未显式指定时用 compression_ratio)对齐
+ compression_ratio: float = 1.0
+
+class PackedTensorArguments:
+ def __init__(
+ self,
+ rank: int,
+ max_batched_tokens: int,
+ config: LLMConfig,
+ seed: int = 42,
+ *,
+ device: torch.device | None = None,
+ use_tp_group_for_collectives: bool = False,
+ ) -> None:
+ hf_config = config.hf_config
+ self.rank = rank
+ self.device = device if device is not None else torch.device(f"cuda:{rank}")
+ self._use_tp_group = use_tp_group_for_collectives
+ self.max_num_batches = config.max_num_seqs
+ self.max_batched_tokens = max_batched_tokens
+ _ws = kv_heads_shard_divisor()
+ self.num_kv_heads = hf_config.num_key_value_heads // _ws
+ self.world_size = config.tensor_parallel_size
+ self.page_size = int(config.kvcache_page_size)
+ self.head_dim = getattr(hf_config, "head_dim", None)
+ self.sketch_dim = config.leverage_sketch_size
+ self.model_dtype = hf_config.torch_dtype
+
+ # i64 pack = [seq_ids (BMAX)] || [input_ids (NMAX)] || [positions (NMAX)] || max_new_tok (BMAX)
+ self.i64_len_max = (
+ self.max_num_batches + 2 * self.max_batched_tokens + self.max_num_batches
+ )
+ self.packed_context_i64 = torch.empty(
+ self.i64_len_max, dtype=torch.int64, device=self.device
+ )
+
+ # i32 pack = [header (6): ... + compression_ratio*1e6] || [cu_q (BMAX+1)] || ...
+ # || [protected_first_tokens (BMAX)] || [protected_last_tokens (BMAX)]
+ self.i32_len_max = (
+ 6
+ + (self.max_num_batches + 1)
+ + (self.max_num_batches + 1)
+ + self.max_num_batches
+ + self.max_num_batches
+ + self.max_num_batches
+ + self.max_num_batches
+ )
+ self.packed_context_i32 = torch.empty(
+ self.i32_len_max, dtype=torch.int32, device=self.device
+ )
+
+ self.generator = torch.Generator(device=self.device).manual_seed(seed)
+ self.PHI = torch.randn(
+ (self.head_dim, self.sketch_dim),
+ device=self.packed_context_i32.device,
+ generator=self.generator,
+ ).to(self.model_dtype) * (1 / math.sqrt(self.sketch_dim))
+
+ def _master_build_prefill(
+ self, seqs: List[Sequence], batch_compression_params: BatchCompressionParams
+ ) -> PrefillBatchArguments:
+ B = len(seqs)
+ if B == 0:
+ raise ValueError(
+ "prefill batch is empty (scheduler should not call build_prefill with "
+ "no sequences)"
+ )
+ Ls = [x.prompt_len for x in seqs]
+
+ N = sum(Ls)
+ assert N <= self.max_batched_tokens
+ do_compression = any(x.compression_params.compression_ratio < 1.0 for x in seqs)
+ do_compression = (
+ do_compression
+ and batch_compression_params.compression_method != CompressionMethod.NONE
+ )
+ pack_slices_64 = self.packed_i64_slices(B, N)
+ pack_slices_32 = self.packed_i32_slices(B)
+
+ # max_retain = max(retain)
+ protected_first_list = [
+ x.compression_params.protected_first_tokens for x in seqs
+ ]
+ protected_last_list = [x.compression_params.protected_last_tokens for x in seqs]
+ retain = [
+ max(
+ int(
+ round(
+ x.compression_params.compression_ratio
+ * (L - s - e)
+ * self.num_kv_heads
+ )
+ ),
+ 1,
+ )
+ for s, e, L, x in zip(protected_first_list, protected_last_list, Ls, seqs)
+ ]
+ retain = torch.tensor(retain, dtype=torch.int32, device="cpu", pin_memory=True)
+ protected_first = torch.tensor(
+ protected_first_list, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ protected_last = torch.tensor(
+ protected_last_list, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ self.packed_context_i32[pack_slices_32["protected_first"]].copy_(
+ protected_first, non_blocking=True
+ )
+ self.packed_context_i32[pack_slices_32["protected_last"]].copy_(
+ protected_last, non_blocking=True
+ )
+ compression_chunk_size = (
+ batch_compression_params.chunk_size
+ if batch_compression_params.do_chunked_compression
+ else -1
+ )
+ min_compression_ratio = min(x.compression_params.compression_ratio for x in seqs)
+ cr_scaled = int(round(float(min_compression_ratio) * 1_000_000.0))
+ cr_scaled = max(min(cr_scaled, 2_000_000_000), -2_000_000_000)
+ header_host = torch.tensor(
+ [
+ B,
+ N,
+ 1 if do_compression else 0,
+ batch_compression_params.compression_method.value,
+ compression_chunk_size,
+ cr_scaled,
+ ],
+ dtype=torch.int32,
+ device="cpu",
+ pin_memory=True,
+ )
+
+ self.packed_context_i32[pack_slices_32["retain"]].copy_(
+ retain, non_blocking=True
+ )
+ self.packed_context_i32[pack_slices_32["header"]].copy_(
+ header_host, non_blocking=True
+ )
+ max_seq_qk = max(Ls)
+
+ cu = torch.tensor(
+ list(itertools.accumulate(Ls, initial=0)),
+ dtype=torch.int32,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i32[pack_slices_32["cu_q"]].copy_(cu, non_blocking=True)
+ self.packed_context_i32[pack_slices_32["cu_k"]].copy_(cu, non_blocking=True)
+ self.packed_context_i32[pack_slices_32["context_lens"]].copy_(
+ cu.diff(), non_blocking=True
+ )
+
+ seq_ids = torch.tensor(
+ [x.seq_id for x in seqs], dtype=torch.int64, device="cpu", pin_memory=True
+ )
+ input_ids = torch.tensor(
+ [tid for x in seqs for tid in x.prompt_token_ids],
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i64[pack_slices_64["seq_ids"]].copy_(
+ seq_ids, non_blocking=True
+ )
+ self.packed_context_i64[pack_slices_64["input_ids"]].copy_(
+ input_ids, non_blocking=True
+ )
+
+ positions = torch.cat(
+ [
+ torch.arange(L, dtype=torch.int64, device="cpu", pin_memory=True)
+ for L in Ls
+ ]
+ )
+ self.packed_context_i64[pack_slices_64["positions"]].copy_(
+ positions, non_blocking=True
+ )
+
+ max_new_tokens = torch.tensor(
+ [seq.sampling_params.max_new_tokens for seq in seqs],
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=True,
+ )
+ self.packed_context_i64[pack_slices_64["max_new_tokens"]].copy_(
+ max_new_tokens, non_blocking=True
+ )
+ # `prefill_store_topk_kv(..., PAD_TO_PAGE_SIZE=True)` may scan beyond the
+ # top-k prefix to fill per-head lengths up to a page boundary. Using a
+ # full ranking (top_k = max_seq_len * HKV) makes `torch.topk` degenerate
+ # into a full sort, which is very expensive for long contexts.
+ #
+ # Instead, request only a prefix that is large enough for:
+ # 1) the maximum "keep" budget in the batch, plus
+ # 2) a conservative extra window for page-padding candidates.
+ max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
+ full_budget = max_seq_len * self.num_kv_heads
+ keep_budget = int(retain.max().item())
+ pad_search_budget = (self.page_size - 1) * (self.num_kv_heads**2)
+ max_retain = min(full_budget, keep_budget + pad_search_budget)
+ # Non-blocking H2D copies above must finish before NCCL broadcast, or peers can
+ # receive stale/garbage packed buffers → wrong prefill → garbage tokens on TP>1.
+ if self.packed_context_i64.is_cuda:
+ torch.cuda.synchronize()
+ # PHI: rank 0's sketch matrix is broadcast so all TP ranks share one PHI for
+ # leverage / compactor scores (same order as packed_context: i64, i32, PHI).
+ broadcast_from_tp_rank0(
+ self.packed_context_i64, use_tp_group=self._use_tp_group
+ )
+ broadcast_from_tp_rank0(
+ self.packed_context_i32, use_tp_group=self._use_tp_group
+ )
+ if self.world_size > 1:
+ broadcast_from_tp_rank0(self.PHI, use_tp_group=self._use_tp_group)
+ prefill_args = PrefillBatchArguments(
+ B=B,
+ N=N,
+ do_compression=do_compression,
+ compression_method=batch_compression_params.compression_method,
+ compression_chunk_size=compression_chunk_size,
+ seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
+ input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
+ positions=self.packed_context_i64[pack_slices_64["positions"]],
+ cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
+ cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
+ max_seqlen_q=max_seq_qk,
+ max_seqlen_k=max_seq_qk,
+ batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
+ max_tokens_to_retain=max_retain,
+ PHI=self.PHI,
+ context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
+ max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
+ protected_first=protected_first_list,
+ protected_last=protected_last_list,
+ compression_ratio=min_compression_ratio,
+ )
+ return prefill_args
+
+ def _peer_receive_prefill(self) -> PrefillBatchArguments:
+ broadcast_from_tp_rank0(
+ self.packed_context_i64, use_tp_group=self._use_tp_group
+ )
+ broadcast_from_tp_rank0(
+ self.packed_context_i32, use_tp_group=self._use_tp_group
+ )
+ if self.world_size > 1:
+ broadcast_from_tp_rank0(self.PHI, use_tp_group=self._use_tp_group)
+ # Header is 6 fields (B, N, do_compression, method, chunk_size, cr_scaled); must match
+ # packed_i32_slices(B)["header"] for any B.
+ header = self.packed_context_i32[:6].tolist()
+ B, N = int(header[0]), int(header[1])
+ do_compression = bool(int(header[2]))
+ compression_method = CompressionMethod(int(header[3]))
+ compression_chunk_size = int(header[4])
+ compression_ratio = int(header[5]) / 1_000_000.0
+
+ pack_slices_64 = self.packed_i64_slices(B, N)
+ pack_slices_32 = self.packed_i32_slices(B)
+ max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
+ # Must match _master_build_prefill: max_seqlen_{q,k} = max(Ls), not cu_q.max()
+ # (which equals total batch tokens N and breaks varlen attention on peers).
+ full_budget = max_seq_len * self.num_kv_heads
+ keep_budget = int(self.packed_context_i32[pack_slices_32["retain"]].max().item())
+ pad_search_budget = (self.page_size - 1) * (self.num_kv_heads**2)
+ max_retain = min(full_budget, keep_budget + pad_search_budget)
+ prefill_args = PrefillBatchArguments(
+ B=B,
+ N=N,
+ do_compression=do_compression,
+ compression_method=compression_method,
+ compression_chunk_size=compression_chunk_size,
+ seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
+ input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
+ positions=self.packed_context_i64[pack_slices_64["positions"]],
+ cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
+ cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
+ max_seqlen_q=max_seq_len,
+ max_seqlen_k=max_seq_len,
+ batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
+ max_tokens_to_retain=max_retain,
+ PHI=self.PHI,
+ context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
+ max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
+ protected_first=self.packed_context_i32[
+ pack_slices_32["protected_first"]
+ ].tolist(),
+ protected_last=self.packed_context_i32[
+ pack_slices_32["protected_last"]
+ ].tolist(),
+ compression_ratio=compression_ratio,
+ )
+ return prefill_args
+
+ @torch.inference_mode()
+ def build_prefill_args(
+ self,
+ seqs: Optional[List[Sequence]] = None,
+ batch_compression_params: Optional[BatchCompressionParams] = None,
+ ) -> PrefillBatchArguments:
+ if self.rank == 0:
+ return self._master_build_prefill(seqs, batch_compression_params)
+ return self._peer_receive_prefill()
+
+ def broadcast(self):
+ if self.world_size > 1:
+ return broadcast_from_tp_rank0(
+ self.packed_context_i64, use_tp_group=self._use_tp_group
+ )
+ return None
+
+ @staticmethod
+ def packed_i64_slices(B: int, N: int):
+ return {
+ "seq_ids": slice(0, B),
+ "input_ids": slice(B, B + N),
+ "positions": slice(B + N, B + 2 * N),
+ "max_new_tokens": slice(B + 2 * N, 2 * B + 2 * N),
+ }
+
+ @staticmethod
+ def packed_i32_slices(B: int):
+ h0, h1 = 0, 6
+ q0 = h1
+ q1 = q0 + (B + 1)
+ k0 = q1
+ k1 = k0 + (B + 1)
+ r0 = k1
+ r1 = r0 + B
+ c0 = r1
+ c1 = r1 + B
+
+ pf0 = c1
+ pf1 = c1 + B
+ pl0 = pf1
+ pl1 = pf1 + B
+ return {
+ "header": slice(h0, h1),
+ "cu_q": slice(q0, q1),
+ "cu_k": slice(k0, k1),
+ "retain": slice(r0, r1),
+ "context_lens": slice(c0, c1),
+ "protected_first": slice(pf0, pf1),
+ "protected_last": slice(pl0, pl1),
+ }
+
+
+@dataclass
+class DecodeBatchOutput:
+ output_tokens: Optional[torch.Tensor]
+ output_seq_ids: Optional[torch.Tensor]
+
+
+@dataclass
+class DecodeBatchArguments:
+ batch_mapping: Optional[torch.Tensor] = None
+ token_ids: Optional[torch.Tensor] = None
+ positions: Optional[torch.Tensor] = None
+ max_ctx_lens: Optional[torch.Tensor] = None
+ seq_ids: Optional[torch.Tensor] = None
+ temps: Optional[torch.Tensor] = None
+ desired_batch_occupancy: int = -1
+ num_stashed_batches: int = 0
+
+ def update(
+ self,
+ batch_mapping,
+ token_ids,
+ positions,
+ max_ctx_lens,
+ seq_ids,
+ temps=None,
+ desired_batch_occupancy: int = None,
+ ):
+ if self.batch_mapping is not None:
+ self.batch_mapping = torch.cat([self.batch_mapping, batch_mapping], dim=0)
+ else:
+ self.batch_mapping = batch_mapping.clone()
+ if self.token_ids is not None:
+ self.token_ids = torch.cat([self.token_ids, token_ids], dim=0)
+ else:
+ self.token_ids = token_ids.clone()
+ if self.positions is not None:
+ self.positions = torch.cat([self.positions, positions], dim=0)
+ else:
+ self.positions = positions.clone()
+ if self.max_ctx_lens is not None:
+ self.max_ctx_lens = torch.cat([self.max_ctx_lens, max_ctx_lens], dim=0)
+ else:
+ self.max_ctx_lens = max_ctx_lens.clone()
+ if self.seq_ids is not None:
+ self.seq_ids = torch.cat([self.seq_ids, seq_ids], dim=0)
+ else:
+ self.seq_ids = seq_ids.clone()
+
+ if self.temps is not None and temps is not None:
+ self.temps = torch.cat([self.temps, temps], dim=0)
+ elif temps is not None:
+ self.temps = temps.clone()
+
+ if desired_batch_occupancy is not None:
+ self.desired_batch_occupancy = desired_batch_occupancy
+
+ return self
diff --git a/vllm/kvprune_legacy_save/utils/context.py b/vllm/kvprune_legacy_save/utils/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d44a34658d665ce2d238613f5fa3fbc5cf201bf
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/context.py
@@ -0,0 +1,109 @@
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import torch
+
+# Import from compression_config, not compression.__init__, to avoid circular imports
+# (compression -> compactor -> context -> compression).
+from vllm.kvprune.compression.compression_config import CompressionMethod
+from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule
+
+
+@dataclass
+class CompressionContext:
+ compression_method: CompressionMethod = CompressionMethod.COMPACTOR
+
+ compression_chunk_size: int = -1
+ batch_tokens_to_retain: torch.Tensor | None = None
+ max_tokens_to_retain: int = 0
+ context_lens: List[int] | None = None
+ PHI: torch.Tensor | None = None
+
+ # Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参)
+ sketch_dimension: int = 48
+ sink_size_start: int = 8
+ sink_size_end: int = 4
+ compactor_blending: Optional[float] = None
+ # 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio)
+ compression_ratio: Optional[float] = None
+
+ protected_first_tokens: List[int] | None = None
+ protected_last_tokens: List[int] | None = None
+
+ # CriticalAdaKV
+ wo_weight: Optional[torch.Tensor] = None
+ critical_ada_epsilon: float = 1e-4
+ critical_ada_first_stage_ratio: float = 0.5
+ critical_ada_alpha_safeguard: float = 0.2
+
+
+@dataclass
+class Context:
+ is_prefill: bool = False
+ do_compression: bool = False
+
+ cu_seqlens_q: torch.Tensor | None = None
+ cu_seqlens_k: torch.Tensor | None = None
+ # Set in ModelRunner.run_prefill before forward — avoids D2H inside compactor kernels.
+ cu_seqlens_q_host: Optional[Tuple[int, ...]] = None
+ cu_seqlens_k_host: Optional[Tuple[int, ...]] = None
+ max_seqlen_q: int = 0
+ max_seqlen_k: int = 0
+ batch_mapping: torch.Tensor | None = None
+ max_bh_len: int = 0
+
+ compression_context: CompressionContext | None = None
+ STORE_STREAM: torch.cuda.Stream | None = None
+
+ key_split: int | None = None
+ attention_schedule: KvpruneAttentionSchedule = (
+ KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
+ )
+
+
+_CONTEXT = Context()
+
+
+def get_context():
+ return _CONTEXT
+
+
+def set_context(
+ *,
+ is_prefill,
+ do_compression=False,
+ cu_seqlens_q=None,
+ cu_seqlens_k=None,
+ cu_seqlens_q_host: Optional[Tuple[int, ...]] = None,
+ cu_seqlens_k_host: Optional[Tuple[int, ...]] = None,
+ max_seqlen_q=0,
+ max_seqlen_k=0,
+ batch_mapping=None,
+ max_bh_len=0,
+ compression_context: CompressionContext = None,
+ STORE_STREAM=None,
+ key_split=None,
+ attention_schedule=KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE,
+):
+ global _CONTEXT
+ _CONTEXT = Context(
+ is_prefill,
+ do_compression,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ cu_seqlens_q_host,
+ cu_seqlens_k_host,
+ max_seqlen_q,
+ max_seqlen_k,
+ batch_mapping,
+ max_bh_len,
+ compression_context,
+ STORE_STREAM,
+ key_split,
+ attention_schedule,
+ )
+
+
+def reset_context():
+ global _CONTEXT
+ _CONTEXT = Context()
diff --git a/vllm/kvprune_legacy_save/utils/helpers.py b/vllm/kvprune_legacy_save/utils/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e833b885ec2cc2372b1a267a7b361b535fd9d938
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/helpers.py
@@ -0,0 +1,35 @@
+from collections.abc import Callable
+
+import torch
+
+
+def maybe_execute_in_stream(
+ fn: Callable, *args, STORE_STREAM: torch.cuda.Stream = None, **kwargs
+):
+ if STORE_STREAM is not None:
+ tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
+ tensors += [val for val in kwargs.values() if isinstance(val, torch.Tensor)]
+ obj = getattr(fn, "__self__", None)
+ if isinstance(obj, torch.Tensor):
+ tensors.append(obj)
+ STORE_STREAM.wait_stream(torch.cuda.default_stream())
+ # Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
+ # The portable API is `torch.cuda.stream(stream)`.
+ stream_ctx = (
+ STORE_STREAM
+ if hasattr(STORE_STREAM, "__enter__")
+ else torch.cuda.stream(STORE_STREAM)
+ )
+ with stream_ctx:
+ output = fn(*args, **kwargs)
+ for t in tensors:
+ t.record_stream(STORE_STREAM)
+ if isinstance(output, tuple):
+ for o in output:
+ if isinstance(o, torch.Tensor):
+ o.record_stream(torch.cuda.default_stream())
+ elif isinstance(output, torch.Tensor):
+ output.record_stream(torch.cuda.default_stream())
+ return output
+ else:
+ return fn(*args, **kwargs)
diff --git a/vllm/kvprune_legacy_save/utils/kv_dist.py b/vllm/kvprune_legacy_save/utils/kv_dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7e13120cc909c77fe28a61bfc1ea18ab928cdd8
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/kv_dist.py
@@ -0,0 +1,35 @@
+"""Distributed helpers for kvprune when embedded in vLLM (use TP process group)."""
+
+from __future__ import annotations
+
+import torch
+import torch.distributed as dist
+
+
+def broadcast_from_tp_rank0(
+ tensor: torch.Tensor, *, use_tp_group: bool
+) -> None:
+ """Broadcast ``tensor`` from group-local rank 0.
+
+ When ``use_tp_group`` is False (standalone compactor subprocesses), uses the
+ default process group (world == tensor parallel size).
+
+ When True (embedded in a vLLM worker), uses vLLM's tensor-parallel group so
+ collectives do not accidentally involve DP/PP ranks if the default group is global.
+ """
+ if not use_tp_group:
+ dist.broadcast(tensor, src=0)
+ return
+ from vllm.distributed.parallel_state import get_tp_group
+
+ get_tp_group().broadcast(tensor, src=0)
+
+
+def barrier_sync(*, use_tp_group: bool) -> None:
+ """Barrier across either the default group or the TP group (see :func:`broadcast_from_tp_rank0`)."""
+ if not use_tp_group:
+ dist.barrier()
+ return
+ from vllm.distributed.parallel_state import get_tp_group
+
+ get_tp_group().barrier()
diff --git a/vllm/kvprune_legacy_save/utils/layout_bridge.py b/vllm/kvprune_legacy_save/utils/layout_bridge.py
new file mode 100644
index 0000000000000000000000000000000000000000..31321b2f7cf31db79880ecaef8e3601c8ff87662
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/layout_bridge.py
@@ -0,0 +1,167 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Bridge vLLM paged KV layout to compactor Triton kernels.
+
+vLLM FlashAttention KV cache is shaped
+ [num_blocks, block_size, num_kv_heads, head_dim].
+Compactor kernels expect a flat buffer [CACHE_SIZE, head_dim] and a page table
+ global_page_table[batch, kv_head, logical_page] -> physical_page_id
+where each physical page holds ``block_size`` consecutive rows belonging to that
+KV head only.
+
+When num_kv_heads == 1 (MQA), a vLLM block maps 1:1 to compactor rows:
+ row_index = physical_block_id * block_size + offset_in_block.
+
+When ``num_kv_heads > 1``, we permute to head-major
+``[num_kv_heads, num_blocks, block_size, head_dim]`` and flatten to
+``[num_kv_heads * num_blocks * block_size, head_dim]`` so each KV head occupies
+a disjoint row range in the flat buffer. The page table is built so each
+logical compression page maps to ``global_row // PAGE_SIZE`` in that layout
+(see ``build_page_table_head_major``).
+"""
+
+from __future__ import annotations
+
+import torch
+
+
+def _cdiv(n: int, d: int) -> int:
+ return (n + d - 1) // d
+
+
+def flatten_kv_cache_head_major(
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """View ``[nb, bs, H, D]`` caches as ``[H*nb*bs, D]`` in head-major order."""
+ if key_cache.shape != value_cache.shape:
+ raise ValueError("key_cache and value_cache must match")
+ nb, bs, hkv, d = key_cache.shape
+ k_hm = key_cache.permute(2, 0, 1, 3).contiguous()
+ v_hm = value_cache.permute(2, 0, 1, 3).contiguous()
+ k_flat = k_hm.reshape(hkv * nb * bs, d)
+ v_flat = v_hm.reshape(hkv * nb * bs, d)
+ return k_flat, v_flat
+
+
+def write_head_major_flat_to_interleaved(
+ k_flat: torch.Tensor,
+ v_flat: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+) -> None:
+ """Copy ``[H*nb*bs, D]`` head-major flats back to ``[nb, bs, H, D]``."""
+ nb, bs, hkv, d = key_cache.shape
+ k_hm = k_flat.view(hkv, nb, bs, d)
+ v_hm = v_flat.view(hkv, nb, bs, d)
+ key_cache.copy_(k_hm.permute(1, 2, 0, 3))
+ value_cache.copy_(v_hm.permute(1, 2, 0, 3))
+
+
+def build_page_table_head_major(
+ block_table: torch.Tensor,
+ num_kv_heads: int,
+ num_blocks: int,
+ block_size: int,
+ page_size: int,
+ max_batches: int,
+) -> torch.Tensor:
+ """Build ``[max_batches, H, max_chain]`` page table for head-major flat KV.
+
+ Chains physical page ids in ``block_table`` order for each (batch, head).
+ Each entry is ``global_row // page_size`` where ``global_row`` indexes rows
+ in the head-major flat buffer (see ``flatten_kv_cache_head_major``).
+ """
+ bsz, max_blocks = block_table.shape
+ if bsz > max_batches:
+ raise ValueError("batch size exceeds max_batches for page table")
+ num_pages_per_block = _cdiv(block_size, page_size)
+ max_chain = max_blocks * num_pages_per_block
+ out = torch.zeros(
+ (max_batches, num_kv_heads, max_chain),
+ dtype=torch.int32,
+ device=block_table.device,
+ )
+ bt = block_table.to(torch.int64)
+ for b in range(bsz):
+ for h in range(num_kv_heads):
+ lp_idx = 0
+ for blk_i in range(max_blocks):
+ bid = int(bt[b, blk_i].item())
+ if bid < 0:
+ continue
+ if bid >= num_blocks:
+ raise ValueError(
+ f"block_table[{b},{blk_i}]={bid} out of range "
+ f"num_blocks={num_blocks}"
+ )
+ base_row = h * (num_blocks * block_size) + bid * block_size
+ for p in range(num_pages_per_block):
+ start_row = base_row + p * page_size
+ if start_row >= base_row + block_size:
+ break
+ phys = start_row // page_size
+ out[b, h, lp_idx] = int(phys)
+ lp_idx += 1
+ return out
+
+
+def flatten_kv_cache_plane(
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ num_kv_heads: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """View (num_blocks, block_size, HKV, D) caches as [num_blocks*block_size*HKV, D].
+
+ This matches compactor row indexing only when HKV == 1 (see module doc).
+ """
+ if num_kv_heads != 1:
+ raise ValueError(
+ "flatten_kv_cache_plane requires num_kv_heads==1 for compactor layout"
+ )
+ if key_cache.shape != value_cache.shape:
+ raise ValueError("key_cache and value_cache must match")
+ # [num_blocks, block_size, 1, D] -> [num_blocks * block_size, D]
+ nb, bs, hkv, d = key_cache.shape
+ if hkv != 1:
+ raise ValueError("expected num_kv_heads==1")
+ k_flat = key_cache.reshape(nb * bs, d)
+ v_flat = value_cache.reshape(nb * bs, d)
+ if not k_flat.is_contiguous():
+ k_flat = k_flat.contiguous()
+ if not v_flat.is_contiguous():
+ v_flat = v_flat.contiguous()
+ return k_flat, v_flat
+
+
+def block_table_to_global_page_table(
+ block_table: torch.Tensor,
+ num_kv_heads: int,
+ max_batches: int,
+) -> torch.Tensor:
+ """Build [max_batches, HKV, num_logical_pages] int32 page table.
+
+ For MQA, every KV head reuses the same physical block ids as vLLM's table.
+ """
+ # block_table: [num_reqs_padded, max_num_blocks]
+ bsz, max_lp = block_table.shape
+ if bsz > max_batches:
+ raise ValueError("batch size exceeds max_batches for page table")
+ out = torch.zeros(
+ (max_batches, num_kv_heads, max_lp),
+ dtype=torch.int32,
+ device=block_table.device,
+ )
+ bt = block_table.to(torch.int32)[:bsz]
+ if num_kv_heads == 1:
+ out[:bsz, 0, :max_lp] = bt
+ else:
+ for h in range(num_kv_heads):
+ out[:bsz, h, :max_lp] = bt
+ return out
+
+
+def build_batch_mapping(num_reqs: int, device: torch.device) -> torch.Tensor:
+ """Local batch index -> global batch row (identity)."""
+ return torch.arange(num_reqs, dtype=torch.int32, device=device)
diff --git a/vllm/kvprune_legacy_save/utils/sequence.py b/vllm/kvprune_legacy_save/utils/sequence.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a7934cd1ca66a963a874c87a9eee1ccbf9cda3c
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/sequence.py
@@ -0,0 +1,83 @@
+from dataclasses import dataclass, field
+from enum import Enum, auto
+from itertools import count
+from typing import List
+
+from vllm.kvprune.compression.compression_config import SequenceCompressionParams
+from vllm.kvprune.config.sampling_params import SamplingParams
+
+
+class SequenceStatus(Enum):
+ WAITING = auto()
+ RUNNING = auto()
+ FINISHED = auto()
+
+
+@dataclass
+class Sequence:
+ """
+ Represents a single user request / sequence being generated.
+ """
+
+ _counter = count()
+
+ prompt_token_ids: List[int]
+ completion_token_ids: List[int] = field(default_factory=list)
+ sampling_params: SamplingParams = field(default_factory=SamplingParams)
+ compression_params: SequenceCompressionParams = field(
+ default_factory=SequenceCompressionParams
+ )
+ status: SequenceStatus = SequenceStatus.WAITING
+
+ seq_id: int = field(default_factory=lambda: next(Sequence._counter), init=False)
+ num_tokens_processed: int = 0
+
+ @property
+ def num_prompt_tokens(self) -> int:
+ return len(self.prompt_token_ids)
+
+ @property
+ def num_generated_tokens(self) -> int:
+ return len(self.completion_token_ids)
+
+ def add_new_token(self, token_id: int) -> None:
+ if len(self.completion_token_ids) == 0:
+ self.num_tokens_processed += self.num_prompt_tokens
+ self.completion_token_ids.append(token_id)
+ self.num_tokens_processed += 1
+
+ def tokens_to_retain_per_layer(self, num_kv_heads: int) -> int:
+ n = int(
+ self.compression_params.compression_ratio
+ * self.num_prompt_tokens
+ * num_kv_heads
+ )
+ return max(1, n)
+
+ def __getstate__(self):
+ return dict(
+ prompt_token_ids=list(self.prompt_token_ids),
+ completion_token_ids=list(self.completion_token_ids),
+ sampling_params=self.sampling_params,
+ compression_params=self.compression_params,
+ status=self.status,
+ seq_id=self.seq_id,
+ num_tokens_processed=self.num_tokens_processed,
+ )
+
+ def __setstate__(self, state):
+ self.prompt_token_ids = list(state["prompt_token_ids"])
+ self.completion_token_ids = list(state["completion_token_ids"])
+ self.sampling_params = state["sampling_params"]
+ self.compression_params = state["compression_params"]
+ self.status = state["status"]
+ self.seq_id = state["seq_id"]
+ self.num_tokens_processed = state["num_tokens_processed"]
+
+ @property
+ def prompt_len(self) -> int:
+ return len(self.prompt_token_ids)
+
+ @property
+ def completion_len(self) -> int:
+ return len(self.completion_token_ids)
diff --git a/vllm/kvprune_legacy_save/utils/tp_collectives.py b/vllm/kvprune_legacy_save/utils/tp_collectives.py
new file mode 100644
index 0000000000000000000000000000000000000000..855792aa8f524dcf94e8655a307d9fcc64721261
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/tp_collectives.py
@@ -0,0 +1,48 @@
+"""Tensor-parallel collectives for kvprune (match vLLM TP process group when embedded)."""
+
+from __future__ import annotations
+
+import torch.distributed as dist
+
+
+def tensor_parallel_all_reduce(tensor: torch.Tensor) -> torch.Tensor:
+ """All-reduce across tensor-parallel ranks (in-place on ``tensor`` when possible).
+
+ When vLLM :mod:`vllm.distributed.parallel_state` is initialized (e.g. kvprune
+ runs inside a vLLM GPU worker), uses the same TP NCCL group as the main model
+ (:func:`~vllm.distributed.communication_op.tensor_model_parallel_all_reduce`).
+
+ vLLM's TP :meth:`~vllm.distributed.parallel_state.GroupCoordinator.all_reduce`
+ is **out-of-place** and returns a new tensor. Call sites such as
+ :class:`~vllm.kvprune.layers.linear.RowParallelLinear` historically invoked
+ ``tensor_parallel_all_reduce(y)`` without using the return value, which left
+ ``y`` as the **unreduced** per-rank partial output under TP>1 — wrong activations,
+ wrong logits, and garbage tokens. We copy the reduced result back into ``tensor``
+ so existing call sites remain correct.
+
+ Standalone kvprune subprocesses only have the default process group (world ==
+ ``tensor_parallel_size``); in that case we fall back to :func:`torch.distributed.all_reduce`
+ on the default group.
+ """
+ if not dist.is_initialized() or dist.get_world_size() <= 1:
+ return tensor
+ try:
+ from vllm.distributed.parallel_state import model_parallel_is_initialized
+
+ if model_parallel_is_initialized():
+ from vllm.distributed.communication_op import (
+ tensor_model_parallel_all_reduce as vllm_tp_all_reduce,
+ )
+
+ reduced = vllm_tp_all_reduce(tensor)
+ if reduced is not tensor:
+ # vLLM TP all_reduce is out-of-place: `reduced` holds the cross-rank sum.
+ # Call sites ignore the return value and expect `tensor` to be updated — we
+ # MUST materialize the reduced values here or TP>1 keeps per-rank partials
+ # (RowParallel / VocabParallel outputs stay wrong without this copy).
+ tensor.copy_(reduced)
+ return tensor
+ except Exception:
+ pass
+ dist.all_reduce(tensor)
+ return tensor
diff --git a/vllm/kvprune_legacy_save/utils/tp_utils.py b/vllm/kvprune_legacy_save/utils/tp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e829f1f400ba5a162a9c6a1164f708fa229cfa9
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/tp_utils.py
@@ -0,0 +1,40 @@
+"""Tensor-parallel helpers for kvprune when embedded in a vLLM worker."""
+
+from __future__ import annotations
+
+import torch.distributed as dist
+
+
+def tensor_parallel_rank_for_sharding() -> int:
+ """Rank within the tensor-parallel group (matches vLLM weight shards when embedded).
+
+ Falls back to :func:`torch.distributed.get_rank` when vLLM parallel state is
+ unavailable (standalone kvprune with only the default process group).
+ """
+ try:
+ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
+
+ return int(get_tensor_model_parallel_rank())
+ except Exception:
+ if dist.is_initialized():
+ return int(dist.get_rank())
+ return 0
+
+
+def tensor_parallel_world_size_for_sharding() -> int:
+ """World size of the tensor-parallel group."""
+ try:
+ from vllm.distributed.parallel_state import (
+ get_tensor_model_parallel_world_size,
+ )
+
+ return int(get_tensor_model_parallel_world_size())
+ except Exception:
+ if dist.is_initialized():
+ return int(dist.get_world_size())
+ return 1
+
+
+def kv_heads_shard_divisor() -> int:
+ """Return world size used to shard KV heads (TP group when vLLM is loaded)."""
+ return tensor_parallel_world_size_for_sharding()
diff --git a/vllm/kvprune_legacy_save/utils/triton_compat.py b/vllm/kvprune_legacy_save/utils/triton_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..65a459c0bddeaf38d594177abc2e0bfb07533b8e
--- /dev/null
+++ b/vllm/kvprune_legacy_save/utils/triton_compat.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+import inspect
+from typing import Any, Callable, Mapping
+
+import torch
+
+
+def _filter_kwargs_for_callable(
+ fn: Callable[..., Any], kwargs: Mapping[str, Any]
+) -> dict[str, Any]:
+ try:
+ params = inspect.signature(fn).parameters
+ except (TypeError, ValueError):
+ return dict(kwargs)
+ return {k: v for k, v in kwargs.items() if k in params}
+
+
+def autotune(*, configs, key, **kwargs):
+ """
+ Compatibility wrapper around `triton.autotune`.
+
+ Some Triton builds (e.g., custom vendor builds) may not support newer
+ keyword arguments like `cache_results`. This wrapper filters unsupported
+ kwargs based on the runtime `triton.autotune` signature.
+ """
+ import triton
+
+ filtered = _filter_kwargs_for_callable(triton.autotune, kwargs)
+ return triton.autotune(configs=configs, key=key, **filtered)
+
+
+def maybe_set_allocator(alloc_fn: Callable[[int, int, int | None], Any]) -> bool:
+ """
+ Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
+
+ Returns True if the allocator was set.
+ """
+ import triton
+
+ setter = getattr(triton, "set_allocator", None)
+ if setter is None:
+ return False
+ setter(alloc_fn)
+ return True
+
+
+def cuda_capability_geq(major: int, minor: int = 0, device: int | None = None) -> bool:
+ """
+ Host-side CUDA capability check that works even when `tl.target_info` is absent.
+ """
+ if not torch.cuda.is_available():
+ return False
+ if device is None:
+ try:
+ device = torch.cuda.current_device()
+ except Exception:
+ device = 0
+ cap = torch.cuda.get_device_capability(device)
+ return cap >= (major, minor)
+
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index b5fa071df2f77fe885da2a5e18d121de24ab4d4d..d3072127b693d07a68728bbaf323edbee3231e66 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -660,8 +660,16 @@ class RocmPlatform(Platform):
):
compilation_config.custom_ops.append("+rotary_embedding")
- # Default dispatch to rocm's sparse_attn_indexer implementation
- compilation_config.custom_ops.append("+sparse_attn_indexer")
+ # Only DeepSeek MLA models use sparse_attn_indexer. Avoid enabling it
+ # for unrelated models (e.g. Qwen3), which only produces a
+ # "has no effect" warning during compilation checks.
+ model_type = getattr(vllm_config.model_config.hf_config, "model_type", None)
+ if (
+ model_type in {"deepseek_v2", "deepseek_v3"}
+ and "+sparse_attn_indexer" not in compilation_config.custom_ops
+ and "-sparse_attn_indexer" not in compilation_config.custom_ops
+ ):
+ compilation_config.custom_ops.append("+sparse_attn_indexer")
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py
index 33f7090e4e3d2a30609ba2ca2ca16d2c8b98dd5a..fb49241e54971b4793505d1f5665a7a33f81ea7f 100644
--- a/vllm/v1/sample/ops/topk_topp_sampler.py
+++ b/vllm/v1/sample/ops/topk_topp_sampler.py
@@ -248,7 +248,12 @@ def apply_top_k_top_p(
if p is None and k is None:
return logits
- if HAS_TRITON and logits.shape[0] >= 8:
+ # Triton top-k/top-p can VMFault on some ROCm stacks; PyTorch path is stable.
+ if (
+ HAS_TRITON
+ and logits.shape[0] >= 8
+ and not current_platform.is_rocm()
+ ):
return apply_top_k_top_p_triton(logits, k, p)
# Use pytorch sort implementation for small batch sizes.
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index 6d117175b1ad8d017ed51ec006c772321c1d421b..eb47a4ac24d04c47e628d56e244e2447e116c11c 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -719,6 +719,13 @@ class Worker(WorkerBase):
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
+ def kvprune_v1_compressed_generate(self, payload: dict[str, Any]) -> dict[str, Any]:
+ """KV-prune compactor path for tensor parallel size > 1 (all ranks collective)."""
+ from vllm.kvprune.integration.v1_tp_runner import (
+ run_kvprune_tp_compressed_generate,
+ )
+
+ return run_kvprune_tp_compressed_generate(self, payload)
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
diff --git a/vllm/version.py b/vllm/version.py
index 63095f8bce1ea4fe4800deaac7b2c7f05b38ff11..cfe38b4a19746acb834f9d2adac37f490a5e1287 100644
--- a/vllm/version.py
+++ b/vllm/version.py
@@ -1,19 +1,22 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
- from ._version import __version__, __version_tuple__
+ __version__ = "0.18.1"
+ __version_tuple__ = (0, 18, 1)
+ __hcu_version__ = f'0.18.1+das.dtk2604'
+
+ from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
- warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2)
-
+ warnings.warn(f"Failed to read commit hash:\n + str(e)",
+ RuntimeWarning,
+ stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
-
-
+
+
def _prev_minor_version_was(version_str):
- """Check whether a given version matches the previous minor version.
+ '''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
@@ -21,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
- """
+ '''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
- assert __version_tuple__[0] == 0
+ # assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
- """For the purpose of testing, return a previous minor version number."""
+ '''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"