"src/vscode:/vscode.git/clone" did not exist on "98c5e5da31dd70facf92970074be49501cd5e20b"
Commit 51679bbd authored by zhuwenwen's avatar zhuwenwen
Browse files

resolve merge confilcts

parents 4095d0db 1af090b5
# This script is run by buildkite to run the benchmarks and upload the results to buildkite
set -ex
set -o pipefail
# cd into parent directory of this file
cd "$(dirname "${BASH_SOURCE[0]}")/.."
(wget && curl) || (apt-get update && apt-get install -y wget curl)
# run benchmarks and upload the result to buildkite
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
bench_latency_exit_code=$?
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
bench_throughput_exit_code=$?
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
server_pid=$!
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# wait for server to start, timeout after 600 seconds
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
--model meta-llama/Llama-2-7b-chat-hf \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt
bench_serving_exit_code=$?
kill $server_pid
# write the results into a markdown file
echo "### Latency Benchmarks" >> benchmark_results.md
sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
echo "" >> benchmark_results.md
sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
echo "### Throughput Benchmarks" >> benchmark_results.md
sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
echo "" >> benchmark_results.md
sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
echo "### Serving Benchmarks" >> benchmark_results.md
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
echo "" >> benchmark_results.md
tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines
# upload the results to buildkite
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
# exit with the exit code of the benchmarks
if [ $bench_latency_exit_code -ne 0 ]; then
exit $bench_latency_exit_code
fi
if [ $bench_throughput_exit_code -ne 0 ]; then
exit $bench_throughput_exit_code
fi
if [ $bench_serving_exit_code -ne 0 ]; then
exit $bench_serving_exit_code
fi
# In this file, you can add more tests to run either by adding a new step or
# adding a new command to an existing step. See different options here for examples.
# This script will be feed into Jinja template in `test-template.j2` to generate
# the final pipeline yaml file.
steps:
- label: Regression Test
command: pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional
- label: AsyncEngine Test
command: pytest -v -s async_engine
- label: Distributed Test
command: pytest -v -s test_comm_ops.py
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
- label: Engine Test
command: pytest -v -s engine
- label: Entrypoints Test
command: pytest -v -s entrypoints
- label: Kernels Test
command: pytest -v -s kernels
soft_fail: true
- label: Models Test
commands:
- pytest -v -s models --forked
soft_fail: true
- label: Prefix Caching Test
commands:
- pytest -v -s prefix_caching
- label: Samplers Test
command: pytest -v -s samplers --forked
- label: Worker Test
command: pytest -v -s worker
- label: LoRA Test
command: pytest -v -s lora
- label: Benchmarks
working_dir: "/vllm-workspace/.buildkite"
commands:
- pip install aiohttp
- bash run-benchmarks.sh
{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
{% set default_num_gpu = 1 %}
{% set default_working_dir = "/vllm-workspace/tests" %}
steps:
- label: ":docker: build image"
commands:
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}"
env:
DOCKER_BUILDKIT: "1"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
- wait
{% for step in steps %}
- label: "{{ step.label }}"
agents:
queue: kubernetes
soft_fail: {{ step.soft_fail or false }}
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
plugins:
- kubernetes:
podSpec:
volumes:
- name: dshm
emptyDir:
medium: Memory
containers:
- image: "{{ docker_image }}"
command: ["bash"]
args:
- "-c"
- "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
resources:
requests:
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
limits:
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
env:
- name: HF_TOKEN
valueFrom:
secretKeyRef:
name: hf-token-secret
key: token
volumeMounts:
- mountPath: /dev/shm
name: dshm
{% endfor %}
...@@ -13,6 +13,8 @@ $python_executable -m pip install -r requirements.txt ...@@ -13,6 +13,8 @@ $python_executable -m pip install -r requirements.txt
# Limit the number of parallel jobs to avoid OOM # Limit the number of parallel jobs to avoid OOM
export MAX_JOBS=1 export MAX_JOBS=1
# Make sure punica is built for the release (for LoRA)
export VLLM_INSTALL_PUNICA_KERNELS=1
# Build # Build
$python_executable setup.py bdist_wheel --dist-dir=dist $python_executable setup.py bdist_wheel --dist-dir=dist
...@@ -28,4 +28,4 @@ jobs: ...@@ -28,4 +28,4 @@ jobs:
pip install toml==0.10.2 pip install toml==0.10.2
- name: Running yapf - name: Running yapf
run: | run: |
yapf --diff --recursive vllm tests yapf --diff --recursive .
...@@ -181,3 +181,6 @@ _build/ ...@@ -181,3 +181,6 @@ _build/
# hip files generated by PyTorch # hip files generated by PyTorch
*.hip *.hip
*_hip* *_hip*
# Benchmark dataset
*.json
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server.
#################### BASE BUILD IMAGE ####################
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y python3-pip && apt-get install -y python3-pip git
WORKDIR /workspace WORKDIR /workspace
...@@ -14,8 +18,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ...@@ -14,8 +18,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \
COPY requirements-dev.txt requirements-dev.txt COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
#################### BASE BUILD IMAGE ####################
# image to build pytorch extensions #################### EXTENSION BUILD IMAGE ####################
FROM dev AS build FROM dev AS build
# install build dependencies # install build dependencies
...@@ -30,6 +36,7 @@ COPY requirements.txt requirements.txt ...@@ -30,6 +36,7 @@ COPY requirements.txt requirements.txt
COPY pyproject.toml pyproject.toml COPY pyproject.toml pyproject.toml
COPY vllm/__init__.py vllm/__init__.py COPY vllm/__init__.py vllm/__init__.py
# cuda arch list used by torch
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# max jobs used by Ninja to build extensions # max jobs used by Ninja to build extensions
...@@ -38,20 +45,30 @@ ENV MAX_JOBS=${max_jobs} ...@@ -38,20 +45,30 @@ ENV MAX_JOBS=${max_jobs}
# number of threads used by nvcc # number of threads used by nvcc
ARG nvcc_threads=8 ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads ENV NVCC_THREADS=$nvcc_threads
# make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
RUN python3 setup.py build_ext --inplace RUN python3 setup.py build_ext --inplace
#################### EXTENSION Build IMAGE ####################
#################### TEST IMAGE ####################
# image to run unit testing suite # image to run unit testing suite
FROM dev AS test FROM dev AS test
# copy pytorch extensions separately to avoid having to rebuild # copy pytorch extensions separately to avoid having to rebuild
# when python code changes # when python code changes
COPY --from=build /workspace/vllm/*.so /workspace/vllm/ WORKDIR /vllm-workspace
COPY tests tests # ADD is used to preserve directory structure
COPY vllm vllm ADD . /vllm-workspace/
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
# ignore build dependencies installation because we are using pre-complied extensions
RUN rm pyproject.toml
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
#################### TEST IMAGE ####################
ENTRYPOINT ["python3", "-m", "pytest", "tests"]
#################### RUNTIME BASE IMAGE ####################
# use CUDA base as CUDA runtime dependencies are already installed via pip # use CUDA base as CUDA runtime dependencies are already installed via pip
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
...@@ -63,14 +80,10 @@ WORKDIR /workspace ...@@ -63,14 +80,10 @@ WORKDIR /workspace
COPY requirements.txt requirements.txt COPY requirements.txt requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt pip install -r requirements.txt
#################### RUNTIME BASE IMAGE ####################
FROM vllm-base AS vllm
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm
EXPOSE 8000
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
#################### OPENAI API SERVER ####################
# openai api server alternative # openai api server alternative
FROM vllm-base AS vllm-openai FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server # install additional dependencies for openai api server
...@@ -81,4 +94,4 @@ COPY --from=build /workspace/vllm/*.so /workspace/vllm/ ...@@ -81,4 +94,4 @@ COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm COPY vllm vllm
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
#################### OPENAI API SERVER ####################
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 # default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
FROM $BASE_IMAGE
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
RUN echo "Base image is $BASE_IMAGE"
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
# this does not always work for all rocm versions
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="3d2b6f5"
RUN echo "FA_BRANCH is $FA_BRANCH"
# Install some basic utilities # Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y RUN apt-get update && apt-get install python3 python3-pip -y
...@@ -37,17 +57,23 @@ RUN mkdir libs \ ...@@ -37,17 +57,23 @@ RUN mkdir libs \
&& cd libs \ && cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \ && cd flash-attention \
&& git checkout 3d2b6f5 \ && git checkout ${FA_BRANCH} \
&& git submodule update --init \ && git submodule update --init \
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \ && export GPU_ARCHS=${FA_GFX_ARCHS} \
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \ && python3 setup.py install \
&& cd .. && cd ..
COPY ./ /app/vllm COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade pip
RUN pip install xformers==0.0.23 --no-deps RUN python3 -m pip install xformers==0.0.23 --no-deps
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually removed it so that later steps of numpy upgrade can continue
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
RUN cd /app \ RUN cd /app \
&& cd vllm \ && cd vllm \
......
...@@ -41,7 +41,7 @@ python3 setup.py install ...@@ -41,7 +41,7 @@ python3 setup.py install
+ 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ + 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
## 验证 ## 验证
- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.2.7 - python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.3.0
## Known Issue ## Known Issue
- -
......
...@@ -16,8 +16,18 @@ Easy, fast, and cheap LLM serving for everyone ...@@ -16,8 +16,18 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
**The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)**
We are thrilled to announce our second vLLM Meetup!
The vLLM team will share recent updates and roadmap.
We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations.
Please register [here](https://lu.ma/ygxbpzhl) and join us!
---
*Latest News* 🔥 *Latest News* 🔥
- [2023/12] Added ROCm support to vLLM. - [2024/01] Added ROCm 6.0 support to vLLM.
- [2023/12] Added ROCm 5.7 support to vLLM.
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there. - [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv! - [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
...@@ -36,7 +46,7 @@ vLLM is fast with: ...@@ -36,7 +46,7 @@ vLLM is fast with:
- Efficient management of attention key and value memory with **PagedAttention** - Efficient management of attention key and value memory with **PagedAttention**
- Continuous batching of incoming requests - Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph - Fast model execution with CUDA/HIP graph
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629) - Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
- Optimized CUDA kernels - Optimized CUDA kernels
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:
...@@ -47,6 +57,8 @@ vLLM is flexible and easy to use with: ...@@ -47,6 +57,8 @@ vLLM is flexible and easy to use with:
- Streaming outputs - Streaming outputs
- OpenAI-compatible API server - OpenAI-compatible API server
- Support NVIDIA GPUs and AMD GPUs - Support NVIDIA GPUs and AMD GPUs
- (Experimental) Prefix caching support
- (Experimental) Multi-lora support
vLLM seamlessly supports many Hugging Face models, including the following architectures: vLLM seamlessly supports many Hugging Face models, including the following architectures:
...@@ -68,6 +80,8 @@ vLLM seamlessly supports many Hugging Face models, including the following archi ...@@ -68,6 +80,8 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) - Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
......
...@@ -24,6 +24,7 @@ def main(args: argparse.Namespace): ...@@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
trust_remote_code=args.trust_remote_code, trust_remote_code=args.trust_remote_code,
dtype=args.dtype, dtype=args.dtype,
enforce_eager=args.enforce_eager, enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -65,7 +66,9 @@ def main(args: argparse.Namespace): ...@@ -65,7 +66,9 @@ def main(args: argparse.Namespace):
if args.profile: if args.profile:
profile_dir = args.profile_result_dir profile_dir = args.profile_result_dir
if not profile_dir: if not profile_dir:
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" profile_dir = Path(
"."
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=args.profile_result_dir) run_to_completion(profile_dir=args.profile_result_dir)
return return
...@@ -115,6 +118,13 @@ if __name__ == '__main__': ...@@ -115,6 +118,13 @@ if __name__ == '__main__':
parser.add_argument('--enforce-eager', parser.add_argument('--enforce-eager',
action='store_true', action='store_true',
help='enforce eager mode and disable CUDA graph') help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8_e5m2'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument( parser.add_argument(
'--profile', '--profile',
action='store_true', action='store_true',
...@@ -123,9 +133,7 @@ if __name__ == '__main__': ...@@ -123,9 +133,7 @@ if __name__ == '__main__':
'--profile-result-dir', '--profile-result-dir',
type=str, type=str,
default=None, default=None,
help=( help=('path to save the pytorch profiler output. Can be visualized '
'path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.'))
'with ui.perfetto.dev or Tensorboard.'
))
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -24,6 +24,7 @@ from typing import AsyncGenerator, List, Tuple ...@@ -24,6 +24,7 @@ from typing import AsyncGenerator, List, Tuple
import aiohttp import aiohttp
import numpy as np import numpy as np
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
...@@ -40,15 +41,10 @@ def sample_requests( ...@@ -40,15 +41,10 @@ def sample_requests(
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [ dataset = [(data["conversations"][0]["value"],
(data["conversations"][0]["value"], data["conversations"][1]["value"]) data["conversations"][1]["value"]) for data in dataset]
for data in dataset
]
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] prompts = [prompt for prompt, _ in dataset]
...@@ -96,15 +92,9 @@ async def get_request( ...@@ -96,15 +92,9 @@ async def get_request(
await asyncio.sleep(interval) await asyncio.sleep(interval)
async def send_request( async def send_request(backend: str, model: str, api_url: str, prompt: str,
backend: str, prompt_len: int, output_len: int, best_of: int,
api_url: str, use_beam_search: bool, pbar: tqdm) -> None:
prompt: str,
prompt_len: int,
output_len: int,
best_of: int,
use_beam_search: bool,
) -> None:
request_start_time = time.perf_counter() request_start_time = time.perf_counter()
headers = {"User-Agent": "Benchmark Client"} headers = {"User-Agent": "Benchmark Client"}
...@@ -120,6 +110,8 @@ async def send_request( ...@@ -120,6 +110,8 @@ async def send_request(
"ignore_eos": True, "ignore_eos": True,
"stream": False, "stream": False,
} }
if model is not None:
pload["model"] = model
elif backend == "tgi": elif backend == "tgi":
assert not use_beam_search assert not use_beam_search
params = { params = {
...@@ -137,7 +129,8 @@ async def send_request( ...@@ -137,7 +129,8 @@ async def send_request(
timeout = aiohttp.ClientTimeout(total=3 * 3600) timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
while True: while True:
async with session.post(api_url, headers=headers, json=pload) as response: async with session.post(api_url, headers=headers,
json=pload) as response:
chunks = [] chunks = []
async for chunk, _ in response.content.iter_chunks(): async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk) chunks.append(chunk)
...@@ -151,10 +144,12 @@ async def send_request( ...@@ -151,10 +144,12 @@ async def send_request(
request_end_time = time.perf_counter() request_end_time = time.perf_counter()
request_latency = request_end_time - request_start_time request_latency = request_end_time - request_start_time
REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
pbar.update(1)
async def benchmark( async def benchmark(
backend: str, backend: str,
model: str,
api_url: str, api_url: str,
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
best_of: int, best_of: int,
...@@ -162,13 +157,15 @@ async def benchmark( ...@@ -162,13 +157,15 @@ async def benchmark(
request_rate: float, request_rate: float,
) -> None: ) -> None:
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
pbar = tqdm(total=len(input_requests))
async for request in get_request(input_requests, request_rate): async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt, task = asyncio.create_task(
prompt_len, output_len, send_request(backend, model, api_url, prompt, prompt_len,
best_of, use_beam_search)) output_len, best_of, use_beam_search, pbar))
tasks.append(task) tasks.append(task)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
pbar.close()
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
...@@ -176,13 +173,15 @@ def main(args: argparse.Namespace): ...@@ -176,13 +173,15 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of, asyncio.run(
args.use_beam_search, args.request_rate)) benchmark(args.backend, args.model, api_url, input_requests,
args.best_of, args.use_beam_search, args.request_rate))
benchmark_end_time = time.perf_counter() benchmark_end_time = time.perf_counter()
benchmark_time = benchmark_end_time - benchmark_start_time benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s") print(f"Total time: {benchmark_time:.2f} s")
...@@ -196,10 +195,8 @@ def main(args: argparse.Namespace): ...@@ -196,10 +195,8 @@ def main(args: argparse.Namespace):
for prompt_len, output_len, latency in REQUEST_LATENCY for prompt_len, output_len, latency in REQUEST_LATENCY
]) ])
print(f"Average latency per token: {avg_per_token_latency:.2f} s") print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean([ avg_per_output_token_latency = np.mean(
latency / output_len [latency / output_len for _, output_len, latency in REQUEST_LATENCY])
for _, output_len, latency in REQUEST_LATENCY
])
print("Average latency per output token: " print("Average latency per output token: "
f"{avg_per_output_token_latency:.2f} s") f"{avg_per_output_token_latency:.2f} s")
...@@ -207,27 +204,46 @@ def main(args: argparse.Namespace): ...@@ -207,27 +204,46 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.") description="Benchmark the online serving throughput.")
parser.add_argument("--backend", type=str, default="vllm", parser.add_argument("--backend",
type=str,
default="vllm",
choices=["vllm", "tgi"]) choices=["vllm", "tgi"])
parser.add_argument("--protocol",
type=str,
default="http",
choices=["http", "https"])
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--endpoint", type=str, default="/generate")
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True, parser.add_argument("--tokenizer",
type=str,
required=True,
help="Name or path of the tokenizer.") help="Name or path of the tokenizer.")
parser.add_argument("--best-of", type=int, default=1, parser.add_argument("--best-of",
type=int,
default=1,
help="Generates `best_of` sequences per prompt and " help="Generates `best_of` sequences per prompt and "
"returns the best one.") "returns the best one.")
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000, parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.") help="Number of prompts to process.")
parser.add_argument("--request-rate", type=float, default=float("inf"), parser.add_argument("--request-rate",
type=float,
default=float("inf"),
help="Number of requests per second. If this is inf, " help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. " "then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize " "Otherwise, we use Poisson process to synthesize "
"the request arrival times.") "the request arrival times.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true', parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface') help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -71,6 +71,7 @@ def run_vllm( ...@@ -71,6 +71,7 @@ def run_vllm(
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
enforce_eager: bool, enforce_eager: bool,
kv_cache_dtype: str,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
...@@ -83,6 +84,7 @@ def run_vllm( ...@@ -83,6 +84,7 @@ def run_vllm(
dtype=dtype, dtype=dtype,
max_model_len=max_model_len, max_model_len=max_model_len,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
) )
# Add the requests to the engine. # Add the requests to the engine.
...@@ -206,7 +208,8 @@ def main(args: argparse.Namespace): ...@@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size, args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager) args.max_model_len, args.enforce_eager,
args.kv_cache_dtype)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
...@@ -284,6 +287,13 @@ if __name__ == "__main__": ...@@ -284,6 +287,13 @@ if __name__ == "__main__":
parser.add_argument("--enforce-eager", parser.add_argument("--enforce-eager",
action="store_true", action="store_true",
help="enforce eager execution") help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
......
from typing import Optional
import argparse import argparse
import random import random
import time import time
import torch import torch
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._C import ops from vllm._C import ops
NUM_BLOCKS = 1024 NUM_BLOCKS = 1024
...@@ -23,6 +25,7 @@ def main( ...@@ -23,6 +25,7 @@ def main(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
do_profile: bool, do_profile: bool,
kv_cache_dtype: Optional[str] = None,
) -> None: ) -> None:
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
...@@ -59,15 +62,10 @@ def main( ...@@ -59,15 +62,10 @@ def main(
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
# Create the KV cache. # Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size() key_caches, value_caches = create_kv_caches_with_random(
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") dtype)
key_cache.uniform_(-scale, scale) key_cache, value_cache = key_caches[0], value_caches[0]
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device="cuda")
value_cache.uniform_(-scale, scale)
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
...@@ -106,6 +104,7 @@ def main( ...@@ -106,6 +104,7 @@ def main(
block_size, block_size,
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype,
) )
elif version == "v2": elif version == "v2":
ops.paged_attention_v2( ops.paged_attention_v2(
...@@ -123,6 +122,7 @@ def main( ...@@ -123,6 +122,7 @@ def main(
block_size, block_size,
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype,
) )
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
...@@ -168,16 +168,18 @@ if __name__ == '__main__': ...@@ -168,16 +168,18 @@ if __name__ == '__main__':
default="half") default="half")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.num_query_heads % args.num_kv_heads != 0: if args.num_query_heads % args.num_kv_heads != 0:
raise ValueError("num_query_heads must be divisible by num_kv_heads") raise ValueError("num_query_heads must be divisible by num_kv_heads")
dtype_to_torch_dtype = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
main( main(
version=args.version, version=args.version,
num_seqs=args.batch_size, num_seqs=args.batch_size,
...@@ -187,7 +189,8 @@ if __name__ == '__main__': ...@@ -187,7 +189,8 @@ if __name__ == '__main__':
head_size=args.head_size, head_size=args.head_size,
block_size=args.block_size, block_size=args.block_size,
use_alibi=args.use_alibi, use_alibi=args.use_alibi,
dtype=dtype_to_torch_dtype[args.dtype], dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed, seed=args.seed,
do_profile=args.profile, do_profile=args.profile,
kv_cache_dtype=args.kv_cache_dtype,
) )
...@@ -4,3 +4,4 @@ ...@@ -4,3 +4,4 @@
#include "dtype_float16.cuh" #include "dtype_float16.cuh"
#include "dtype_float32.cuh" #include "dtype_float32.cuh"
// #include "dtype_bfloat16.cuh" // #include "dtype_bfloat16.cuh"
// #include "dtype_fp8_e5m2.cuh"
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#include <algorithm> #include <algorithm>
...@@ -79,17 +80,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -79,17 +80,19 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template<
typename scalar_t, typename scalar_t,
typename cache_t,
int HEAD_SIZE, int HEAD_SIZE,
int BLOCK_SIZE, int BLOCK_SIZE,
int NUM_THREADS, int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE,
int PARTITION_SIZE = 0> // Zero means no partitioning. int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
...@@ -145,6 +148,9 @@ __device__ void paged_attention_kernel( ...@@ -145,6 +148,9 @@ __device__ void paged_attention_kernel(
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
#ifdef ENABLE_FP8_E5M2
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
#endif
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
...@@ -176,7 +182,7 @@ __device__ void paged_attention_kernel( ...@@ -176,7 +182,7 @@ __device__ void paged_attention_kernel(
// x == THREAD_GROUP_SIZE * VEC_SIZE // x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time. // Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX; float qk_max = -FLT_MAX;
// Iterate over the key blocks. // Iterate over the key blocks.
...@@ -202,14 +208,24 @@ __device__ void paged_attention_kernel( ...@@ -202,14 +208,24 @@ __device__ void paged_attention_kernel(
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride
+ physical_block_offset * x; + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec.
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
#else
assert(false);
#endif
} else {
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} }
}
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
...@@ -282,6 +298,9 @@ __device__ void paged_attention_kernel( ...@@ -282,6 +298,9 @@ __device__ void paged_attention_kernel(
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
#ifdef ENABLE_FP8_E5M2
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
#endif
using Float_L_vec = typename FloatVec<L_vec>::Type; using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
...@@ -307,14 +326,25 @@ __device__ void paged_attention_kernel( ...@@ -307,14 +326,25 @@ __device__ void paged_attention_kernel(
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride; + kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); V_vec v_vec;
if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
#else
assert(false);
#endif
} else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
}
if (block_idx == num_context_blocks - 1) { if (block_idx == num_context_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context, // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs. // we should explicitly zero out the values since they may contain NaNs.
...@@ -395,14 +425,16 @@ __device__ void paged_attention_kernel( ...@@ -395,14 +425,16 @@ __device__ void paged_attention_kernel(
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template< template<
typename scalar_t, typename scalar_t,
typename cache_t,
int HEAD_SIZE, int HEAD_SIZE,
int BLOCK_SIZE, int BLOCK_SIZE,
int NUM_THREADS> int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE>
__global__ void paged_attention_v1_kernel( __global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
...@@ -412,7 +444,7 @@ __global__ void paged_attention_v1_kernel( ...@@ -412,7 +444,7 @@ __global__ void paged_attention_v1_kernel(
const int q_stride, const int q_stride,
const int kv_block_stride, const int kv_block_stride,
const int kv_head_stride) { const int kv_head_stride) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, /* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
...@@ -421,17 +453,19 @@ __global__ void paged_attention_v1_kernel( ...@@ -421,17 +453,19 @@ __global__ void paged_attention_v1_kernel(
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template<
typename scalar_t, typename scalar_t,
typename cache_t,
int HEAD_SIZE, int HEAD_SIZE,
int BLOCK_SIZE, int BLOCK_SIZE,
int NUM_THREADS, int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel( __global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
...@@ -441,7 +475,7 @@ __global__ void paged_attention_v2_kernel( ...@@ -441,7 +475,7 @@ __global__ void paged_attention_v2_kernel(
const int q_stride, const int q_stride,
const int kv_block_stride, const int kv_block_stride,
const int kv_head_stride) { const int kv_head_stride) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride); q_stride, kv_block_stride, kv_head_stride);
...@@ -550,10 +584,10 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -550,10 +584,10 @@ __global__ void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \ ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
shared_mem_size); \ IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \ vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
<<<grid, block, shared_mem_size, stream>>>( \ IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \ out_ptr, \
query_ptr, \ query_ptr, \
key_cache_ptr, \ key_cache_ptr, \
...@@ -571,7 +605,9 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -571,7 +605,9 @@ __global__ void paged_attention_v2_reduce_kernel(
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template<
typename T, typename T,
typename CACHE_T,
int BLOCK_SIZE, int BLOCK_SIZE,
bool IS_FP8_E5M2_KV_CACHE,
int NUM_THREADS = 128> int NUM_THREADS = 128>
void paged_attention_v1_launcher( void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& out,
...@@ -602,8 +638,8 @@ void paged_attention_v1_launcher( ...@@ -602,8 +638,8 @@ void paged_attention_v1_launcher(
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* context_lens_ptr = context_lens.data_ptr<int>();
...@@ -647,8 +683,8 @@ void paged_attention_v1_launcher( ...@@ -647,8 +683,8 @@ void paged_attention_v1_launcher(
} }
} }
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
paged_attention_v1_launcher<T, BLOCK_SIZE>( \ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
out, \ out, \
query, \ query, \
key_cache, \ key_cache, \
...@@ -662,16 +698,16 @@ void paged_attention_v1_launcher( ...@@ -662,16 +698,16 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
switch (block_size) { \ switch (block_size) { \
case 8: \ case 8: \
CALL_V1_LAUNCHER(T, 8); \ CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
break; \ break; \
case 16: \ case 16: \
CALL_V1_LAUNCHER(T, 16); \ CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
break; \ break; \
case 32: \ case 32: \
CALL_V1_LAUNCHER(T, 32); \ CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
...@@ -689,20 +725,36 @@ void paged_attention_v1( ...@@ -689,20 +725,36 @@ void paged_attention_v1(
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len, int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) { const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) { if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float); CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
// } else if (query.dtype() == at::ScalarType::BFloat16) { // } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); // CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
} }
// } else if (kv_cache_dtype == "fp8_e5m2") {
// if (query.dtype() == at::ScalarType::Float) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::Half) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
// } else {
// TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
// }
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
} }
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \ vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \ <<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \ exp_sums_ptr, \
max_logits_ptr, \ max_logits_ptr, \
...@@ -730,7 +782,9 @@ void paged_attention_v1( ...@@ -730,7 +782,9 @@ void paged_attention_v1(
template< template<
typename T, typename T,
typename CACHE_T,
int BLOCK_SIZE, int BLOCK_SIZE,
bool IS_FP8_E5M2_KV_CACHE,
int NUM_THREADS = 128, int NUM_THREADS = 128,
int PARTITION_SIZE = 512> int PARTITION_SIZE = 512>
void paged_attention_v2_launcher( void paged_attention_v2_launcher(
...@@ -768,8 +822,8 @@ void paged_attention_v2_launcher( ...@@ -768,8 +822,8 @@ void paged_attention_v2_launcher(
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr()); T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* context_lens_ptr = context_lens.data_ptr<int>();
...@@ -816,8 +870,8 @@ void paged_attention_v2_launcher( ...@@ -816,8 +870,8 @@ void paged_attention_v2_launcher(
} }
} }
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
paged_attention_v2_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
out, \ out, \
exp_sums, \ exp_sums, \
max_logits, \ max_logits, \
...@@ -834,16 +888,16 @@ void paged_attention_v2_launcher( ...@@ -834,16 +888,16 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
switch (block_size) { \ switch (block_size) { \
case 8: \ case 8: \
CALL_V2_LAUNCHER(T, 8); \ CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
break; \ break; \
case 16: \ case 16: \
CALL_V2_LAUNCHER(T, 16); \ CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
break; \ break; \
case 32: \ case 32: \
CALL_V2_LAUNCHER(T, 32); \ CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
...@@ -864,16 +918,31 @@ void paged_attention_v2( ...@@ -864,16 +918,31 @@ void paged_attention_v2(
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len, int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) { const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) { if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float); CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
// } else if (query.dtype() == at::ScalarType::BFloat16) { // } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); // CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
} }
// } else if (kv_cache_dtype == "fp8_e5m2") {
// if (query.dtype() == at::ScalarType::Float) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::Half) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
// } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
// } else {
// TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
// }
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
} }
#undef WARP_SIZE #undef WARP_SIZE
......
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
#ifdef ENABLE_FP8_E5M2
#include <cuda_fp8.h>
#endif
namespace vllm {
#ifdef ENABLE_FP8_E5M2
// fp8 vector types for quantization of kv cache
template<>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
};
template<>
struct Vec<uint8_t, 2> {
using Type = uint16_t;
};
template<>
struct Vec<uint8_t, 4> {
using Type = uint32_t;
};
template<>
struct Vec<uint8_t, 8> {
using Type = uint2;
};
#endif // ENABLE_FP8_E5M2
} // namespace vllm
...@@ -20,7 +20,8 @@ void reshape_and_cache( ...@@ -20,7 +20,8 @@ void reshape_and_cache(
torch::Tensor& value, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping); torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
void gather_cached_kv( void gather_cached_kv(
torch::Tensor& key, torch::Tensor& key,
...@@ -28,3 +29,8 @@ void gather_cached_kv( ...@@ -28,3 +29,8 @@ void gather_cached_kv(
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping); torch::Tensor& slot_mapping);
// Just for unittest
void convert_fp8_e5m2(
torch::Tensor& src_cache,
torch::Tensor& dst_cache);
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
...@@ -34,7 +35,7 @@ void swap_blocks( ...@@ -34,7 +35,7 @@ void swap_blocks(
char *dst_ptr = static_cast<char*>(dst.data_ptr()); char *dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device); const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large. // NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) { for (const auto& pair : block_mapping) {
...@@ -131,7 +132,7 @@ void copy_blocks( ...@@ -131,7 +132,7 @@ void copy_blocks(
dim3 block(std::min(1024, numel_per_block)); dim3 block(std::min(1024, numel_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device); const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
...@@ -143,12 +144,12 @@ void copy_blocks( ...@@ -143,12 +144,12 @@ void copy_blocks(
namespace vllm { namespace vllm {
template<typename scalar_t> template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int key_stride,
const int value_stride, const int value_stride,
...@@ -185,19 +186,45 @@ __global__ void reshape_and_cache_kernel( ...@@ -185,19 +186,45 @@ __global__ void reshape_and_cache_kernel(
+ head_idx * head_size * block_size + head_idx * head_size * block_size
+ head_offset * block_size + head_offset * block_size
+ block_offset; + block_offset;
key_cache[tgt_key_idx] = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
value_cache[tgt_value_idx] = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_e5m2_kv_cache) {
#ifdef ENABLE_FP8_E5M2
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#else
assert(false);
#endif
} else {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
}
} }
} }
} // namespace vllm } // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
{ {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
...@@ -212,23 +239,25 @@ void reshape_and_cache( ...@@ -212,23 +239,25 @@ void reshape_and_cache(
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( if (kv_cache_dtype == "auto") {
key.scalar_type(), if (key.dtype() == at::ScalarType::Float) {
"reshape_and_cache_kernel", CALL_RESHAPE_AND_CACHE(float, float, false);
[&] { } else if (key.dtype() == at::ScalarType::Half) {
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>( CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
key.data_ptr<scalar_t>(), } else if (key.dtype() == at::ScalarType::BFloat16) {
value.data_ptr<scalar_t>(), CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
key_cache.data_ptr<scalar_t>(), }
value_cache.data_ptr<scalar_t>(), } else if (kv_cache_dtype == "fp8_e5m2") {
slot_mapping.data_ptr<int64_t>(), if (key.dtype() == at::ScalarType::Float) {
key_stride, CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
value_stride, } else if (key.dtype() == at::ScalarType::Half) {
num_heads, CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
head_size, } else if (key.dtype() == at::ScalarType::BFloat16) {
block_size, CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
x); }
}); } else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
} }
namespace vllm { namespace vllm {
...@@ -373,7 +402,7 @@ void gather_cached_kv( ...@@ -373,7 +402,7 @@ void gather_cached_kv(
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key.scalar_type(), key.scalar_type(),
"gather_cached_kv_kernel_optimized", "gather_cached_kv_kernel_optimized",
[&] { [&] {
...@@ -391,3 +420,55 @@ void gather_cached_kv( ...@@ -391,3 +420,55 @@ void gather_cached_kv(
x); x);
}); });
} }
namespace vllm {
template<typename Tout, typename Tin>
__global__ void convert_fp8_e5m2_kernel(
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
#ifdef ENABLE_FP8_E5M2
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
}
}
} // namespace vllm
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
block_stride);
void convert_fp8_e5m2(
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
{
int64_t num_blocks = src_cache.size(0);
int64_t block_stride = src_cache.stride(0);
dim3 grid(num_blocks);
dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8_E5M2(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8_E5M2(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment