"docs/vscode:/vscode.git/clone" did not exist on "b0925b38789bb3b20dcc39e229fcfe12a311e487"
Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
...@@ -16,7 +16,7 @@ add_compile_options(-w) ...@@ -16,7 +16,7 @@ add_compile_options(-w)
# Supported python versions. These versions will be searched in order, the # Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py. # first match will be selected. These should be kept in sync with setup.py.
# #
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
# Supported NVIDIA architectures. # Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
...@@ -34,7 +34,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 ...@@ -34,7 +34,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch # requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm # versions are derived from Dockerfile.rocm
# #
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1") set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
# #
...@@ -68,6 +68,39 @@ endif() ...@@ -68,6 +68,39 @@ endif()
# #
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
message(STATUS "Enabling core extension.")
# Define _core_C extension
# built for (almost) every target platform, (excludes TPU and Neuron)
set(VLLM_EXT_SRC
"csrc/core/torch_bindings.cpp")
define_gpu_extension_target(
_core_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI)
add_dependencies(default _core_C)
# #
# Forward the non-CUDA device extensions to external CMake scripts. # Forward the non-CUDA device extensions to external CMake scripts.
# #
...@@ -76,7 +109,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND ...@@ -76,7 +109,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
if (VLLM_TARGET_DEVICE STREQUAL "cpu") if (VLLM_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
else() else()
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") return()
endif() endif()
return() return()
endif() endif()
...@@ -134,7 +167,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -134,7 +167,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
# #
# Define extension targets # Define other extension targets
# #
# #
...@@ -160,12 +193,13 @@ set(VLLM_EXT_SRC ...@@ -160,12 +193,13 @@ set(VLLM_EXT_SRC
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent) include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
FetchContent_Declare( FetchContent_Declare(
cutlass cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# CUTLASS 3.5.0 # CUTLASS 3.5.1
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9
GIT_PROGRESS TRUE
) )
FetchContent_MakeAvailable(cutlass) FetchContent_MakeAvailable(cutlass)
...@@ -174,6 +208,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -174,6 +208,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
...@@ -204,7 +239,7 @@ define_gpu_extension_target( ...@@ -204,7 +239,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
...@@ -226,76 +261,7 @@ define_gpu_extension_target( ...@@ -226,76 +261,7 @@ define_gpu_extension_target(
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
#
# _punica_C extension
#
set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cu"
"csrc/punica/torch_bindings.cpp")
#
# Copy GPU compilation flags+update for punica
#
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__"
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
"-D__CUDA_NO_HALF2_OPERATORS__")
#
# Filter out CUDA architectures < 8.0 for punica.
#
if (${VLLM_GPU_LANG} STREQUAL "CUDA")
set(VLLM_PUNICA_GPU_ARCHES)
foreach(ARCH ${VLLM_GPU_ARCHES})
string_to_ver(CODE_VER ${ARCH})
if (CODE_VER GREATER_EQUAL 8.0)
list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
endif()
endforeach()
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
endif()
if (VLLM_PUNICA_GPU_ARCHES)
define_gpu_extension_target(
_punica_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
else()
message(WARNING "Unable to create _punica_C target because none of the "
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
endif()
#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
...@@ -304,12 +270,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") ...@@ -304,12 +270,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling moe extension.") message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C) add_dependencies(default _moe_C)
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
# there are supported target arches.
if (VLLM_PUNICA_GPU_ARCHES AND
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
message(STATUS "Enabling punica extension.")
add_dependencies(default _punica_C)
endif()
endif() endif()
...@@ -42,6 +42,7 @@ WORKDIR /workspace ...@@ -42,6 +42,7 @@ WORKDIR /workspace
# install build and runtime dependencies # install build and runtime dependencies
COPY requirements-common.txt requirements-common.txt COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt python3 -m pip install -r requirements-cuda.txt
...@@ -78,6 +79,7 @@ COPY setup.py setup.py ...@@ -78,6 +79,7 @@ COPY setup.py setup.py
COPY cmake cmake COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt COPY CMakeLists.txt CMakeLists.txt
COPY requirements-common.txt requirements-common.txt COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt COPY requirements-cuda.txt requirements-cuda.txt
COPY pyproject.toml pyproject.toml COPY pyproject.toml pyproject.toml
COPY vllm vllm COPY vllm vllm
...@@ -88,8 +90,6 @@ ENV MAX_JOBS=${max_jobs} ...@@ -88,8 +90,6 @@ ENV MAX_JOBS=${max_jobs}
# number of threads used by nvcc # number of threads used by nvcc
ARG nvcc_threads=8 ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads ENV NVCC_THREADS=$nvcc_threads
# make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
ARG buildkite_commit ARG buildkite_commit
ENV BUILDKITE_COMMIT=${buildkite_commit} ENV BUILDKITE_COMMIT=${buildkite_commit}
...@@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb ...@@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp310-cp310-linux_x86_64.whl python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
FROM ubuntu:22.04 AS cpu-test-1 FROM ubuntu:22.04 AS cpu-test-1
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \ && apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
...@@ -13,8 +13,9 @@ RUN pip install intel-openmp ...@@ -13,8 +13,9 @@ RUN pip install intel-openmp
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD" ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"
RUN echo 'ulimit -c 0' >> ~/.bashrc
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
RUN pip install --upgrade pip \ RUN pip install --upgrade pip \
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy && pip install wheel packaging ninja "setuptools>=49.4.0" numpy
......
# The vLLM Dockerfile is used to construct vLLM image that can be directly used # The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server. # to run the OpenAI compatible server.
FROM ubuntu:20.04 AS dev FROM ubuntu:22.04 AS dev
RUN apt-get update -y && \ RUN apt-get update -y && \
apt-get install -y python3-pip git apt-get install -y python3-pip git
...@@ -13,12 +13,15 @@ COPY requirements-common.txt /workspace/vllm/ ...@@ -13,12 +13,15 @@ COPY requirements-common.txt /workspace/vllm/
COPY requirements-openvino.txt /workspace/vllm/ COPY requirements-openvino.txt /workspace/vllm/
COPY vllm/ /workspace/vllm/vllm COPY vllm/ /workspace/vllm/vllm
COPY csrc/core /workspace/vllm/csrc/core
COPY cmake/utils.cmake /workspace/vllm/cmake/
COPY CMakeLists.txt /workspace/vllm/
COPY setup.py /workspace/vllm/ COPY setup.py /workspace/vllm/
# install build requirements # install build requirements
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
# build vLLM with OpenVINO backend # build vLLM with OpenVINO backend
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
COPY examples/ /workspace/vllm/examples COPY examples/ /workspace/vllm/examples
COPY benchmarks/ /workspace/vllm/benchmarks COPY benchmarks/ /workspace/vllm/benchmarks
......
...@@ -53,10 +53,10 @@ RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(whic ...@@ -53,10 +53,10 @@ RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(whic
# Install torch == 2.5.0 on ROCm # Install torch == 2.5.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \ *"rocm-6.1"*) \
python3 -m pip uninstall -y torch torchaudio torchvision \ python3 -m pip uninstall -y torch torchvision \
&& python3 -m pip install --no-cache-dir --pre \ && python3 -m pip install --no-cache-dir --pre \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \ torch==2.5.0.dev20240726 \
torchvision==0.20.0.dev20240710 \ torchvision==0.20.0.dev20240726 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
*) ;; esac *) ;; esac
...@@ -127,19 +127,11 @@ FROM base AS final ...@@ -127,19 +127,11 @@ FROM base AS final
# Import the vLLM development directory from the build context # Import the vLLM development directory from the build context
COPY . . COPY . .
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually remove it so that later steps of numpy upgrade can continue
RUN case "$(which python3)" in \
*"/opt/conda/envs/py_3.9"*) \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
*) ;; esac
# Package upgrades for useful functionality or to avoid dependency issues # Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
# Workaround for ray >= 2.10.0 # Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
# Silences the HF Tokenizers warning # Silences the HF Tokenizers warning
......
ARG NIGHTLY_DATE="20240713" ARG NIGHTLY_DATE="20240726"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE FROM $BASE_IMAGE
...@@ -12,6 +12,9 @@ RUN pip install "numpy<2" ...@@ -12,6 +12,9 @@ RUN pip install "numpy<2"
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Fix FastAPI dependence
RUN pip install "starlette<0.38.0"
# Build vLLM. # Build vLLM.
COPY . /workspace/vllm COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu" ENV VLLM_TARGET_DEVICE="tpu"
......
include LICENSE include LICENSE
include requirements-adag.txt
include requirements-common.txt include requirements-common.txt
include requirements-cuda.txt include requirements-cuda.txt
include requirements-rocm.txt include requirements-rocm.txt
......
...@@ -82,7 +82,7 @@ VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install ...@@ -82,7 +82,7 @@ VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
+ 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ + 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
## 验证 ## 验证
- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.5.3.post1 - python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.5.4
## Known Issue ## Known Issue
- -
......
...@@ -16,16 +16,8 @@ Easy, fast, and cheap LLM serving for everyone ...@@ -16,16 +16,8 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
**The Fifth vLLM Bay Area Meetup (July 24th 5pm-8pm PT)**
We are excited to announce our fifth vLLM Meetup!
Join us to hear the vLLM's recent updates and the upcoming roadmap.
Additionally, our collaborators from AWS will be presenting their insights and experiences in deploying vLLM.
Register now [here](https://lu.ma/lp0gyjqr) and be part of the event!
---
*Latest News* 🔥 *Latest News* 🔥
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). - [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
...@@ -47,7 +39,7 @@ vLLM is fast with: ...@@ -47,7 +39,7 @@ vLLM is fast with:
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache - Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
- Optimized CUDA kernels - Optimized CUDA kernels
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/3924) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)). **Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:
......
...@@ -13,7 +13,7 @@ from weight_shapes import WEIGHT_SHAPES ...@@ -13,7 +13,7 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1] DEFAULT_TP_SIZES = [1]
...@@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ...@@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
timers = [] timers = []
# pytorch impl # pytorch impl - bfloat16
timers.append( timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl, torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales")) "pytorch_bf16_bf16_bf16_matmul-no-scales"))
# pytorch impl - float16
timers.append(
bench_fn(a.to(dtype=torch.float16, device="cuda"),
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
torch.float16, label, sub_label, pytorch_mm_impl,
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
# cutlass impl # cutlass impl
timers.append( timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
......
...@@ -7,16 +7,17 @@ from benchmark_shapes import WEIGHT_SHAPES ...@@ -7,16 +7,17 @@ from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, marlin_quantize) MarlinWorkspace, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize) marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights) gptq_pack, gptq_quantize_weights, sort_weights)
from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
...@@ -27,13 +28,14 @@ K_FULL_OPTS = [False, True] ...@@ -27,13 +28,14 @@ K_FULL_OPTS = [False, True]
def bench_run(results: List[benchmark.Measurement], model: str, def bench_run(results: List[benchmark.Measurement], model: str,
act_order: bool, is_k_full: bool, num_bits: int, group_size: int, act_order: bool, is_k_full: bool, quant_type: ScalarType,
size_m: int, size_k: int, size_n: int): group_size: int, size_m: int, size_k: int, size_n: int):
label = "Quant Matmul" label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, " sub_label = ("{}, act={} k_full={}, q={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, "MKN=({}x{}x{})".format(model, act_order, is_k_full,
group_size, size_m, size_k, size_n)) str(quant_type), group_size, size_m,
size_k, size_n))
print(f"Testing: {sub_label}") print(f"Testing: {sub_label}")
...@@ -50,16 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str, ...@@ -50,16 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str,
marlin_g_idx, marlin_g_idx,
marlin_sort_indices, marlin_sort_indices,
marlin_rand_perm, marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order) ) = marlin_quantize(b, quant_type, group_size, act_order)
# Marlin_24 quant # Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
# GPTQ quant # GPTQ quant
(w_ref, q_w, s, g_idx, (w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order) rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" # For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing # so that group ids are increasing
...@@ -73,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str, ...@@ -73,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str,
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL) GPTQ_MARLIN_24_MAX_PARALLEL)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
globals = { globals = {
# Gen params # Gen params
"num_bits": num_bits, "quant_type": quant_type,
"group_size": group_size, "group_size": group_size,
"size_m": size_m, "size_m": size_m,
"size_n": size_n, "size_n": size_n,
...@@ -87,6 +92,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, ...@@ -87,6 +92,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
"marlin_w_ref": marlin_w_ref, "marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w, "marlin_q_w": marlin_q_w,
"marlin_s": marlin_s, "marlin_s": marlin_s,
"marlin_zp": marlin_zp,
"marlin_g_idx": marlin_g_idx, "marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices, "marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm, "marlin_rand_perm": marlin_rand_perm,
...@@ -125,19 +131,29 @@ def bench_run(results: List[benchmark.Measurement], model: str, ...@@ -125,19 +131,29 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp16",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="gptq_marlin_gemm", description="gptq_marlin_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
...@@ -147,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, ...@@ -147,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
...@@ -183,12 +199,13 @@ def main(args): ...@@ -183,12 +199,13 @@ def main(args):
) > 0 and is_k_full not in args.limit_k_full: ) > 0 and is_k_full not in args.limit_k_full:
continue continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: for quant_type in query_marlin_supported_quant_types(
if len(args.limit_num_bits False):
) > 0 and num_bits not in args.limit_num_bits: if len(args.limit_num_bits) > 0 and \
quant_type.size_bits not in args.limit_num_bits:
continue continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
if len( if len(
args.limit_group_size args.limit_group_size
) > 0 and group_size not in args.limit_group_size: ) > 0 and group_size not in args.limit_group_size:
...@@ -202,8 +219,8 @@ def main(args): ...@@ -202,8 +219,8 @@ def main(args):
for size_m in args.batch_sizes: for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full, bench_run(results, model, act_order, is_k_full,
num_bits, group_size, size_m, size_k, quant_type, group_size, size_m,
size_n) size_k, size_n)
compare = benchmark.Compare(results) compare = benchmark.Compare(results)
compare.print() compare.print()
......
...@@ -175,7 +175,7 @@ if __name__ == '__main__': ...@@ -175,7 +175,7 @@ if __name__ == '__main__':
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 192, 256], choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128) default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
......
...@@ -94,7 +94,7 @@ if __name__ == '__main__': ...@@ -94,7 +94,7 @@ if __name__ == '__main__':
parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 192, 256], choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128) default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype", parser.add_argument("--dtype",
......
...@@ -83,6 +83,8 @@ endif() ...@@ -83,6 +83,8 @@ endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
list(APPEND LIBS "numa")
# #
# Define extension targets # Define extension targets
...@@ -95,6 +97,7 @@ set(VLLM_EXT_SRC ...@@ -95,6 +97,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp" "csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp" "csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp" "csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp" "csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp" "csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp") "csrc/cpu/torch_bindings.cpp")
...@@ -104,11 +107,11 @@ define_gpu_extension_target( ...@@ -104,11 +107,11 @@ define_gpu_extension_target(
DESTINATION vllm DESTINATION vllm
LANGUAGE CXX LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS} COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3 USE_SABI 3
WITH_SOABI WITH_SOABI
) )
add_custom_target(default)
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
add_dependencies(default _C) add_dependencies(default _C)
...@@ -186,7 +186,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) ...@@ -186,7 +186,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
# #
# The torch cmake setup hardcodes the detected architecture flags in # The torch cmake setup hardcodes the detected architecture flags in
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
# can't modified on a per-target basis, e.g. for the `punica` extension. # can't modified on a per-target basis.
# So, all the `-gencode` flags need to be extracted and removed from # So, all the `-gencode` flags need to be extracted and removed from
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
# Since it's not possible to use `target_compiler_options` for adding target # Since it's not possible to use `target_compiler_options` for adding target
......
...@@ -65,6 +65,7 @@ DEFAULT_CONDA_PATTERNS = { ...@@ -65,6 +65,7 @@ DEFAULT_CONDA_PATTERNS = {
"optree", "optree",
"nccl", "nccl",
"transformers", "transformers",
"zmq",
} }
DEFAULT_PIP_PATTERNS = { DEFAULT_PIP_PATTERNS = {
...@@ -77,6 +78,7 @@ DEFAULT_PIP_PATTERNS = { ...@@ -77,6 +78,7 @@ DEFAULT_PIP_PATTERNS = {
"onnx", "onnx",
"nccl", "nccl",
"transformers", "transformers",
"zmq",
} }
......
...@@ -819,7 +819,8 @@ void paged_attention_v1_launcher( ...@@ -819,7 +819,8 @@ void paged_attention_v1_launcher(
if(num_heads!=num_kv_heads){ if(num_heads!=num_kv_heads){
num_threads =256; num_threads =256;
} }
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
...@@ -963,7 +964,7 @@ void paged_attention_v2_launcher( ...@@ -963,7 +964,7 @@ void paged_attention_v2_launcher(
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
......
...@@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { ...@@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
// #else // #else
return __bfloat1622float2(val); return __bfloat1622float2(val);
// #endif // #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
...@@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { ...@@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
// #else // #else
return __bfloat162bfloat162(val); return __bfloat162bfloat162(val);
// #endif // #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
// Vector addition. // Vector addition.
...@@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { ...@@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
return __hadd(a, b); return __hadd(a, b);
#endif #endif
// #endif // #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
...@@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { ...@@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
// #else // #else
return __hadd2(a, b); return __hadd2(a, b);
// #endif // #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
...@@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { ...@@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
// #else // #else
return __hmul(a, b); return __hmul(a, b);
// #endif // #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
template <> template <>
...@@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { ...@@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
// #else // #else
return __hmul2(a, b); return __hmul2(a, b);
// #endif // #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
template <> template <>
...@@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, ...@@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
// #else // #else
return __hfma2(a, b, c); return __hfma2(a, b, c);
// #endif // #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
...@@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, ...@@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
// #else // #else
return __hfma2(bf162bf162(a), b, c); return __hfma2(bf162bf162(a), b, c);
// #endif // #endif
__builtin_unreachable(); // Suppress missing return statement warning
} }
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
......
...@@ -48,6 +48,9 @@ ...@@ -48,6 +48,9 @@
} else if (HEADDIM == 128) { \ } else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \ constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \ } else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \ constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
......
...@@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, ...@@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype); const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);
// Just for unittest // Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment