Commit 2b7160c6 authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.0

parent fa718036
......@@ -21,59 +21,149 @@ For events, please visit [vllm.ai/events](https://vllm.ai/events) to join us.
## About
vLLM is a fast and easy-to-use library for LLM inference and serving.
The model compression function of kv cache pruning has been added to the official vllm.
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
vLLM prune with:
vLLM is fast with:
- [**SNAPKV**](https://arxiv.org/pdf/2404.14469)
- [**COMPACTOR**](https://arxiv.org/pdf/2507.08143)
- [**CRITICALADAKV**](https://arxiv.org/pdf/2502.03805)
- State-of-the-art serving throughput
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
- Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
- Speculative decoding
- Chunked prefill
vLLM is flexible and easy to use with:
- Seamless integration with popular Hugging Face models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor, pipeline, data and expert parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
- Prefix caching support
- Multi-LoRA support
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
- Embedding Models (e.g., E5-Mistral)
- Multi-modal LLMs (e.g., LLaVA)
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
- Transformer-like LLMs (e.g., Qwen3/Llama)
## Getting Started
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
## Env
```bash
pip install vllm
cd vllm
python use_existing_torch.py
# then add torch in requires of pyproject.toml
export SETUPTOOLS_SCM_PRETEND_VERSION_FOR_VLLM="0.6.0"
pip install -e . --no-build-isolation -v -i https://mirrors.aliyun.com/pypi/simple/
pip install numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/
```
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
More related libraries:
- flash_attn-2.8.3+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
- torchvision-0.24.0+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
- triton-3.5.1+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
## Quick Start
Basic Chat Generation with Compression:
```
python test.py --schedule pdtriton
```
test.py:
```python
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# PYTHONPATH=/home/vllm-project/vllm python test.py --schedule pdtriton
from __future__ import annotations
import argparse
import os
import sys
from multiprocessing import freeze_support
def _apply_kvprune_attention_env(schedule: str | None) -> None:
"""Map CLI -> VLLM_KVPRUNE_ATTENTION_SCHEDULE (fa_triton | pdtriton | pdfa)."""
if not schedule:
return
os.environ["VLLM_KVPRUNE_ATTENTION_SCHEDULE"] = schedule
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--schedule",
type=str,
default="pdtriton",
choices=("fa_triton", "pdtriton", "pdfa"),
help=(
"fa_triton=FA prefill + Triton decode;"
"pdtriton=Triton prefill + Triton decode;"
"pdfa=FA prefill + FA decode (page KV writing is Triton);"
),
)
args, _unknown = parser.parse_known_args()
_apply_kvprune_attention_env(args.schedule)
from transformers import AutoTokenizer
from vllm import CompressionParams, LLM, SamplingParams
model_id = "Qwen/Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.8,
repetition_penalty=1.05,
max_tokens=512,
)
llm = LLM(
model=model_id,
tensor_parallel_size=4,
max_model_len=8192,
gpu_memory_utilization=0.85,
kvprune_compression=True,
)
prompt = (
"Write a 200-word English prompt for a creative writing task. The prompt should be "
"a single coherent paragraph without any bullet points, numbered lists, or markdown "
"formatting. It should describe a specific scenario, character, or conflict, and end "
"with a clear question that invites the writer to continue the story. Do not use any "
"special symbols or line breaks. The tone can be mysterious, tense, or reflective. "
"After the paragraph, include the question on the same line directly following the "
"period, without hitting enter."
)
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True, # True
)
compression = [
CompressionParams(
compression_ratio=0.5,
compression_method="snapkv",
),
]
outputs = llm.generate(
[text],
sampling_params=sampling_params,
compression=compression,
)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
if __name__ == "__main__":
freeze_support()
main()
```
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
## Contributing
We welcome and value any contributions and collaborations.
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
## Citation
......
<!-- markdownlint-disable MD001 MD041 -->
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png">
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-light.png" width=55%>
</picture>
</p>
<h3 align="center">
Easy, fast, and cheap LLM serving for everyone
</h3>
<p align="center">
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
</p>
🔥 We have built a vllm website to help you get started with vllm. Please visit [vllm.ai](https://vllm.ai) to learn more.
For events, please visit [vllm.ai/events](https://vllm.ai/events) to join us.
---
## About
vLLM is a fast and easy-to-use library for LLM inference and serving.
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
vLLM is fast with:
- State-of-the-art serving throughput
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
- Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
- Speculative decoding
- Chunked prefill
vLLM is flexible and easy to use with:
- Seamless integration with popular Hugging Face models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor, pipeline, data and expert parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
- Prefix caching support
- Multi-LoRA support
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
- Embedding Models (e.g., E5-Mistral)
- Multi-modal LLMs (e.g., LLaVA)
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
## Getting Started
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
```bash
pip install vllm
```
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
## Contributing
We welcome and value any contributions and collaborations.
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
## Citation
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
```bibtex
@inproceedings{kwon2023efficient,
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
year={2023}
}
```
## Contact Us
<!-- --8<-- [start:contact-us] -->
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
- For collaborations and partnerships, please contact us at [collaboration@vllm.ai](mailto:collaboration@vllm.ai)
<!-- --8<-- [end:contact-us] -->
## Media Kit
- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit)
......@@ -8,6 +8,16 @@
#include "cuda_vec_utils.cuh"
#include "dispatch_utils.h"
// ROCm/HIP often assumes at most 256 threads per block unless the kernel
// declares otherwise; launching more triggers runtime warnings / UB. NVIDIA
// CUDA builds keep the original 1024 cap. No __launch_bounds__ on templated
// kernels here — HIP/clang can fail to compile those (see act_and_mul_kernel).
#ifdef USE_ROCM
#define VLLM_ACTIVATION_GATE_MAX_THREADS 256
#else
#define VLLM_ACTIVATION_GATE_MAX_THREADS 1024
#endif
namespace vllm {
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
......@@ -170,7 +180,7 @@ packed_gelu_tanh_kernel(const packed_t& val) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
dim3 block(std::min(d / vec_size, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
......@@ -191,7 +201,7 @@ packed_gelu_tanh_kernel(const packed_t& val) {
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
......@@ -387,7 +397,7 @@ __global__ void swigluoai_and_mul_kernel(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
dim3 block(std::min(d / vec_size, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
......@@ -414,7 +424,7 @@ __global__ void swigluoai_and_mul_kernel(
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
......@@ -429,7 +439,7 @@ __global__ void swigluoai_and_mul_kernel(
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
......@@ -520,7 +530,7 @@ __global__ void activation_kernel(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
dim3 block(std::min(d / vec_size, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
......@@ -535,7 +545,7 @@ __global__ void activation_kernel(
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, VLLM_ACTIVATION_GATE_MAX_THREADS)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, false> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
......
......@@ -74,7 +74,10 @@ inline __device__ void apply_rotary_embedding(
}
template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
// HIP/ROCm commonly enforces max 256 threads/block unless explicitly raised.
// Keep blockDim.x <= 256 here so launch always matches compiler bounds (avoids
// UB when __launch_bounds__ and host code get out of sync across rebuilds).
__global__ __launch_bounds__(256, 1) void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
......@@ -162,7 +165,7 @@ void rotary_embedding(
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 256));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
python -P test-vllmkvprune.py
"""
from __future__ import annotations
import os
import sys
from multiprocessing import freeze_support
from transformers import AutoTokenizer
from vllm import CompressionParams, LLM, SamplingParams
def main() -> None:
model_id = "Qwen/Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.8,
repetition_penalty=1.05,
max_tokens=512,
)
# TP=1:进程内共享权重 compactor。TP>1 时请用 ``qwen3_kvprune_tp4_inference.py`` 或自行设
# ``tensor_parallel_size>=2`` 并传 ``compression``(走 collective_rpc + 每卡 ModelRunner)。
llm = LLM(
model=model_id,
tensor_parallel_size=1,
max_model_len=8192,
gpu_memory_utilization=0.85,
kvprune_compression=True,
)
prompt = "Give me a short introduction to large language models."
messages = [
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
# 剪枝:compression_ratio < 1 时走 compactor;每条 prompt 对应一条 CompressionParams。
compression = [
CompressionParams(
compression_ratio=0.5,
compression_method="compactor",
),
]
outputs = llm.generate(
[text],
sampling_params=sampling_params,
compression=compression,
)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
if __name__ == "__main__":
freeze_support()
main()
......@@ -6,7 +6,7 @@ requires = [
"packaging>=24.2",
"setuptools>=77.0.3,<81.0.0",
"setuptools-scm>=8.0",
"torch == 2.10.0",
"torch",
"wheel",
"jinja2",
]
......@@ -171,4 +171,3 @@ ser = "ser"
ure = "ure"
[tool.uv]
no-build-isolation-package = ["torch"]
......@@ -4,7 +4,6 @@ ninja
packaging>=24.2
setuptools>=77.0.3,<81.0.0
setuptools-scm>=8
torch==2.10.0
wheel
jinja2>=3.1.6
regex
......
......@@ -3,8 +3,6 @@ ninja
packaging>=24.2
setuptools==77.0.3 # this version can reuse CMake build dir
setuptools-scm>=8
torch==2.10.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.10.0; platform_machine == "aarch64" or platform_system == "Darwin" or platform_machine == "ppc64le"
wheel
jinja2>=3.1.6
regex
......@@ -6,16 +6,9 @@ setuptools==77.0.3 # this version can reuse CMake build dir
numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding
# Dependencies for CPUs
torch==2.10.0+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.10.0; platform_machine == "aarch64" or platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "riscv64"
# required for the image processor of minicpm-o-2_6, this must be updated alongside torch
torchaudio; platform_machine != "s390x" and platform_machine != "riscv64"
# required for the image processor of phi3v, this must be updated alongside torch
torchvision; platform_machine != "s390x" and platform_machine != "riscv64"
# Intel Extension for PyTorch, only for x86_64 CPUs
intel-openmp==2024.2.1; platform_machine == "x86_64"
# Use this to gather CPU info and optimize based on ARM Neoverse cores
......
......@@ -4,10 +4,6 @@
numba == 0.61.2 # Required for N-gram speculative decoding
# Dependencies for NVIDIA GPUs
torch==2.10.0
torchaudio==2.10.0
# These must be updated alongside torch
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.6
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
......
# Common dependencies
-r common.txt
--extra-index-url https://download.pytorch.org/whl/rocm7.1
torch==2.10.0
torchvision==0.25.0
torchaudio==2.10.0
triton==3.6.0
cmake>=3.26.1,<4
packaging>=24.2
......
......@@ -69,8 +69,6 @@ multiprocess==0.70.16
# Required for v1/metrics/test_engine_logger_apis.py
ray[cgraph,default]>=2.48.0
torchgeo==0.7.0
# via terratorch
# MTEB Benchmark Test
mteb[bm25s]>=2, <3
......@@ -85,7 +83,6 @@ fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.
# Required for suffix decoding test
arctic-inference == 0.1.1
# Required for Nemotron test
open-clip-torch==2.32.0
# Required for isaac Multi-Modal generation test
perceptron==0.1.4
# Required for the multi-modal models test
......@@ -99,9 +96,7 @@ huggingface-hub==0.36.2
# Pin Mistral Common
mistral-common[image,audio]==1.10.0
# Required for Prithvi tests
terratorch==1.2.2
# Required for Prithvi tests
segmentation-models-pytorch==0.5.0
# Required for Prithvi tests
imagehash==4.3.2
# Required for bitsandbytes quantization test
......
......@@ -23,6 +23,4 @@ timm>=1.0.17
amd-quark>=0.8.99
# Other necessary dependencies
torch == 2.10.0
torchvision == 0.25.0
flash_attn == 2.8.3
\ No newline at end of file
......@@ -16,7 +16,6 @@ blobfile # required for kimi-vl test
einops # required for MPT, qwen-vl
httpx
librosa # required for audio tests
vector_quantize_pytorch # required for minicpmo_26 test
vocos # required for minicpmo_26 test
peft>=0.15.0 # required for phi-4-mm test
pqdm
......@@ -26,14 +25,10 @@ soundfile # required for audio tests
jiwer # required for audio tests
tblib # for pickling test exceptions
timm >=1.0.17 # required for internvl and gemma3n-mm test
torch==2.10.0
torchaudio==2.10.0
torchvision==0.25.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.9.1 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless >= 4.13.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]>=0.4.11 # required for model evaluation test
......@@ -61,17 +56,13 @@ fastsafetensors>=0.2.2 # 0.2.2 contains important fixes for multi-GPU mem usage
instanttensor>=0.1.5
pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0
terratorch >= 1.2.2 # Required for Prithvi tests
imagehash # Required for Prithvi tests
segmentation-models-pytorch > 0.4.0 # Required for Prithvi tests
gpt-oss >= 0.0.7; python_version > '3.11'
perceptron # required for isaac test
kaldi-native-fbank >= 1.18.7 # required for fireredasr2 test
# Newer versions of datasets require torchcoded, that makes the tests fail in CI because of a missing library.
# Older versions are in conflict with teerratorch requirements.
datasets>=3.3.0,<=3.6.0
openpyxl # required for perf comparison excel report
......
# This file was autogenerated by uv via the following command:
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28 --python-version 3.12
absl-py==2.1.0
# via
# rouge-score
......@@ -25,11 +24,9 @@ aiohttp-cors==0.8.1
aiosignal==1.4.0
# via aiohttp
albucore==0.0.16
# via terratorch
albumentations==1.4.6
# via
# -r requirements/test.in
# terratorch
alembic==1.16.4
# via optuna
annotated-doc==0.0.4
......@@ -165,7 +162,6 @@ cryptography==46.0.5
# msal
# pyjwt
cuda-bindings==12.9.4
# via torch
cuda-pathfinder==1.3.3
# via cuda-bindings
cupy-cuda12x==13.6.0
......@@ -189,7 +185,6 @@ decorator==5.1.1
decord==0.6.0
# via -r requirements/test.in
diffusers==0.36.0
# via terratorch
dill==0.3.8
# via
# datasets
......@@ -210,12 +205,8 @@ einops==0.8.1
# via
# -r requirements/test.in
# encodec
# terratorch
# torchgeo
# vector-quantize-pytorch
# vocos
einx==0.3.0
# via vector-quantize-pytorch
email-validator==2.2.0
# via pydantic
encodec==0.1.1
......@@ -239,11 +230,9 @@ filelock==3.16.1
# diffusers
# huggingface-hub
# ray
# torch
# transformers
# virtualenv
fiona==1.10.1
# via torchgeo
fonttools==4.55.0
# via matplotlib
fqdn==1.5.1
......@@ -261,17 +250,13 @@ fsspec==2024.12.0
# fastparquet
# huggingface-hub
# lightning
# pytorch-lightning
# tacoreader
# torch
ftfy==6.3.1
# via open-clip-torch
genai-perf==0.0.16
# via -r requirements/test.in
genson==1.3.0
# via datamodel-code-generator
geopandas==1.0.1
# via terratorch
gitdb==4.0.12
# via gitpython
gitpython==3.1.44
......@@ -320,7 +305,6 @@ h11==0.14.0
h2==4.3.0
# via httpx
h5py==3.13.0
# via terratorch
harfile==0.3.0
# via schemathesis
hf-xet==1.1.7
......@@ -345,11 +329,8 @@ huggingface-hub==0.36.2
# datasets
# diffusers
# evaluate
# open-clip-torch
# peft
# segmentation-models-pytorch
# sentence-transformers
# terratorch
# timm
# tokenizers
# transformers
......@@ -406,7 +387,6 @@ jinja2==3.1.6
# datamodel-code-generator
# genai-perf
# lm-eval
# torch
jiwer==3.0.5
# via -r requirements/test.in
jmespath==1.0.1
......@@ -421,7 +401,6 @@ joblib==1.4.2
jsonargparse==4.46.0
# via
# lightning
# terratorch
jsonlines==4.0.0
# via lm-eval
jsonnet==0.21.0
......@@ -445,7 +424,6 @@ kaleido==0.2.1
kiwisolver==1.4.7
# via matplotlib
kornia==0.8.1
# via torchgeo
kornia-rs==0.1.9
# via kornia
lazy-loader==0.4
......@@ -458,19 +436,13 @@ librosa==0.10.2.post1
# via -r requirements/test.in
lightly==1.5.22
# via
# terratorch
# torchgeo
lightly-utils==0.0.2
# via lightly
lightning==2.6.1
# via
# terratorch
# torchgeo
lightning-utilities==0.14.3
# via
# lightning
# pytorch-lightning
# torchmetrics
llvmlite==0.44.0
# via numba
lm-eval==0.4.11
......@@ -496,7 +468,6 @@ matplotlib==3.9.2
# -r requirements/test.in
# lightning
# pycocotools
# torchgeo
mbstrdecoder==1.1.3
# via
# dataproperty
......@@ -535,7 +506,6 @@ mypy-extensions==1.0.0
networkx==3.2.1
# via
# scikit-image
# torch
nltk==3.9.1
# via rouge-score
num2words==0.5.14
......@@ -591,18 +561,13 @@ numpy==2.2.6
# scikit-image
# scikit-learn
# scipy
# segmentation-models-pytorch
# shapely
# soxr
# statsmodels
# tensorboard
# tensorboardx
# tensorizer
# terratorch
# tifffile
# torchgeo
# torchmetrics
# torchvision
# transformers
# tritonclient
# vocos
......@@ -611,46 +576,30 @@ nvidia-cublas-cu12==12.9.1.4
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.9.79
# via torch
nvidia-cuda-nvrtc-cu12==12.9.86
# via torch
nvidia-cuda-runtime-cu12==12.9.79
# via torch
nvidia-cudnn-cu12==9.10.2.21
# via torch
nvidia-cufft-cu12==11.4.1.4
# via torch
nvidia-cufile-cu12==1.14.1.1
# via torch
nvidia-curand-cu12==10.3.10.19
# via torch
nvidia-cusolver-cu12==11.7.5.82
# via torch
nvidia-cusparse-cu12==12.5.10.65
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.7.1
# via torch
nvidia-nccl-cu12==2.27.5
# via torch
nvidia-nvjitlink-cu12==12.9.86
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvshmem-cu12==3.4.5
# via torch
nvidia-nvtx-cu12==12.9.79
# via torch
omegaconf==2.3.0
# via
# hydra-core
# lightning
open-clip-torch==2.32.0
# via -r requirements/test.in
openai-harmony==0.0.4
# via gpt-oss
......@@ -709,14 +658,12 @@ packaging==24.2
# pyogrio
# pytest
# pytest-rerunfailures
# pytorch-lightning
# ray
# rioxarray
# scikit-image
# statsmodels
# tensorboard
# tensorboardx
# torchmetrics
# transformers
# typepy
# wandb
......@@ -730,7 +677,6 @@ pandas==2.2.3
# geopandas
# statsmodels
# tacoreader
# torchgeo
# xarray
pathspec==0.12.1
# via black
......@@ -755,10 +701,7 @@ pillow==10.4.0
# mistral-common
# perceptron
# scikit-image
# segmentation-models-pytorch
# tensorboard
# torchgeo
# torchvision
platformdirs==4.3.6
# via
# black
......@@ -817,7 +760,6 @@ pyarrow==23.0.0
# datasets
# genai-perf
# tacoreader
# terratorch
pyasn1==0.6.1
# via
# pyasn1-modules
......@@ -825,7 +767,6 @@ pyasn1==0.6.1
pyasn1-modules==0.4.2
# via google-auth
pycocotools==2.0.8
# via terratorch
pycountry==24.6.1
# via pydantic-extra-types
pycparser==2.22
......@@ -864,7 +805,6 @@ pyproj==3.7.1
# via
# geopandas
# rioxarray
# torchgeo
pyrate-limiter==3.7.0
# via schemathesis
pystemmer==3.0.0
......@@ -902,7 +842,6 @@ pytest-subtests==0.14.1
pytest-timeout==2.3.1
# via -r requirements/test.in
python-box==7.3.2
# via terratorch
python-dateutil==2.9.0.post0
# via
# arrow
......@@ -913,7 +852,6 @@ python-dateutil==2.9.0.post0
# typepy
python-rapidjson==1.20
# via tritonclient
pytorch-lightning==2.5.2
# via
# lightly
# lightning
......@@ -938,7 +876,6 @@ pyyaml==6.0.2
# omegaconf
# optuna
# peft
# pytorch-lightning
# ray
# responses
# schemathesis
......@@ -951,8 +888,6 @@ rapidfuzz==3.12.1
rasterio==1.4.3
# via
# rioxarray
# terratorch
# torchgeo
ray==2.48.0
# via -r requirements/test.in
redis==5.2.0
......@@ -965,7 +900,6 @@ regex==2024.9.11
# via
# diffusers
# nltk
# open-clip-torch
# sacrebleu
# tiktoken
# transformers
......@@ -1007,10 +941,8 @@ rich==13.9.4
# lightning
# mteb
# perceptron
# terratorch
# typer
rioxarray==0.19.0
# via terratorch
rouge-score==0.1.2
# via lm-eval
rpds-py==0.20.1
......@@ -1020,7 +952,6 @@ rpds-py==0.20.1
rsa==4.9.1
# via google-auth
rtree==1.4.0
# via torchgeo
runai-model-streamer==0.15.7
# via -r requirements/test.in
runai-model-streamer-azure==0.15.7
......@@ -1037,9 +968,7 @@ safetensors==0.4.5
# via
# accelerate
# diffusers
# open-clip-torch
# peft
# segmentation-models-pytorch
# timm
# transformers
schemathesis==3.39.15
......@@ -1047,7 +976,6 @@ schemathesis==3.39.15
scikit-image==0.25.2
# via
# albumentations
# terratorch
scikit-learn==1.5.2
# via
# albumentations
......@@ -1055,7 +983,6 @@ scikit-learn==1.5.2
# lm-eval
# mteb
# sentence-transformers
# terratorch
scipy==1.13.1
# via
# albumentations
......@@ -1068,11 +995,8 @@ scipy==1.13.1
# sentence-transformers
# statsmodels
# vocos
segmentation-models-pytorch==0.5.0
# via
# -r requirements/test.in
# terratorch
# torchgeo
sentence-transformers==5.2.0
# via
# -r requirements/test.in
......@@ -1084,11 +1008,9 @@ setuptools==77.0.3
# lightning-utilities
# pytablewriter
# tensorboard
# torch
shapely==2.1.1
# via
# geopandas
# torchgeo
shellingham==1.5.4
# via
# perceptron
......@@ -1141,13 +1063,11 @@ structlog==25.4.0
sympy==1.13.3
# via
# einx
# torch
tabledata==1.3.3
# via pytablewriter
tabulate==0.9.0
# via sacrebleu
tacoreader==0.5.6
# via terratorch
tblib==3.1.0
# via -r requirements/test.in
tcolorpy==0.1.6
......@@ -1158,7 +1078,6 @@ tenacity==9.1.2
# lm-eval
# plotly
tensorboard==2.20.0
# via terratorch
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
......@@ -1168,15 +1087,12 @@ tensorizer==2.10.1
termcolor==3.1.0
# via
# gpt-oss
# terratorch
terratorch==1.2.2
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn
tifffile==2025.3.30
# via
# scikit-image
# terratorch
tiktoken==0.12.0
# via
# gpt-oss
......@@ -1185,10 +1101,6 @@ tiktoken==0.12.0
timm==1.0.17
# via
# -r requirements/test.in
# open-clip-torch
# segmentation-models-pytorch
# terratorch
# torchgeo
tokenizers==0.22.0
# via
# -r requirements/test.in
......@@ -1197,7 +1109,6 @@ tomli==2.2.1
# via schemathesis
tomli-w==1.2.0
# via schemathesis
torch==2.10.0+cu129
# via
# -r requirements/test.in
# accelerate
......@@ -1208,43 +1119,22 @@ torch==2.10.0+cu129
# lightly
# lightning
# mteb
# open-clip-torch
# peft
# pytorch-lightning
# runai-model-streamer
# segmentation-models-pytorch
# sentence-transformers
# tensorizer
# terratorch
# timm
# torchaudio
# torchgeo
# torchmetrics
# torchvision
# vector-quantize-pytorch
# vocos
torchaudio==2.10.0+cu129
# via
# -r requirements/test.in
# encodec
# vocos
torchgeo==0.7.0
# via terratorch
torchmetrics==1.7.4
# via
# lightning
# pytorch-lightning
# terratorch
# torchgeo
torchvision==0.25.0+cu129
# via
# -r requirements/test.in
# lightly
# open-clip-torch
# segmentation-models-pytorch
# terratorch
# timm
# torchgeo
tqdm==4.67.3
# via
# datasets
......@@ -1255,15 +1145,11 @@ tqdm==4.67.3
# lm-eval
# mteb
# nltk
# open-clip-torch
# optuna
# peft
# pqdm
# pytorch-lightning
# segmentation-models-pytorch
# sentence-transformers
# tacoreader
# terratorch
# transformers
transformers==4.57.5
# via
......@@ -1275,7 +1161,6 @@ transformers==4.57.5
transformers-stream-generator==0.0.5
# via -r requirements/test.in
triton==3.6.0
# via torch
tritonclient==2.64.0
# via -r requirements/test.in
typepy==1.3.2
......@@ -1316,12 +1201,9 @@ typing-extensions==4.15.0
# pydantic
# pydantic-core
# pydantic-extra-types
# pytorch-lightning
# sentence-transformers
# sqlalchemy
# starlette
# torch
# torchgeo
# typer
# typeshed-client
# typing-inspection
......@@ -1344,14 +1226,12 @@ urllib3==2.2.3
# tritonclient
uvicorn==0.35.0
# via gpt-oss
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
virtualenv==20.31.2
# via ray
vocos==0.1.0
# via -r requirements/test.in
wandb==0.24.2
# via terratorch
wcwidth==0.2.13
# via ftfy
webcolors==24.11.1
......
......@@ -10,9 +10,5 @@ wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.61.2 # Required for N-gram speculative decoding
--extra-index-url=https://download.pytorch.org/whl/xpu
torch==2.10.0+xpu
torchaudio
torchvision
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.3/vllm_xpu_kernels-0.1.3-cp38-abi3-linux_x86_64.whl
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
# Keep v1 CUDA-graph tests valid: fork default for ``LLM(kvprune_compression=None)``
# is controlled by ``VLLM_KVPRUNE_COMPRESSION_DEFAULT`` (see ``env_override``).
os.environ.setdefault("VLLM_KVPRUNE_COMPRESSION_DEFAULT", "0")
os.environ.setdefault("VLLM_KVPRUNE_RELEASE_V1_KV", "0")
import contextlib
import pathlib
from copy import deepcopy
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
"""LongBench evaluation via vLLM ``LLM.generate`` + kvprune compression (same folder layout as RULER)."""
from __future__ import annotations
import json
import logging
import os
import sys
from pathlib import Path
from datasets import concatenate_datasets, load_dataset
_SCRIPT_DIR = Path(__file__).resolve().parent
if str(_SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPT_DIR))
from longbench_metrics import dataset2metric # noqa: E402
from vllm import LLM, SamplingParams # noqa: E402
from vllm.kvprune.integration.compression_params import CompressionParams # noqa: E402
def _hf_tokenizer(llm: LLM):
tok = llm.get_tokenizer()
return getattr(tok, "tokenizer", tok)
def messages_to_prompts(
llm: LLM,
messages: list[list[dict]],
*,
add_generation_prompt: bool,
enable_thinking: bool,
) -> list[str]:
inner = _hf_tokenizer(llm)
out: list[str] = []
kw: dict = {}
if enable_thinking:
kw["enable_thinking"] = True
for conv in messages:
text = inner.apply_chat_template(
conv,
tokenize=False,
add_generation_prompt=add_generation_prompt,
**kw,
)
out.append(text)
return out
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s"
)
cfg_dir = _SCRIPT_DIR / "longbench_config"
prompts = json.load(open(cfg_dir / "dataset2prompt.json", "r", encoding="utf-8"))
max_gen_lens = json.load(open(cfg_dir / "dataset2maxlen.json", "r", encoding="utf-8"))
datasets = [
"narrativeqa",
"qasper",
"multifieldqa_en",
"hotpotqa",
"2wikimqa",
"musique",
"gov_report",
"qmsum",
"multi_news",
"trec",
"triviaqa",
"samsum",
"passage_retrieval_en",
"passage_count",
"lcc",
"repobench-p",
]
dataset = concatenate_datasets(
[
load_dataset("THUDM/LongBench", n, split="test", trust_remote_code=True)
for n in datasets
]
).shuffle(seed=42)
dset_names = [
item["dataset"] if item["dataset"][-2:] != "_e" else item["dataset"][:-2]
for item in dataset
]
gen_lengths = [max_gen_lens[dset_name] for dset_name in dset_names]
messages = [
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompts[dset_name].format(**item)},
]
for dset_name, item in zip(dset_names, dataset)
]
model = os.environ.get("KVPRUNE_EVAL_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
tp = int(os.environ.get("KVPRUNE_EVAL_TP", "2"))
seq_ratio = float(os.environ.get("KVPRUNE_SEQ_COMPRESSION_RATIO", "0.25"))
llm = LLM(
model=model,
max_num_seqs=64,
gpu_memory_utilization=0.95,
tensor_parallel_size=tp,
max_model_len=128000,
kvprune_compression=True,
)
text_prompts = messages_to_prompts(
llm,
messages,
add_generation_prompt=True,
enable_thinking=False,
)
sampling_params = [
SamplingParams(max_tokens=g, temperature=0.00001) for g in gen_lengths
]
n = len(text_prompts)
compression = [
CompressionParams(
compression_ratio=seq_ratio,
compression_method="compactor",
protected_first_tokens=8,
protected_last_tokens=64,
)
] * n
outputs = llm.generate(text_prompts, sampling_params, compression=compression)
responses = [o.outputs[0].text for o in outputs]
results: dict = {}
for dset_name, prediction, item in zip(dset_names, responses, dataset):
results.setdefault(dset_name, [])
pred = prediction
if dset_name in ["trec", "triviaqa", "samsum", "lsht"]:
pred = pred.lstrip("\n").split("\n")[0]
score = 0.0
for ground_truth in item["answers"]:
score = max(
score,
dataset2metric[dset_name](
pred, ground_truth, all_classes=item["all_classes"]
),
)
results[dset_name].append(score)
all_sum, all_count = 0, 0
for task, scores in results.items():
avg = sum(scores) / len(scores)
print(task, f"{avg:.2f}")
all_sum += sum(scores)
all_count += len(scores)
print(f"ALL: {all_sum / all_count:.2f}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment