diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..7f9e6d720fae5e3881c922172fca8fdb82d39890 --- /dev/null +++ b/.clang-format @@ -0,0 +1,26 @@ +BasedOnStyle: Google +UseTab: Never +IndentWidth: 2 +ColumnLimit: 80 + +# Force pointers to the type for C++. +DerivePointerAlignment: false +PointerAlignment: Left + +# Reordering #include statements can (and currently will) introduce errors +SortIncludes: false + +# Style choices +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +IndentPPDirectives: BeforeHash + +IncludeCategories: + - Regex: '^<' + Priority: 4 + - Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' + Priority: 3 + - Regex: '^"(qoda|\.\.)/' + Priority: 2 + - Regex: '.*' + Priority: 1 diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000000000000000000000000000000000..b7a9fdb4e05a872cd7f3e2515b09ddbd6635388e --- /dev/null +++ b/.coveragerc @@ -0,0 +1,47 @@ +[run] +# Track the installed vllm package (this is what actually gets imported during tests) +# Use wildcard pattern to match the installed location +source = + vllm + */dist-packages/vllm + */site-packages/vllm +omit = + */tests/* + */test_* + */__pycache__/* + */build/* + */dist/* + */vllm.egg-info/* + */third_party/* + */examples/* + */benchmarks/* + */docs/* + +[paths] +# Map all possible vllm locations to a canonical "vllm" path +# This ensures coverage.combine properly merges data from different test runs +source = + vllm + /vllm-workspace/src/vllm + /vllm-workspace/vllm + */site-packages/vllm + */dist-packages/vllm + +[report] +exclude_lines = + pragma: no cover + def __repr__ + if self.debug: + if settings.DEBUG + raise AssertionError + raise NotImplementedError + if 0: + if __name__ == .__main__.: + class .*\bProtocol\): + @(abc\.)?abstractmethod + +[html] +directory = htmlcov + +[xml] +output = coverage.xml diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..3863656915d035c3831614a5fcba05e09699542f --- /dev/null +++ b/.dockerignore @@ -0,0 +1,33 @@ +/.venv +/build +dist +vllm/*.so + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +.mypy_cache + +# Distribution / packaging +.Python +/build/ +cmake-build-*/ +CMakeUserPresets.json +develop-eggs/ +/dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000000000000000000000000000000000000..5a601d00cef8b9720cd6b078da6d1a14c5fff072 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,4 @@ +# Migrate from `yapf` & `isort` to `ruff` +d6953beb91da4e9c99be4c0a1304a2d24189535c +# Convert `Optional[x]` to `x | None` and `Union[x, y]` to `x | y` +8fcaaf6a165e661f63fc51be906bc05b0767332f diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..795071bd77f737e977fd790e0cadd0c39e174b86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,245 @@ +# version file generated by setuptools-scm +/vllm/_version.py + +# vllm-flash-attn built from source +vllm/vllm_flash_attn/* +!vllm/vllm_flash_attn/__init__.py +!vllm/vllm_flash_attn/flash_attn_interface.py + +# OpenAI triton kernels copied from source +vllm/third_party/triton_kernels/* + +# FlashMLA interface copied from source +vllm/third_party/flashmla/flash_mla_interface.py + +# triton jit +.triton + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +cmake-build-*/ +CMakeUserPresets.json +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +/.deps/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# generated files +**/generated/** + +# uv +uv.lock + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site +docs/argparse +docs/examples/* +!docs/examples/README.md + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# VSCode +.vscode/ + +# Claude +CLAUDE.md +.claude/ + +# Codex +AGENTS.md +.codex/ + +# Cursor +.cursor/ + +# DS Store +.DS_Store + +# Results +*.csv + +# Python pickle files +*.pkl + +# Sphinx documentation +_build/ + +# vim swap files +*.swo +*.swp + +# hip files generated by PyTorch +*.hip +*_hip* +hip_compat.h + +# Benchmark dataset +benchmarks/**/*.json + +# Linting +actionlint +shellcheck*/ + +# Ignore moe/marlin_moe gen code +csrc/moe/marlin_moe_wna16/kernel_* + +# Ignore ep_kernels_workspace folder +ep_kernels_workspace/ + +# Allow tracked library source folders under submodules (e.g., benchmarks/lib) +!vllm/benchmarks/lib/ + +# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto) +vllm/grpc/vllm_engine_pb2.py +vllm/grpc/vllm_engine_pb2_grpc.py +vllm/grpc/vllm_engine_pb2.pyi + +# Ignore generated cpu headers +csrc/cpu/cpu_attn_dispatch_generated.h diff --git a/.markdownlint.yaml b/.markdownlint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..937487f47364d04a8f8cb02a6bd0b8ac467362ec --- /dev/null +++ b/.markdownlint.yaml @@ -0,0 +1,11 @@ +MD007: + indent: 4 +MD013: false +MD024: + siblings_only: true +MD031: + list_items: false +MD033: false +MD046: false +MD052: false +MD059: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33460222ec10daa0b76f4500a813c63da399cdd7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,158 @@ +default_install_hook_types: + - pre-commit + - commit-msg +default_stages: + - pre-commit # Run locally + - manual # Run in CI +exclude: 'vllm/third_party/.*' +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.0 + hooks: + - id: ruff-check + args: [--output-format, github, --fix] + - id: ruff-format +- repo: https://github.com/crate-ci/typos + rev: v1.38.1 + hooks: + - id: typos + args: [--force-exclude] +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v21.1.2 + hooks: + - id: clang-format + exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' + types_or: [c++, cuda] + args: [--style=file, --verbose] +- repo: https://github.com/igorshubovych/markdownlint-cli + rev: v0.45.0 + hooks: + - id: markdownlint + exclude: '.*\.inc\.md' + stages: [manual] # Only run in CI +- repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint +- repo: https://github.com/astral-sh/uv-pre-commit + rev: 0.9.1 + hooks: + - id: pip-compile + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28, --python-version, "3.12"] + files: ^requirements/test\.(in|txt)$ +- repo: local + hooks: + - id: format-torch-nightly-test + name: reformat nightly_torch_test.txt to be in sync with test.in + language: python + entry: python tools/pre_commit/generate_nightly_torch_test.py + files: ^requirements/test\.(in|txt)$ + - id: mypy-local + name: Run mypy locally for lowest supported Python version + entry: python tools/pre_commit/mypy.py 0 "3.10" + stages: [pre-commit] # Don't run in CI + <<: &mypy_common + language: python + types_or: [python, pyi] + require_serial: true + additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] + - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.10 + entry: python tools/pre_commit/mypy.py 1 "3.10" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.11 + entry: python tools/pre_commit/mypy.py 1 "3.11" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.12 + entry: python tools/pre_commit/mypy.py 1 "3.12" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.13 + entry: python tools/pre_commit/mypy.py 1 "3.13" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: shellcheck + name: Lint shell scripts + entry: tools/pre_commit/shellcheck.sh + language: script + types: [shell] + - id: png-lint + name: Lint PNG exports from excalidraw + entry: tools/pre_commit/png-lint.sh + language: script + types: [png] + - id: signoff-commit + name: Sign-off Commit + entry: bash + args: + - -c + - | + if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then + printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)" + fi + language: system + verbose: true + stages: [commit-msg] + - id: check-spdx-header + name: Check SPDX headers + entry: python tools/pre_commit/check_spdx_header.py + language: python + types: [python] + - id: check-root-lazy-imports + name: Check root lazy imports + entry: python tools/pre_commit/check_init_lazy_imports.py + language: python + types: [python] + - id: check-filenames + name: Check for spaces in all filenames + entry: bash + args: + - -c + - 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' + language: system + always_run: true + pass_filenames: false + - id: update-dockerfile-graph + name: Update Dockerfile dependency graph + entry: tools/pre_commit/update-dockerfile-graph.sh + language: script + - id: check-forbidden-imports + name: Check for forbidden imports + entry: python tools/pre_commit/check_forbidden_imports.py + language: python + types: [python] + additional_dependencies: [regex] + - id: validate-config + name: Validate configuration has default values and that each field has a docstring + entry: python tools/pre_commit/validate_config.py + language: python + additional_dependencies: [regex] + - id: validate-docker-versions + name: Validate docker/versions.json matches Dockerfile + entry: python tools/generate_versions_json.py --check + language: python + files: ^docker/(Dockerfile|versions\.json)$ + pass_filenames: false + additional_dependencies: [dockerfile-parse] + - id: attention-backend-docs + name: Check attention backend documentation is up to date + entry: python tools/pre_commit/generate_attention_backend_docs.py --check + language: python + - id: check-boolean-context-manager + name: Check for boolean ops in with-statements + entry: python tools/pre_commit/check_boolean_context_manager.py + language: python + types: [python] + # Keep `suggestion` last + - id: suggestion + name: Suggestion + entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=."' + language: system + verbose: true + pass_filenames: false + # Insert new entries above the `suggestion` entry diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f372a3fb8cc9c7bc745630dd454379cddbe8d9ed --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,22 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.12" + jobs: + post_checkout: + - git fetch origin main --unshallow --no-tags --filter=blob:none || true + pre_create_environment: + - pip install uv + create_environment: + - uv venv $READTHEDOCS_VIRTUALENV_PATH + install: + - uv pip install --python $READTHEDOCS_VIRTUALENV_PATH/bin/python --no-cache-dir -r requirements/docs.txt + +mkdocs: + configuration: mkdocs.yaml + fail_on_warning: true diff --git a/.shellcheckrc b/.shellcheckrc new file mode 100644 index 0000000000000000000000000000000000000000..f3b6eedf8d907ca8cefdc6266fe2ee04130cf564 --- /dev/null +++ b/.shellcheckrc @@ -0,0 +1,9 @@ +# rules currently disabled: +# +# SC1091 (info): Not following: was not specified as input (see shellcheck -x) +# SC2004 (style): $/${} is unnecessary on arithmetic variables. +# SC2129 (style): Consider using { cmd1; cmd2; } >> file instead of individual redirects. +# SC2155 (warning): Declare and assign separately to avoid masking return values. +# SC2164 (warning): Use 'cd ... || exit' or 'cd ... || return' in case cd fails. +# +disable=SC1091,SC2004,SC2129,SC2155,SC2164 diff --git a/.yapfignore b/.yapfignore new file mode 100644 index 0000000000000000000000000000000000000000..38158259032a69d0c44cd0e34d23fca8948a5a33 --- /dev/null +++ b/.yapfignore @@ -0,0 +1,2 @@ +collect_env.py +vllm/model_executor/layers/fla/ops/*.py diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..65df275cd3148d1236d0053aeef2b4affc2a4e9f --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,1199 @@ +cmake_minimum_required(VERSION 3.26) + +# When building directly using CMake, make sure you run the install step +# (it places the .so files in the correct location). +# +# Example: +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. .. +# cmake --build . --target install +# +# If you want to only build one target, make sure to install it manually: +# cmake --build . --target _C +# cmake --install . --component _C +project(vllm_extensions LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + + +# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) +set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") +message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") +message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") + +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + +# Suppress potential warnings about unused manually-specified variables +set(ignoreMe "${VLLM_PYTHON_PATH}") + +# Prevent installation of dependencies (cutlass) by default. +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) + +# +# Supported python versions. These versions will be searched in order, the +# first match will be selected. These should be kept in sync with setup.py. +# +set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13") + +# Supported AMD GPU architectures. +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151") + +# ROCm installation prefix. Default to /opt/rocm but allow override via +# -DROCM_PATH=/your/rocm/path when invoking cmake. +if(NOT DEFINED ROCM_PATH) + set(ROCM_PATH "/opt/rocm" CACHE PATH "ROCm installation prefix") +else() + set(ROCM_PATH ${ROCM_PATH} CACHE PATH "ROCm installation prefix" FORCE) +endif() +# +# Supported/expected torch versions for CUDA/ROCm. +# +# Currently, having an incorrect pytorch version results in a warning +# rather than an error. +# +# Note: the CUDA torch version is derived from pyproject.toml and various +# requirements.txt files and should be kept consistent. The ROCm torch +# versions are derived from docker/Dockerfile.rocm +# +set(TORCH_SUPPORTED_VERSION_CUDA "2.10.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.10.0") + +# +# Try to find python package with an executable that exactly matches +# `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions. +# +if (VLLM_PYTHON_EXECUTABLE) + find_python_from_executable(${VLLM_PYTHON_EXECUTABLE} "${PYTHON_SUPPORTED_VERSIONS}") +else() + message(FATAL_ERROR + "Please set VLLM_PYTHON_EXECUTABLE to the path of the desired python version" + " before running cmake configure.") +endif() + +# +# Update cmake's `CMAKE_PREFIX_PATH` with torch location. +# +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") + +# Ensure the 'nvcc' command is in the PATH +find_program(NVCC_EXECUTABLE nvcc) +if (CUDA_FOUND AND NOT NVCC_EXECUTABLE) + message(FATAL_ERROR "nvcc not found") +endif() + +# +# Import torch cmake configuration. +# Torch also imports CUDA (and partially HIP) languages with some customizations, +# so there is no need to do this explicitly with check_language/enable_language, +# etc. +# +find_package(Torch REQUIRED) + +# Supported NVIDIA architectures. +# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined +if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0") +elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") +else() + set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") +endif() + +# +# Forward the non-CUDA device extensions to external CMake scripts. +# +if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND + NOT VLLM_TARGET_DEVICE STREQUAL "rocm") + if (VLLM_TARGET_DEVICE STREQUAL "cpu") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) + else() + return() + endif() + return() +endif() + +# +# Set up GPU language and check the torch version and warn if it isn't +# what is expected. +# +if (NOT HIP_FOUND AND CUDA_FOUND) + set(VLLM_GPU_LANG "CUDA") + + if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA}) + message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} " + "expected for CUDA build, saw ${Torch_VERSION} instead.") + endif() +elseif(HIP_FOUND) + set(VLLM_GPU_LANG "HIP") + + # Importing torch recognizes and sets up some HIP/ROCm configuration but does + # not let cmake recognize .hip files. In order to get cmake to understand the + # .hip extension automatically, HIP must be enabled explicitly. + enable_language(HIP) + + # ROCm 5.X and 6.X + if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND + Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM}) + message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} " + "expected for ROCm build, saw ${Torch_VERSION} instead.") + endif() +else() + message(FATAL_ERROR "Can't find CUDA or HIP installation.") +endif() + + +if(VLLM_GPU_LANG STREQUAL "CUDA") + # + # For cuda we want to be able to control which architectures we compile for on + # a per-file basis in order to cut down on compile time. So here we extract + # the set of architectures we want to compile for and remove the from the + # CMAKE_CUDA_FLAGS so that they are not applied globally. + # + clear_cuda_arches(CUDA_ARCH_FLAGS) + extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") + message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") + # Filter the target architectures by the supported supported archs + # since for some files we will build for all CUDA_ARCHS. + cuda_archs_loose_intersection(CUDA_ARCHS + "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") + message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") +else() + # + # For other GPU targets override the GPU architectures detected by cmake/torch + # and filter them by the supported versions for the current language. + # The final set of arches is stored in `VLLM_GPU_ARCHES`. + # + override_gpu_arches(VLLM_GPU_ARCHES + ${VLLM_GPU_LANG} + "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") +endif() + +# +# Query torch for additional GPU compilation flags for the given +# `VLLM_GPU_LANG`. +# The final set of arches is stored in `VLLM_GPU_FLAGS`. +# +get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) + +# +# Set nvcc parallelism. +# +if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") +endif() + +# +# Set compression mode for CUDA >=13.x. +# +if(VLLM_GPU_LANG STREQUAL "CUDA" AND + DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + list(APPEND VLLM_GPU_FLAGS "--compress-mode=size") +endif() + +# +# Set CUDA include flags for CXX compiler. +# +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include") + if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl") + endif() +endif() + +# +# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. +# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. +# Each dependency that produces build artifacts should override its BINARY_DIR to avoid +# conflicts between build types. It should instead be set to ${CMAKE_BINARY_DIR}/. +# +include(FetchContent) +file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists +message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") + +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # Overriding the default -O set up by cmake, adding ggdb3 for the most verbose devug info + # + set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3") + + # + # Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates + # a lot of warnings that always mask real issues. Suppressing until this is properly addressed. + # + set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result") +endif() + +# +# Define other extension targets +# + +# +# cumem_allocator extension +# + +set(VLLM_CUMEM_EXT_SRC + "csrc/cumem_allocator.cpp") + +set_gencode_flags_for_srcs( + SRCS "${VLLM_CUMEM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") + message(STATUS "Enabling cumem allocator extension.") + if(VLLM_GPU_LANG STREQUAL "CUDA") + # link against cuda driver library + list(APPEND CUMEM_LIBS CUDA::cuda_driver) + else() + # link against rocm driver library. Prefer an absolute path to + # libamdhip64.so inside ${ROCM_PATH}/lib if available, otherwise fall + # back to linking by name "amdhip64". + find_library(AMDHIP64_LIB + NAMES amdhip64 libamdhip64.so + PATHS ${ROCM_PATH}/lib + NO_DEFAULT_PATH) + if(AMDHIP64_LIB) + message(STATUS "Found libamdhip64 at ${AMDHIP64_LIB}") + list(APPEND CUMEM_LIBS ${AMDHIP64_LIB}) + else() + message(WARNING "libamdhip64 not found in ${ROCM_PATH}/lib; falling back to linking 'amdhip64' by name") + list(APPEND CUMEM_LIBS amdhip64) + endif() + endif() + define_extension_target( + cumem_allocator + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_CUMEM_EXT_SRC} + LIBRARIES ${CUMEM_LIBS} + USE_SABI 3.8 + WITH_SOABI) +endif() + +# +# _C extension +# + +set(VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/cache_kernels.cu" + "csrc/cache_kernels_fused.cu" + "csrc/attention/paged_attention_v1.cu" + "csrc/attention/paged_attention_v2.cu" + "csrc/attention/merge_attn_states.cu" + "csrc/attention/vertical_slash_index.cu" + "csrc/pos_encoding_kernels.cu" + "csrc/activation_kernels.cu" + "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" + "csrc/layernorm_quant_kernels.cu" + "csrc/sampler.cu" + "csrc/topk.cu" + "csrc/cuda_view.cu" + "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/activation_kernels.cu" + "csrc/cuda_utils_kernels.cu" + "csrc/custom_all_reduce.cu" + "csrc/torch_bindings.cpp") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + + # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. + set(CUTLASS_REVISION "v4.2.1") + + # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided + if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) + set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) + endif() + + if(VLLM_CUTLASS_SRC_DIR) + if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) + get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) + endif() + message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") + FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) + else() + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # Please keep this in sync with CUTLASS_REVISION line above. + GIT_TAG ${CUTLASS_REVISION} + GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE + ) + endif() + FetchContent_MakeAvailable(cutlass) + + list(APPEND VLLM_EXT_SRC + "csrc/quantization/awq/gemm_kernels.cu" + "csrc/permute_cols.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" + "csrc/quantization/fp4/nvfp4_quant_entry.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/cutlass_extensions/common.cpp" + "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/quantization/w8a8/int8/per_token_group_quant.cu") + + set_gencode_flags_for_srcs( + SRCS "${VLLM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + + # Only build Marlin kernels if we are building for at least some compatible archs. + # Keep building Marlin for 9.0 as there are some group sizes and shapes that + # are not supported by Machete yet. + + # marlin arches for fp16 output + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # marlin has limited support for turing + cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}") + # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) + cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + # marlin arches for fp8 input + # - sm80 doesn't support fp8 computation + # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) + cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + # marlin arches for other files + cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") + + if (MARLIN_OTHER_ARCHS) + + # + # For the Marlin kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/marlin/generate_kernels.py) + file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") + + message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + + if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=$ENV{PYTHONPATH} + ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} + RESULT_VARIABLE marlin_generation_result + OUTPUT_VARIABLE marlin_generation_result + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log + ) + + if (NOT marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin generation failed." + " Result: \"${marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") + else() + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + CACHE STRING "Last run Marlin generate script hash and arch" FORCE) + message(STATUS "Marlin generation completed successfully.") + endif() + else() + message(STATUS "Marlin generation script has not changed, skipping generation.") + endif() + + if (MARLIN_ARCHS) + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_float16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) + + file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_bfloat16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_BF16_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) + endif() + + if (MARLIN_SM75_ARCHS) + file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/marlin/sm75_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_SM75_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC}) + endif() + + if (MARLIN_FP8_ARCHS) + file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/marlin/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC}) + endif() + + set(MARLIN_SRCS + "csrc/quantization/marlin/marlin.cu" + "csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu" + "csrc/quantization/marlin/gptq_marlin_repack.cu" + "csrc/quantization/marlin/awq_marlin_repack.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_SRCS}" + CUDA_ARCHS "${MARLIN_OTHER_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_SRCS} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") + + message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}") + else() + message(STATUS "Not building Marlin kernels as no compatible archs found" + " in CUDA target architectures") + endif() + + # Only build AllSpark kernels if we are building for at least some compatible archs. + cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") + if (ALLSPARK_ARCHS) + set(ALLSPARK_SRCS + "csrc/quantization/gptq_allspark/allspark_repack.cu" + "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") + set_gencode_flags_for_srcs( + SRCS "${ALLSPARK_SRCS}" + CUDA_ARCHS "${ALLSPARK_ARCHS}") + list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") + message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") + else() + message(STATUS "Not building AllSpark kernels as no compatible archs found" + " in CUDA target architectures") + endif() + + + set(SCALED_MM_3X_ARCHS) + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + + # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.8 or later + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + + # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) + # require CUDA 12.8 or later + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # + # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) + # kernels for the remaining archs that are not already built for 3x. + # (Build 8.9 for FP8) + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + "7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}") + # subtract out the archs that are already built for 3x + list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) + if (SCALED_MM_2X_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") + message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") + else() + if (SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c2x as all archs are already built" + " for and covered by scaled_mm_c3x") + else() + message(STATUS "Not building scaled_mm_c2x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # + # 2:4 Sparse Kernels + + # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor + # require CUDA 12.2 or later (and only work on Hopper). + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) + set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") + message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) + message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " + "if you intend on running FP8 sparse quantized models on Hopper.") + else() + message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require + # CUDA 12.8 or later + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") + else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) + endif() + + # FP4 Archs and flags + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") + else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) + endif() + + # CUTLASS MLA Archs and flags + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) + set(SRCS + "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${MLA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") + # Add MLA-specific include directories only to MLA source files + set_source_files_properties(${SRCS} + PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") + message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") + else() + message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") + # clear MLA_ARCHS + set(MLA_ARCHS) + endif() + + # CUTLASS MoE kernels + + # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works + # on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled + # if it's possible to compile MoE kernels that use its output. + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS) + set(SRCS + "csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu" + "csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1") + message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 + AND ES_MXFP8_GROUPED_MM_ARCHS) + message(STATUS "Not building ES MXFP8 grouped kernels as CUDA Compiler version is " + "not >= 12.8.") + else() + message(STATUS "Not building ES MXFP8 grouped kernels as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS) + set(DSV3_FUSED_A_GEMM_SRC "csrc/dsv3_fused_a_gemm.cu") + set_gencode_flags_for_srcs( + SRCS "${DSV3_FUSED_A_GEMM_SRC}" + CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") + list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC}) + message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") + else() + message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " + "in CUDA target architectures.") + endif() + + # moe_data.cu is used by all CUTLASS MoE kernels. + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + message(STATUS "Not building moe_data as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") + else() + message(STATUS "Not building moe_data as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + + # + # Machete kernels + + # The machete kernels only work on hopper and require CUDA 12.0 or later. + # Only build Machete kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS) + # + # For the Machete kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MACHETE_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py) + file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH) + + message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}") + message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH} + OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH} + ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH} + CACHE STRING "Last run machete generate script hash" FORCE) + message(STATUS "Machete generation completed successfully.") + endif() + else() + message(STATUS "Machete generation script has not changed, skipping generation.") + endif() + + # Add machete generated sources + file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") + list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) + + # forward compatible + set_gencode_flags_for_srcs( + SRCS "${MACHETE_GEN_SOURCES}" + CUDA_ARCHS "${MACHETE_ARCHS}") + + list(APPEND VLLM_EXT_SRC + csrc/quantization/machete/machete_pytorch.cu) + + message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND MACHETE_ARCHS) + message(STATUS "Not building Machete kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building Machete kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + + # Only build W4A8 kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_utils.cu" + ) + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + + # Hadacore kernels + cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + if(HADACORE_ARCHS) + set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${HADACORE_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building hadacore") + endif() + +# if CUDA endif +endif() + +if (VLLM_GPU_LANG STREQUAL "HIP") + # Add QuickReduce kernels + list(APPEND VLLM_EXT_SRC + "csrc/custom_quickreduce.cu" + ) +# if ROCM endif +endif() + +message(STATUS "Enabling C extension.") +define_extension_target( + _C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + +# If CUTLASS is compiled on NVCC >= 12.5, it by default uses +# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the +# driver API. This causes problems when linking with earlier versions of CUDA. +# Setting this variable sidesteps the issue by calling the driver directly. +target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + +# +# _moe_C extension +# + +set(VLLM_MOE_EXT_SRC + "csrc/moe/torch_bindings.cpp" + "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/topk_softmax_kernels.cu") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/moe_wna16.cu" + "csrc/moe/grouped_topk_kernels.cu" + "csrc/moe/router_gemm.cu") +endif() + +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(MOE_PERMUTE_SRC + "csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu" + "csrc/moe/moe_permute_unpermute_op.cu") + + list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}") +endif() + +set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(VLLM_MOE_WNA16_SRC + "csrc/moe/moe_wna16.cu") + + set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_WNA16_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + + list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") + # moe marlin arches + # note that we always set `use_atomic_add=False` for moe marlin now, + # so we don't need 9.0 for bf16 atomicAdd PTX + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # moe marlin has limited support for turing + cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}") + # moe marlin arches for fp8 input + # - sm80 doesn't support fp8 computation + # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) + cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + # moe marlin arches for other files + cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") + if (MARLIN_MOE_OTHER_ARCHS) + + # + # For the Marlin MOE kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MOE_MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) + file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") + + message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + + if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=$ENV{PYTHONPATH} + ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} + RESULT_VARIABLE moe_marlin_generation_result + OUTPUT_VARIABLE moe_marlin_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ) + + if (NOT moe_marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin MOE generation failed." + " Result: \"${moe_marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") + else() + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + CACHE STRING "Last run Marlin MOE generate script hash" FORCE) + message(STATUS "Marlin MOE generation completed successfully.") + endif() + else() + message(STATUS "Marlin MOE generation script has not changed, skipping generation.") + endif() + + if (MARLIN_MOE_ARCHS) + file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_SRC}" + CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) + endif() + + if (MARLIN_MOE_SM75_ARCHS) + file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_SM75_SRC}" + CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_SM75_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC}) + endif() + + if (MARLIN_MOE_FP8_ARCHS) + file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_FP8_SRC}" + CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_FP8_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC}) + endif() + + set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_OTHER_SRC}" + CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_OTHER_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}") + + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}") + else() + message(STATUS "Not building Marlin MOE kernels as no compatible archs found" + " in CUDA target architectures") + endif() + + # DeepSeek V3 router GEMM kernel - requires SM90+ + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_ROUTER_GEMM_ARCHS) + set(DSV3_ROUTER_GEMM_SRC + "csrc/moe/dsv3_router_gemm_entry.cu" + "csrc/moe/dsv3_router_gemm_float_out.cu" + "csrc/moe/dsv3_router_gemm_bf16_out.cu") + set_gencode_flags_for_srcs( + SRCS "${DSV3_ROUTER_GEMM_SRC}" + CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}") + list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}") + message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}") + else() + message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found" + " (requires SM90+ and CUDA >= 12.0)") + endif() +endif() + +message(STATUS "Enabling moe extension.") +define_extension_target( + _moe_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_MOE_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI) + +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # _rocm_C extension + # + set(VLLM_ROCM_EXT_SRC + "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/skinny_gemms.cu" + "csrc/rocm/attention.cu") + + define_extension_target( + _rocm_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_ROCM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) +endif() + +# For CUDA and HIP builds also build the triton_kernels external package. +if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") + include(cmake/external_projects/triton_kernels.cmake) +endif() + +# For CUDA we also build and ship some external projects. +if (VLLM_GPU_LANG STREQUAL "CUDA") + include(cmake/external_projects/flashmla.cmake) + include(cmake/external_projects/qutlass.cmake) + + # vllm-flash-attn should be last as it overwrites some CMake functions + include(cmake/external_projects/vllm_flash_attn.cmake) +endif () diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..5268ff135c9d0d5b064dbe30aaa577e49071e33b --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,127 @@ + +# vLLM Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socioeconomic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official email address, +posting via an official social media account, or acting as an appointed +representative at an online or offline/IRL event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement in the #code-of-conduct +channel in the [vLLM Slack](https://slack.vllm.ai). +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), +version 2.1, available at +[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html). + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion). + +For answers to common questions about this code of conduct, see the +[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at +[Contributor Covenant translations](https://www.contributor-covenant.org/translations). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..2947aad75ee5613b4bdf8019b581620b8073149c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,3 @@ +# Contributing to vLLM + +You may find information about contributing to vLLM on [docs.vllm.ai](https://docs.vllm.ai/en/latest/contributing). diff --git a/DCO b/DCO new file mode 100644 index 0000000000000000000000000000000000000000..49b8cb0549267a8176467738b172a63d86eff436 --- /dev/null +++ b/DCO @@ -0,0 +1,34 @@ +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..fb3cccbb4a9c156bc3aa0b08c8333e1d5340dcda --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,9 @@ +include LICENSE +include requirements/common.txt +include requirements/cuda.txt +include requirements/rocm.txt +include requirements/cpu.txt +include CMakeLists.txt + +recursive-include cmake * +recursive-include csrc * diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000000000000000000000000000000000..dfd4fa1ae04d499663b4b315a9fc4988408cbfc9 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,73 @@ +# Releasing vLLM + +vLLM releases offer a reliable version of the code base, packaged into a binary format that can be conveniently accessed via [PyPI](https://pypi.org/project/vllm). These releases also serve as key milestones for the development team to communicate with the community about newly available features, improvements, and upcoming changes that could affect users, including potential breaking changes. + +## Release Cadence and Versioning + +We aim to have a regular release every 2 weeks. Since v0.12.0, regular releases increment the minor version rather than patch version. The list of past releases can be found [here](https://vllm.ai/releases). + +Our version numbers are expressed in the form `vX.Y.Z`, where `X` is the major version, `Y` is the minor version, and `Z` is the patch version. They are incremented according to the following rules: + +* _Major_ releases are reserved for architectural milestones involving sweeping API changes, similar to PyTorch 2.0. +* _Minor_ releases correspond to regular releases, which include new features, bug fixes and other backwards-compatible changes. +* _Patch_ releases correspond to special releases for new models, as well as emergency patches for critical performance, functionality and security issues. + +This versioning scheme is similar to [SemVer](https://semver.org/) for compatibility purposes, except that backwards compatibility is only guaranteed for a limited number of minor releases (see our [deprecation policy](https://docs.vllm.ai/en/latest/contributing/deprecation_policy) for details). + +## Release Branch + +Each release is built from a dedicated release branch. + +* For _major_ and _minor_ releases, the release branch cut is performed 1-2 days before release is live. +* For _patch_ releases, previously cut release branch is reused. +* Release builds are triggered via push to RC tag like `vX.Y.Z-rc1`. This enables us to build and test multiple RCs for each release. +* Final tag: `vX.Y.Z` does not trigger the build but used for Release notes and assets. +* After branch cut is created, we monitor the main branch for any reverts and apply these reverts to a release branch. + +### Cherry-Pick Criteria + +After branch cut, we approach finalizing the release branch with clear criteria on what cherry picks are allowed in. Note: a cherry pick is a process to land a PR in the release branch after branch cut. These are typically limited to ensure that the team has sufficient time to complete a thorough round of testing on a stable code base. + +* Regression fixes - that address functional/performance regression against the most recent release (e.g. 0.7.0 for 0.7.1 release) +* Critical fixes - critical fixes for severe issue such as silent incorrectness, backwards compatibility, crashes, deadlocks, (large) memory leaks +* Fixes to new features introduced in the most recent release (e.g. 0.7.0 for 0.7.1 release) +* Documentation improvements +* Release branch specific changes (e.g. change version identifiers or CI fixes) + +Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes. + +## Manual validations + +### E2E Performance Validation + +Before each release, we perform end-to-end performance validation to ensure no regressions are introduced. This validation uses the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) on PyTorch CI. + +**Current Coverage:** + +* Models: Llama3, Llama4, and Mixtral +* Hardware: NVIDIA H100 and AMD MI300x +* _Note: Coverage may change based on new model releases and hardware availability_ + +**Performance Validation Process:** + +**Step 1: Get Access** +Request write access to the [pytorch/pytorch-integration-testing](https://github.com/pytorch/pytorch-integration-testing) repository to run the benchmark workflow. + +**Step 2: Review Benchmark Setup** +Familiarize yourself with the benchmark configurations: + +* [CUDA setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/cuda) +* [ROCm setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/rocm) + +**Step 3: Run the Benchmark** +Navigate to the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) and configure: + +* **vLLM branch**: Set to the release branch (e.g., `releases/v0.9.2`) +* **vLLM commit**: Set to the RC commit hash + +**Step 4: Review Results** +Once the workflow completes, benchmark results will be available on the [vLLM benchmark dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm) under the corresponding branch and commit. + +**Step 5: Performance Comparison** +Compare the current results against the previous release to verify no performance regressions have occurred. Here is an +example of [v0.9.1 vs v0.9.2](https://hud.pytorch.org/benchmark/llms?startTime=Thu%2C%2017%20Apr%202025%2021%3A43%3A50%20GMT&stopTime=Wed%2C%2016%20Jul%202025%2021%3A43%3A50%20GMT&granularity=week&lBranch=releases/v0.9.1&lCommit=b6553be1bc75f046b00046a4ad7576364d03c835&rBranch=releases/v0.9.2&rCommit=a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f&repoName=vllm-project%2Fvllm&benchmarkName=&modelName=All%20Models&backendName=All%20Backends&modeName=All%20Modes&dtypeName=All%20DType&deviceName=All%20Devices&archName=All%20Platforms). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..d6319cdb1ac27215cd0a78ed47a408867e3ef434 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,50 @@ +# Security Policy + +## Reporting security issues + +Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). + +## Issue triage + +Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). + +## Threat model + +Please see the [Security Guide in the vLLM documentation](https://docs.vllm.ai/en/latest/usage/security.html) for more information on vLLM's security assumptions and recommendations. + +Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. + +## Issue severity + +We will determine the risk of each issue, taking into account our experience dealing with past issues, versions affected, common defaults, and use cases. We use the following severity categories: + +### CRITICAL Severity + +Vulnerabilities that allow remote attackers to execute arbitrary code, take full control of the system, or significantly compromise confidentiality, integrity, or availability without any interaction or privileges needed, examples include remote code execution via network, deserialization issues that allow exploit chains. Generally those issues which are rated as CVSS ≥ 9.0. + +### HIGH Severity + +Serious security flaws that allow elevated impact—like RCE in specific, limited contexts or significant data loss—but require advanced conditions or some trust, examples include RCE in advanced deployment modes (e.g. multi-node), or high impact issues where some sort of privileged network access is required. These issues typically have CVSS scores between 7.0 and 8.9 + +### MODERATE Severity + +Vulnerabilities that cause denial of service or partial disruption, but do not allow arbitrary code execution or data breach and have limited impact. These issues have a CVSS rating between 4.0 and 6.9 + +### LOW Severity + +Minor issues such as informational disclosures, logging errors, non-exploitable flaws, or weaknesses that require local or high-privilege access and offer negligible impact. Examples include side channel attacks or hash collisions. These issues often have CVSS scores less than 4.0 + +## Prenotification policy + +For certain security issues of CRITICAL, HIGH, or MODERATE severity level, we may prenotify certain organizations or vendors that ship vLLM. The purpose of this prenotification is to allow for a coordinated release of fixes for severe issues. + +* This prenotification will be in the form of a private email notification. It may also include adding security contacts to the GitHub security advisory, typically a few days before release. + +* If you wish to be added to the prenotification group, please send an email copying all the members of the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Each vendor contact will be analyzed on a case-by-case basis. + +* Organizations and vendors who either ship or use vLLM, are eligible to join the prenotification group if they meet at least one of the following qualifications + * Substantial internal deployment leveraging the upstream vLLM project. + * Established internal security teams and comprehensive compliance measures. + * Active and consistent contributions to the upstream vLLM project. + +* We may withdraw organizations from receiving future prenotifications if they release fixes or any other information about issues before they are public. Group membership may also change based on policy refinements for who may be included. diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..26896a77bf3c75b4f2396e9f37540dc5367c2908 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,20 @@ +# Benchmarks + +This directory used to contain vLLM's benchmark scripts and utilities for performance testing and evaluation. + +## Contents + +- **Serving benchmarks**: Scripts for testing online inference performance (latency, throughput) +- **Throughput benchmarks**: Scripts for testing offline batch inference performance +- **Specialized benchmarks**: Tools for testing specific features like structured output, prefix caching, long document QA, request prioritization, and multi-modal inference +- **Dataset utilities**: Framework for loading and sampling from various benchmark datasets (ShareGPT, HuggingFace datasets, synthetic data, etc.) + +## Usage + +For detailed usage instructions, examples, and dataset information, see the [Benchmark CLI documentation](https://docs.vllm.ai/en/latest/benchmarking/cli/#benchmark-cli). + +For full CLI reference see: + +- +- +- diff --git a/benchmarks/attention_benchmarks/README.md b/benchmarks/attention_benchmarks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..788ce94f23fb8e275cdc931c451af63b1b52c704 --- /dev/null +++ b/benchmarks/attention_benchmarks/README.md @@ -0,0 +1,266 @@ +# vLLM Attention Benchmarking Suite + +Fast, flexible benchmarking for vLLM attention and MLA backends with an extended batch specification grammar. + +## Quick Start + +```bash +cd benchmarks/attention_benchmarks + +# Run a pre-configured benchmark +python benchmark.py --config configs/mla_decode.yaml +python benchmark.py --config configs/mla_mixed_batch.yaml +python benchmark.py --config configs/speculative_decode.yaml +python benchmark.py --config configs/standard_attention.yaml +python benchmark.py --config configs/reorder_threshold.yaml + +# Or run custom benchmarks +python benchmark.py \ + --backends flash flashinfer \ + --batch-specs "q2k" "8q1s1k" "2q2k_32q1s1k" \ + --output-csv results.csv +``` + +## Simplified Batch Specification Grammar + +Express workloads concisely using query length and sequence length: + +```python +"q2k" # 2048-token prefill (q_len=2048, seq_len=2048) +"q1s1k" # Decode: 1 token with 1K sequence +"8q1s1k" # 8 decode requests +"q4s1k" # 4-token extend (e.g., spec decode) +"2q2k_32q1s1k" # Mixed: 2 prefills + 32 decodes +"16q4s1k" # 16 spec decode (4 tokens each) +``` + +### Grammar Rule + +```text +Format: (?) q(k?) (s(k?))? + +- count: Number of identical requests (optional, default=1) +- q_len: Query length (number of new tokens) +- seq_len: Total sequence length (optional, defaults to q_len for prefill) +- 'k': Multiplies value by 1024 + +Mixed batches: Use _ to combine (e.g., "2q2k_32q1s1k") +``` + +**Note**: Decode, prefill, and spec decode are just different query lengths - no special syntax needed! + +## Pre-configured Benchmarks + +The suite includes several pre-configured YAML benchmark configurations: + +### MLA Decode Benchmark + +Tests pure decode performance across MLA backends with varying batch sizes and sequence lengths. + +```bash +python benchmark.py --config configs/mla_decode.yaml +``` + +### MLA Mixed Batch Benchmark + +Tests chunked prefill performance with mixed prefill + decode batches. + +```bash +python benchmark.py --config configs/mla_mixed_batch.yaml +``` + +### Speculative Decoding Benchmark + +Tests speculative decode scenarios (K-token verification) and reorder_batch_threshold optimization. + +```bash +python benchmark.py --config configs/speculative_decode.yaml +``` + +### Standard Attention Benchmark + +Tests standard attention backends (Flash/Triton/FlashInfer) with pure prefill, decode, and mixed batches. + +```bash +python benchmark.py --config configs/standard_attention.yaml +``` + +### Reorder Threshold Study + +**Question:** At what query length does the prefill pipeline become faster than the decode pipeline? + +Tests query lengths from 1-1024 across 9 batch sizes to find the crossover point. Uses `decode_vs_prefill` mode to compare both pipelines for each query length. + +```bash +python benchmark.py --config configs/reorder_threshold.yaml +``` + +--- + +## Universal Benchmark + +The `benchmark.py` script handles **all** backends - both standard attention and MLA. + +### Standard Attention (Flash/Triton/FlashInfer) + +```bash +python benchmark.py \ + --backends flash triton flashinfer \ + --batch-specs "q2k" "8q1s1k" "2q2k_32q1s1k" \ + --num-layers 10 \ + --repeats 5 \ + --output-csv results.csv +``` + +### MLA Backends + +```bash +# Compare all MLA backends +python benchmark.py \ + --backends cutlass_mla flashinfer_mla flashattn_mla flashmla \ + --batch-specs "64q1s1k" "64q1s4k" \ + --output-csv mla_results.csv +``` + +### Parameter Sweeps + +Use `--sweep-param` and `--sweep-values` to run parameter sweeps from the CLI: + +#### CUTLASS MLA num-splits Optimization + +**Question:** What is the optimal `num_kv_splits` for CUTLASS MLA? + +```bash +python benchmark.py \ + --backend cutlass_mla \ + --batch-specs "64q1s1k" "64q1s4k" "64q1s16k" \ + --sweep-param num_kv_splits \ + --sweep-values 1 2 4 8 16 \ + --output-json optimal_splits.json +``` + +#### Reorder Batch Threshold Optimization + +**Question:** What's the optimal `reorder_batch_threshold` for speculative decoding? + +```bash +python benchmark.py \ + --backend flashmla \ + --batch-specs "q4s1k" "q8s2k" \ + --sweep-param reorder_batch_threshold \ + --sweep-values 1 4 16 64 256 512 \ + --output-csv threshold_sweep.csv +``` + +### All Command-Line Options + +```text +--config CONFIG # Path to YAML config file (overrides other args) +--backends BACKEND [BACKEND ...] # flash, triton, flashinfer, cutlass_mla, + # flashinfer_mla, flashattn_mla, flashmla +--backend BACKEND # Single backend (alternative to --backends) +--batch-specs SPEC [SPEC ...] # Batch specifications using extended grammar + +# Model configuration +--num-layers N # Number of layers +--head-dim N # Head dimension +--num-q-heads N # Query heads +--num-kv-heads N # KV heads +--block-size N # Block size + +# Benchmark settings +--device DEVICE # Device (default: cuda:0) +--repeats N # Repetitions +--warmup-iters N # Warmup iterations +--profile-memory # Profile memory usage + +# Parameter sweeps +--sweep-param PARAM # Parameter name to sweep (e.g., num_kv_splits, + # reorder_batch_threshold) +--sweep-values N [N ...] # Values to sweep for the parameter + +# Output +--output-csv FILE # Save to CSV +--output-json FILE # Save to JSON +``` + +## Hardware Requirements + +| Backend | Hardware | +|---------|----------| +| Flash/Triton/FlashInfer | Any CUDA GPU | +| CUTLASS MLA | Blackwell (SM100+) | +| FlashAttn MLA | Hopper (SM90+) | +| FlashMLA | Hopper (SM90+) | +| FlashInfer-MLA | Any CUDA GPU | + +## Using MLA Runner Directly + +All MLA backends are available through `mla_runner.run_mla_benchmark()`: + +```python +from mla_runner import run_mla_benchmark +from common import BenchmarkConfig + +config = BenchmarkConfig( + backend="cutlass_mla", + batch_spec="64q1s4k", + num_layers=10, + head_dim=576, + num_q_heads=128, + num_kv_heads=1, + block_size=128, + device="cuda:0", + repeats=5, + warmup_iters=3, +) + +# CUTLASS MLA with specific num_kv_splits +result = run_mla_benchmark("cutlass_mla", config, num_kv_splits=4) +print(f"Time: {result.mean_time:.6f}s") + +# FlashInfer-MLA +result = run_mla_benchmark("flashinfer_mla", config) + +# FlashAttn MLA (Hopper SM90+) +result = run_mla_benchmark("flashattn_mla", config, reorder_batch_threshold=64) + +# FlashMLA (Hopper SM90+) +result = run_mla_benchmark("flashmla", config, reorder_batch_threshold=64) +``` + +## Python API + +```python +from batch_spec import parse_batch_spec, format_batch_spec, get_batch_stats +from common import BenchmarkConfig, BenchmarkResult, ResultsFormatter + +# Parse batch specs +requests = parse_batch_spec("2q2k_q4s1k_32q1s1k") +print(format_batch_spec(requests)) +# "2 prefill (2x2k), 1 extend (1xq4kv1k), 32 decode (32x1k)" + +# Get batch statistics +stats = get_batch_stats(requests) +print(f"Total tokens: {stats['total_tokens']}") +print(f"Num decode: {stats['num_decode']}, Num prefill: {stats['num_prefill']}") + +# Format results +formatter = ResultsFormatter() +formatter.save_csv(results, "output.csv") +formatter.save_json(results, "output.json") +``` + +## Tips + +**1. Warmup matters** - Use `--warmup-iters 10` for stable results + +**2. Multiple repeats** - Use `--repeats 20` for low variance + +**3. Save results** - Always use `--output-csv` or `--output-json` + +**4. Test incrementally** - Start with `--num-layers 1 --repeats 1` + +**5. Extended grammar** - Leverage spec decode, chunked prefill patterns + +**6. Parameter sweeps** - Use `--sweep-param` and `--sweep-values` to find optimal values diff --git a/benchmarks/attention_benchmarks/__init__.py b/benchmarks/attention_benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d21288700a5997ae8d0c5569f95d43f3c02a3fd --- /dev/null +++ b/benchmarks/attention_benchmarks/__init__.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""vLLM Attention Benchmarking Suite.""" + +from .batch_spec import ( + BatchRequest, + format_batch_spec, + get_batch_stats, + parse_batch_spec, + reorder_for_flashinfer, + split_by_type, +) +from .common import ( + BenchmarkConfig, + BenchmarkResult, + MockLayer, + ResultsFormatter, + get_attention_scale, + is_mla_backend, + setup_mla_dims, +) + +__all__ = [ + # Batch specification + "BatchRequest", + "parse_batch_spec", + "format_batch_spec", + "reorder_for_flashinfer", + "split_by_type", + "get_batch_stats", + # Benchmarking infrastructure + "BenchmarkConfig", + "BenchmarkResult", + "ResultsFormatter", + # Mock objects + "MockLayer", + # Utilities + "setup_mla_dims", + "get_attention_scale", + "is_mla_backend", +] diff --git a/benchmarks/attention_benchmarks/batch_spec.py b/benchmarks/attention_benchmarks/batch_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..9f15f1d8096e7b582db99f9e5537f7b4ac55c1b5 --- /dev/null +++ b/benchmarks/attention_benchmarks/batch_spec.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Simplified batch specification grammar for attention benchmarks. + +Grammar (underscore-separated segments): + Format: (?) q(k?) (s(k?))? + + - count: Number of identical requests (optional, default=1) + - q_len: Query length (number of new tokens) + - seq_len: Total sequence length (optional, defaults to q_len for prefill) + - 'k' suffix: Multiplies value by 1024 + +Common patterns: + - Prefill: q_len == seq_len (e.g., "q2k" → 2048 new tokens, 2048 seq) + - Decode: q_len == 1 (e.g., "q1s1k" → 1 token, 1024 seq length) + - Extend: q_len < seq_len (e.g., "q4s1k" → 4 tokens, 1024 seq length) + +Examples: + q2k -> [(2048, 2048)] # Prefill: 2048 tokens + q1s1k -> [(1, 1024)] # Decode: 1 token, 1K sequence + 8q1s1k -> [(1, 1024)] * 8 # 8 decode requests + q4s1k -> [(4, 1024)] # 4-token extend (spec decode) + 2q1k_32q1s1k -> [(1024, 1024)] * 2 + [(1, 1024)] * 32 # Mixed batch + 16q4s1k -> [(4, 1024)] * 16 # 16 spec decode requests +""" + +from collections import Counter +from dataclasses import dataclass + +import regex as re + + +@dataclass +class BatchRequest: + """Represents a single request in a batch.""" + + q_len: int # Query length (number of new tokens) + kv_len: int # Total KV cache length + + @property + def is_decode(self) -> bool: + """True if this is a decode request (q_len == 1).""" + return self.q_len == 1 + + @property + def is_prefill(self) -> bool: + """True if this is a pure prefill (q_len == kv_len).""" + return self.q_len == self.kv_len + + @property + def is_extend(self) -> bool: + """True if this is context extension (q_len > 1, kv_len > q_len).""" + return self.q_len > 1 and self.kv_len > self.q_len + + @property + def context_len(self) -> int: + """Context length (KV cache - query).""" + return self.kv_len - self.q_len + + def as_tuple(self) -> tuple[int, int]: + """Return as (q_len, kv_len) tuple for compatibility.""" + return (self.q_len, self.kv_len) + + +def _parse_size(size_str: str, k_suffix: str) -> int: + """Parse size string with optional 'k' suffix.""" + size = int(size_str) + return size * 1024 if k_suffix == "k" else size + + +def parse_batch_spec(spec: str) -> list[BatchRequest]: + """ + Parse batch specification string into list of BatchRequest objects. + + Grammar: (?) q(k?) (s(k?))? + + Args: + spec: Batch specification string (see module docstring for grammar) + + Returns: + List of BatchRequest objects + + Raises: + ValueError: If spec format is invalid + """ + requests = [] + + for seg in spec.split("_"): + # Unified pattern: (?) q(k?) (s(k?))? + m = re.match(r"^(?:(\d+))?q(\d+)(k?)(?:s(\d+)(k?))?$", seg) + if m: + cnt = int(m.group(1)) if m.group(1) else 1 + q_len = _parse_size(m.group(2), m.group(3)) + kv_len = _parse_size(m.group(4), m.group(5)) if m.group(4) else q_len + requests.extend([BatchRequest(q_len=q_len, kv_len=kv_len)] * cnt) + continue + + raise ValueError(f"Invalid batch spec segment: '{seg}'") + + return requests + + +def format_batch_spec(requests: list[BatchRequest]) -> str: + """ + Format list of BatchRequest into human-readable string. + + Groups requests by type and provides counts and sizes. + + Args: + requests: List of BatchRequest objects + + Returns: + Formatted string describing the batch + """ + kinds = { + "prefill": [], + "extend": [], + "decode": [], + } + + for req in requests: + tup = (req.q_len, req.kv_len) + if req.is_prefill: + kinds["prefill"].append(tup) + elif req.is_extend: + kinds["extend"].append(tup) + elif req.is_decode: + kinds["decode"].append(tup) + + parts = [] + for kind in ["prefill", "extend", "decode"]: + lst = kinds[kind] + if not lst: + continue + + cnt_total = len(lst) + ctr = Counter(lst) + inner = [] + + for (q, kv), cnt in ctr.items(): + if kind == "prefill": + size = f"{q // 1024}k" if q % 1024 == 0 else str(q) + inner.append(f"{cnt}x{size}") + elif kind == "decode": + size = f"{kv // 1024}k" if kv % 1024 == 0 else str(kv) + inner.append(f"{cnt}x{size}") + else: # extend + qstr = f"{q // 1024}k" if q % 1024 == 0 else str(q) + kstr = f"{kv // 1024}k" if kv % 1024 == 0 else str(kv) + inner.append(f"{cnt}xq{qstr}kv{kstr}") + + parts.append(f"{cnt_total} {kind} ({', '.join(inner)})") + + return ", ".join(parts) + + +def reorder_for_flashinfer(requests: list[BatchRequest]) -> list[BatchRequest]: + """ + Reorder requests for FlashInfer: decode first, then prefill. + + FlashInfer expects decode requests before prefill requests for + optimal performance. + + Args: + requests: Original list of BatchRequest + + Returns: + Reordered list with decode requests first + """ + decodes = [r for r in requests if r.is_decode] + non_decodes = [r for r in requests if not r.is_decode] + return decodes + non_decodes + + +def split_by_type( + requests: list[BatchRequest], +) -> dict[str, list[BatchRequest]]: + """ + Split requests by type for analysis. + + Args: + requests: List of BatchRequest + + Returns: + Dict with keys: 'decode', 'prefill', 'extend' + """ + result = { + "decode": [], + "prefill": [], + "extend": [], + } + + for req in requests: + if req.is_decode: + result["decode"].append(req) + elif req.is_prefill: + result["prefill"].append(req) + elif req.is_extend: + result["extend"].append(req) + + return result + + +def get_batch_stats(requests: list[BatchRequest]) -> dict: + """ + Compute statistics about a batch. + + Args: + requests: List of BatchRequest + + Returns: + Dict with batch statistics + """ + by_type = split_by_type(requests) + + return { + "total_requests": len(requests), + "num_decode": len(by_type["decode"]), + "num_prefill": len(by_type["prefill"]), + "num_extend": len(by_type["extend"]), + "total_tokens": sum(r.q_len for r in requests), + "total_kv_cache": sum(r.kv_len for r in requests), + "max_q_len": max((r.q_len for r in requests), default=0), + "max_kv_len": max((r.kv_len for r in requests), default=0), + "avg_q_len": sum(r.q_len for r in requests) / len(requests) if requests else 0, + "avg_kv_len": ( + sum(r.kv_len for r in requests) / len(requests) if requests else 0 + ), + } + + +def get_batch_type(batch_spec: str, spec_decode_threshold: int = 8) -> str: + """ + Classify a batch spec into a type string. + + Args: + batch_spec: Batch specification string (e.g., "q2k", "8q1s1k", "2q2k_8q1s1k") + spec_decode_threshold: Max q_len to be considered spec-decode vs extend + + Returns: + Type string: "prefill", "decode", "spec-decode", "extend", or "mixed (types...)" + """ + requests = parse_batch_spec(batch_spec) + + # Classify each request + types_present = set() + for req in requests: + if req.is_decode: + types_present.add("decode") + elif req.is_prefill: + types_present.add("prefill") + elif req.is_extend: + # Distinguish spec-decode (small q_len) from extend (chunked prefill) + if req.q_len <= spec_decode_threshold: + types_present.add("spec-decode") + else: + types_present.add("extend") + + if len(types_present) == 1: + return types_present.pop() + elif len(types_present) > 1: + # Sort for consistent output + sorted_types = sorted(types_present) + return f"mixed ({'+'.join(sorted_types)})" + else: + return "unknown" diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..de56cbac8474b4ceb05e44d6705adaed67be49ea --- /dev/null +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -0,0 +1,895 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Universal vLLM Attention Benchmark + +Benchmark any attention backend with the extended grammar. +Supports standard attention (Flash/Triton/FlashInfer) and MLA backends. + +Examples: + # Standard attention + python benchmark.py --backends flash flashinfer --batch-specs "q2k" "8q1s1k" + + # MLA backends + python benchmark.py --backends cutlass_mla flashinfer_mla --batch-specs "64q1s1k" + + # Parameter sweep (CLI) + python benchmark.py --backend cutlass_mla \ + --batch-specs "64q1s1k" \ + --sweep-param num_kv_splits \ + --sweep-values 1 4 8 16 + + # Parameter sweep (YAML config - recommended) + python benchmark.py --config configs/cutlass_numsplits.yaml +""" + +import argparse +import sys +from dataclasses import replace +from pathlib import Path + +import yaml +from rich.console import Console +from tqdm import tqdm + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from batch_spec import parse_batch_spec +from common import ( + BenchmarkConfig, + BenchmarkResult, + ModelParameterSweep, + ParameterSweep, + ResultsFormatter, + batch_spec_sort_key, + is_mla_backend, +) + + +def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: + """Run standard attention benchmark (Flash/Triton/FlashInfer).""" + from runner import run_attention_benchmark + + return run_attention_benchmark(config) + + +def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: + """Run MLA benchmark with appropriate backend.""" + from mla_runner import run_mla_benchmark as run_mla + + return run_mla(config.backend, config, **kwargs) + + +def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: + """ + Run a single benchmark with proper backend selection. + + Args: + config: BenchmarkConfig with backend, batch_spec, and model params + **kwargs: Additional arguments passed to MLA benchmarks + + Returns: + BenchmarkResult (may have error field set on failure) + """ + try: + if is_mla_backend(config.backend): + return run_mla_benchmark(config, **kwargs) + else: + return run_standard_attention_benchmark(config) + except Exception as e: + return BenchmarkResult( + config=config, + mean_time=float("inf"), + std_time=0, + min_time=float("inf"), + max_time=float("inf"), + error=str(e), + ) + + +def run_model_parameter_sweep( + backends: list[str], + batch_specs: list[str], + base_config_args: dict, + sweep: ModelParameterSweep, + console: Console, +) -> list[BenchmarkResult]: + """ + Run model parameter sweep for given backends and batch specs. + + Args: + backends: List of backend names + batch_specs: List of batch specifications + base_config_args: Base configuration arguments (num_layers, head_dim, etc.) + sweep: ModelParameterSweep configuration + console: Rich console for output + + Returns: + List of BenchmarkResult objects + """ + all_results = [] + + console.print( + f"[yellow]Model sweep mode: testing {sweep.param_name} = {sweep.values}[/]" + ) + + total = len(backends) * len(batch_specs) * len(sweep.values) + + with tqdm(total=total, desc="Benchmarking") as pbar: + for backend in backends: + for spec in batch_specs: + for value in sweep.values: + # Create config with modified model parameter + config_args = base_config_args.copy() + config_args[sweep.param_name] = value + + # Create config with original backend for running + clean_config = BenchmarkConfig( + backend=backend, batch_spec=spec, **config_args + ) + + # Run benchmark + result = run_benchmark(clean_config) + + # Replace backend with labeled version for display + backend_label = sweep.get_label(backend, value) + labeled_config = replace(result.config, backend=backend_label) + result = replace(result, config=labeled_config) + all_results.append(result) + + if not result.success: + console.print( + f"[red]Error {backend} {spec} {sweep.param_name}=" + f"{value}: {result.error}[/]" + ) + + pbar.update(1) + + # Display sweep results - create separate table for each parameter value + console.print("\n[bold green]Model Parameter Sweep Results:[/]") + formatter = ResultsFormatter(console) + + # Group results by parameter value and extract backend mapping + by_param_value = {} + backend_mapping = {} # Maps labeled backend -> original backend + + for r in all_results: + # Extract original backend and param value from labeled backend + # The label format is: {backend}_{param_name}_{value} + # We need to reverse engineer this + labeled_backend = r.config.backend + + # Try each backend to find which one this result belongs to + for backend in backends: + for value in sweep.values: + expected_label = sweep.get_label(backend, value) + if labeled_backend == expected_label: + backend_mapping[labeled_backend] = backend + param_value = str(value) + + if param_value not in by_param_value: + by_param_value[param_value] = [] + by_param_value[param_value].append(r) + break + + # Create a table for each parameter value + sorted_param_values = sorted( + by_param_value.keys(), key=lambda x: int(x) if x.isdigit() else x + ) + + for param_value in sorted_param_values: + console.print(f"\n[bold cyan]{sweep.param_name} = {param_value}[/]") + param_results = by_param_value[param_value] + + # Create modified results with original backend names + modified_results = [] + for r in param_results: + # Get the original backend name from our mapping + original_backend = backend_mapping[r.config.backend] + modified_config = replace(r.config, backend=original_backend) + modified_result = replace(r, config=modified_config) + modified_results.append(modified_result) + + # Print table with original backend names + formatter.print_table(modified_results, backends, compare_to_fastest=True) + + # Show optimal backend for each (param_value, batch_spec) combination + console.print( + f"\n[bold cyan]Optimal backend for each ({sweep.param_name}, batch_spec):[/]" + ) + + # Group by (param_value, batch_spec) + by_param_and_spec = {} + for r in all_results: + if r.success: + # Find which (backend, value) this result corresponds to + labeled_backend = r.config.backend + for backend in backends: + for value in sweep.values: + expected_label = sweep.get_label(backend, value) + if labeled_backend == expected_label: + param_value = str(value) + spec = r.config.batch_spec + key = (param_value, spec) + + if key not in by_param_and_spec: + by_param_and_spec[key] = [] + by_param_and_spec[key].append(r) + break + + # Sort by param value then spec (batch_size, q_len, kv_len) + sorted_keys = sorted( + by_param_and_spec.keys(), + key=lambda x: ( + int(x[0]) if x[0].isdigit() else x[0], + batch_spec_sort_key(x[1]), + ), + ) + + current_param_value = None + for param_value, spec in sorted_keys: + # Print header when param value changes + if param_value != current_param_value: + console.print(f"\n [bold]{sweep.param_name}={param_value}:[/]") + current_param_value = param_value + + results = by_param_and_spec[(param_value, spec)] + best = min(results, key=lambda r: r.mean_time) + + # Extract original backend name using the mapping + backend_name = backend_mapping[best.config.backend] + + # Show all backends' times for comparison + times_str = " | ".join( + [ + f"{backend_mapping[r.config.backend]}: {r.mean_time:.6f}s" + for r in sorted(results, key=lambda r: r.mean_time) + ] + ) + + console.print( + f" {spec:12s} -> [bold green]{backend_name:15s}[/] ({times_str})" + ) + + return all_results + + +def run_parameter_sweep( + backends: list[str], + batch_specs: list[str], + base_config_args: dict, + sweep: ParameterSweep, + console: Console, +) -> list[BenchmarkResult]: + """ + Run parameter sweep for given backends and batch specs. + + Args: + backends: List of backend names + batch_specs: List of batch specifications + base_config_args: Base configuration arguments (num_layers, head_dim, etc.) + sweep: ParameterSweep configuration + console: Rich console for output + + Returns: + List of BenchmarkResult objects + """ + all_results = [] + + # Build list of values to sweep (including auto if requested) + sweep_values = list(sweep.values) + if sweep.include_auto: + sweep_values.append("auto") + + console.print(f"[yellow]Sweep mode: testing {sweep.param_name} = {sweep_values}[/]") + + total = len(backends) * len(batch_specs) * len(sweep_values) + + with tqdm(total=total, desc="Benchmarking") as pbar: + for backend in backends: + for spec in batch_specs: + for value in sweep_values: + # Create config with original backend for running + config = BenchmarkConfig( + backend=backend, batch_spec=spec, **base_config_args + ) + + # Prepare kwargs for benchmark runner + kwargs = {} + if value != "auto": + kwargs[sweep.param_name] = value + + # Run benchmark + result = run_benchmark(config, **kwargs) + + # Replace backend with labeled version for display + backend_label = sweep.get_label(backend, value) + labeled_config = replace(result.config, backend=backend_label) + result = replace(result, config=labeled_config) + all_results.append(result) + + if not result.success: + console.print( + f"[red]Error {backend} {spec} {sweep.param_name}=" + f"{value}: {result.error}[/]" + ) + + pbar.update(1) + + # Display sweep results + console.print("\n[bold green]Sweep Results:[/]") + backend_labels = [sweep.get_label(b, v) for b in backends for v in sweep_values] + formatter = ResultsFormatter(console) + formatter.print_table(all_results, backend_labels) + + # Show optimal values + console.print(f"\n[bold cyan]Optimal {sweep.param_name} per batch spec:[/]") + by_spec = {} + for r in all_results: + if r.success: + spec = r.config.batch_spec + if spec not in by_spec: + by_spec[spec] = [] + by_spec[spec].append(r) + + for spec in sorted(by_spec.keys(), key=batch_spec_sort_key): + results = by_spec[spec] + best = min(results, key=lambda r: r.mean_time) + console.print( + f" {spec}: [bold green]{best.config.backend}[/] ({best.mean_time:.6f}s)" + ) + + return all_results + + +def load_config_from_yaml(config_path: str) -> dict: + """Load configuration from YAML file.""" + with open(config_path) as f: + return yaml.safe_load(f) + + +def generate_batch_specs_from_ranges(ranges: list[dict]) -> list[str]: + """ + Generate batch specs from range specifications. + + Args: + ranges: List of range specifications, each containing: + - template: Batch spec template (e.g., "q{q_len}kv1k") + - q_len: Dict with start, stop, step, end_inclusive (optional) + - Other parameters can also be ranges + + Returns: + List of generated batch spec strings + + Example: + ranges = [ + { + "template": "q{q_len}kv1k", + "q_len": { + "start": 1, + "stop": 16, + "step": 1, + "end_inclusive": true # Optional, defaults to true + } + } + ] + Returns: ["q1kv1k", "q2kv1k", ..., "q16kv1k"] + """ + all_specs = [] + + for range_spec in ranges: + template = range_spec.get("template") + if not template: + raise ValueError("Range specification must include 'template'") + + # Extract all range parameters from the spec + range_params = {} + for key, value in range_spec.items(): + if key == "template": + continue + if isinstance(value, dict) and "start" in value: + # This is a range specification + start = value["start"] + stop = value["stop"] + step = value.get("step", 1) + # Check if end should be inclusive (default: True) + end_inclusive = value.get("end_inclusive", True) + + # Adjust stop based on end_inclusive + if end_inclusive: + range_params[key] = list(range(start, stop + 1, step)) + else: + range_params[key] = list(range(start, stop, step)) + else: + # This is a fixed value + range_params[key] = [value] + + # Generate all combinations (Cartesian product) + if range_params: + import itertools + + param_names = list(range_params.keys()) + param_values = [range_params[name] for name in param_names] + + for values in itertools.product(*param_values): + params = dict(zip(param_names, values)) + spec = template.format(**params) + all_specs.append(spec) + else: + # No parameters, just use template as-is + all_specs.append(template) + + return all_specs + + +def main(): + parser = argparse.ArgumentParser( + description="Universal vLLM attention benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Config file + parser.add_argument( + "--config", + help="Path to YAML config file (overrides other args)", + ) + + # Backend selection + parser.add_argument( + "--backends", + nargs="+", + help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, " + "flashinfer_mla, flashattn_mla, flashmla)", + ) + parser.add_argument( + "--backend", + help="Single backend (alternative to --backends)", + ) + + # Batch specifications + parser.add_argument( + "--batch-specs", + nargs="+", + default=["q2k", "8q1s1k"], + help="Batch specifications using extended grammar", + ) + + # Model config + parser.add_argument("--num-layers", type=int, default=10, help="Number of layers") + parser.add_argument("--head-dim", type=int, default=128, help="Head dimension") + parser.add_argument("--num-q-heads", type=int, default=32, help="Query heads") + parser.add_argument("--num-kv-heads", type=int, default=8, help="KV heads") + parser.add_argument("--block-size", type=int, default=16, help="Block size") + + # Benchmark settings + parser.add_argument("--device", default="cuda:0", help="Device") + parser.add_argument("--repeats", type=int, default=1, help="Repetitions") + parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations") + parser.add_argument("--profile-memory", action="store_true", help="Profile memory") + + # Parameter sweep (use YAML config for advanced sweeps) + parser.add_argument( + "--sweep-param", + help="Parameter name to sweep (e.g., num_kv_splits, reorder_batch_threshold)", + ) + parser.add_argument( + "--sweep-values", + type=int, + nargs="+", + help="Values to sweep for the parameter", + ) + + # Output + parser.add_argument("--output-csv", help="Save to CSV") + parser.add_argument("--output-json", help="Save to JSON") + + args = parser.parse_args() + + console = Console() + console.print("[bold cyan]vLLM Attention Benchmark[/]") + + # Load config from YAML if provided + if args.config: + console.print(f"[yellow]Loading config from: {args.config}[/]") + yaml_config = load_config_from_yaml(args.config) + + # Show description if available + if "description" in yaml_config: + console.print(f"[dim]{yaml_config['description']}[/]") + + # Override args with YAML values, but CLI args take precedence + # Check if CLI provided backends (they would be non-None and not default) + cli_backends_provided = args.backends is not None or args.backend is not None + + # Backend(s) - only use YAML if CLI didn't specify + if not cli_backends_provided: + if "backend" in yaml_config: + args.backend = yaml_config["backend"] + args.backends = None + elif "backends" in yaml_config: + args.backends = yaml_config["backends"] + args.backend = None + + # Check for special modes + if "mode" in yaml_config: + args.mode = yaml_config["mode"] + else: + args.mode = None + + # Batch specs and sizes + # Support both explicit batch_specs and generated batch_spec_ranges + if "batch_spec_ranges" in yaml_config: + # Generate batch specs from ranges + generated_specs = generate_batch_specs_from_ranges( + yaml_config["batch_spec_ranges"] + ) + # Combine with any explicit batch_specs + if "batch_specs" in yaml_config: + args.batch_specs = yaml_config["batch_specs"] + generated_specs + else: + args.batch_specs = generated_specs + console.print( + f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]" + ) + elif "batch_specs" in yaml_config: + args.batch_specs = yaml_config["batch_specs"] + + if "batch_sizes" in yaml_config: + args.batch_sizes = yaml_config["batch_sizes"] + else: + args.batch_sizes = None + + # Model config + if "model" in yaml_config: + model = yaml_config["model"] + args.num_layers = model.get("num_layers", args.num_layers) + args.head_dim = model.get("head_dim", args.head_dim) + args.num_q_heads = model.get("num_q_heads", args.num_q_heads) + args.num_kv_heads = model.get("num_kv_heads", args.num_kv_heads) + args.block_size = model.get("block_size", args.block_size) + + # Benchmark settings (top-level keys) + if "device" in yaml_config: + args.device = yaml_config["device"] + if "repeats" in yaml_config: + args.repeats = yaml_config["repeats"] + if "warmup_iters" in yaml_config: + args.warmup_iters = yaml_config["warmup_iters"] + if "profile_memory" in yaml_config: + args.profile_memory = yaml_config["profile_memory"] + + # Parameter sweep configuration + if "parameter_sweep" in yaml_config: + sweep_config = yaml_config["parameter_sweep"] + args.parameter_sweep = ParameterSweep( + param_name=sweep_config["param_name"], + values=sweep_config["values"], + include_auto=sweep_config.get("include_auto", False), + label_format=sweep_config.get( + "label_format", "{backend}_{param_name}_{value}" + ), + ) + else: + args.parameter_sweep = None + + # Model parameter sweep configuration + if "model_parameter_sweep" in yaml_config: + sweep_config = yaml_config["model_parameter_sweep"] + args.model_parameter_sweep = ModelParameterSweep( + param_name=sweep_config["param_name"], + values=sweep_config["values"], + label_format=sweep_config.get( + "label_format", "{backend}_{param_name}_{value}" + ), + ) + else: + args.model_parameter_sweep = None + + # Output + if "output" in yaml_config: + output = yaml_config["output"] + if "csv" in output and not args.output_csv: + args.output_csv = output["csv"] + if "json" in output and not args.output_json: + args.output_json = output["json"] + + console.print() + + # Handle CLI-based parameter sweep (if not from YAML) + if ( + (not hasattr(args, "parameter_sweep") or args.parameter_sweep is None) + and args.sweep_param + and args.sweep_values + ): + args.parameter_sweep = ParameterSweep( + param_name=args.sweep_param, + values=args.sweep_values, + include_auto=False, + label_format="{backend}_{param_name}_{value}", + ) + + # Determine backends + backends = args.backends or ([args.backend] if args.backend else ["flash"]) + console.print(f"Backends: {', '.join(backends)}") + console.print(f"Batch specs: {', '.join(args.batch_specs)}") + console.print() + + # Run benchmarks + all_results = [] + + # Handle special mode: decode_vs_prefill comparison + if hasattr(args, "mode") and args.mode == "decode_vs_prefill": + console.print("[yellow]Mode: Decode vs Prefill pipeline comparison[/]") + console.print( + "[dim]For each query length, testing both decode and prefill pipelines[/]" + ) + console.print("[dim]Using batched execution for optimal performance[/]") + + # Extract batch sizes from config + batch_sizes = getattr(args, "batch_sizes", [1]) + backend = backends[0] # Use first backend (should only be one) + + # Calculate total benchmarks + total = len(batch_sizes) + + with tqdm(total=total, desc="Benchmarking") as pbar: + for batch_size in batch_sizes: + # Prepare all configs for this batch size + configs_with_thresholds = [] + + for spec in args.batch_specs: + # Parse the batch spec to get query length + requests = parse_batch_spec(spec) + if not requests: + console.print( + f"[red]Error: Could not parse batch spec '{spec}'[/]" + ) + continue + + # Get query length from first request + query_length = requests[0].q_len + + # Create batch spec for this batch size + # For batch_size > 1, we need to prepend the count + batch_spec = f"{batch_size}{spec}" if batch_size > 1 else spec + + # Create base config (without backend name) + base_config = BenchmarkConfig( + backend=backend, # Will be overridden later + batch_spec=batch_spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + ) + + # Add decode pipeline config + decode_threshold = query_length + config_decode = replace( + base_config, + backend=f"{backend}_decode_qlen{query_length}_bs{batch_size}", + ) + configs_with_thresholds.append((config_decode, decode_threshold)) + + # Add prefill pipeline config if query_length > 1 + if query_length > 1: + prefill_threshold = query_length - 1 + config_prefill = replace( + base_config, + backend=f"{backend}_prefill_qlen{query_length}" + f"_bs{batch_size}", + ) + configs_with_thresholds.append( + (config_prefill, prefill_threshold) + ) + + # Run all benchmarks for this batch size in one go (batched mode) + try: + from mla_runner import run_mla_benchmark as run_mla + + # Use batched API: pass list of (config, threshold) tuples + timing_results = run_mla(backend, configs_with_thresholds) + + # Create BenchmarkResult objects from timing results + for (config, _), timing in zip( + configs_with_thresholds, timing_results + ): + result = BenchmarkResult( + config=config, + mean_time=timing["mean"], + std_time=timing["std"], + min_time=timing["min"], + max_time=timing["max"], + throughput_tokens_per_sec=timing.get("throughput", None), + ) + all_results.append(result) + + except Exception as e: + import traceback + + console.print( + f"[red]Error running batched benchmarks for " + f"batch_size={batch_size}: {e}[/]" + ) + console.print("[red]Traceback:[/]") + traceback.print_exc() + # Add error results for all configs + for config, _ in configs_with_thresholds: + result = BenchmarkResult( + config=config, + mean_time=float("inf"), + std_time=0, + min_time=float("inf"), + max_time=float("inf"), + error=str(e), + ) + all_results.append(result) + + pbar.update(1) + + # Display decode vs prefill results + console.print("\n[bold green]Decode vs Prefill Results:[/]") + + # Group by batch size + by_batch_size = {} + for r in all_results: + if r.success: + # Extract batch size from backend name + parts = r.config.backend.split("_") + bs_part = [p for p in parts if p.startswith("bs")] + if bs_part: + bs = int(bs_part[0][2:]) + if bs not in by_batch_size: + by_batch_size[bs] = [] + by_batch_size[bs].append(r) + + # For each batch size, analyze crossover point + for bs in sorted(by_batch_size.keys()): + console.print(f"\n[bold cyan]Batch size: {bs}[/]") + results = by_batch_size[bs] + + # Group by query length + by_qlen = {} + for r in results: + parts = r.config.backend.split("_") + qlen_part = [p for p in parts if p.startswith("qlen")] + if qlen_part: + qlen = int(qlen_part[0][4:]) + if qlen not in by_qlen: + by_qlen[qlen] = {} + + pipeline = "decode" if "decode" in r.config.backend else "prefill" + by_qlen[qlen][pipeline] = r + + # Find crossover point + last_decode_faster = None + for qlen in sorted(by_qlen.keys()): + pipelines = by_qlen[qlen] + if "decode" in pipelines and "prefill" in pipelines: + decode_time = pipelines["decode"].mean_time + prefill_time = pipelines["prefill"].mean_time + faster = "decode" if decode_time < prefill_time else "prefill" + + speedup = ( + prefill_time / decode_time + if decode_time < prefill_time + else decode_time / prefill_time + ) + + console.print( + f" qlen={qlen:3d}: decode={decode_time:.6f}s, " + f"prefill={prefill_time:.6f}s -> " + f"[bold]{faster}[/] ({speedup:.2f}x)" + ) + + if faster == "decode": + last_decode_faster = qlen + + if last_decode_faster is not None: + optimal_threshold = last_decode_faster + console.print( + f"\n [bold green]Optimal threshold for batch_size={bs}: " + f"{optimal_threshold}[/]" + ) + console.print( + f" [dim](Use decode pipeline for query_length <= " + f"{optimal_threshold})[/]" + ) + else: + console.print( + f"\n [yellow]Prefill always faster for batch_size={bs}[/]" + ) + + # Handle model parameter sweep mode + elif hasattr(args, "model_parameter_sweep") and args.model_parameter_sweep: + # Model parameter sweep + base_config_args = { + "num_layers": args.num_layers, + "head_dim": args.head_dim, + "num_q_heads": args.num_q_heads, + "num_kv_heads": args.num_kv_heads, + "block_size": args.block_size, + "device": args.device, + "repeats": args.repeats, + "warmup_iters": args.warmup_iters, + "profile_memory": args.profile_memory, + } + all_results = run_model_parameter_sweep( + backends, + args.batch_specs, + base_config_args, + args.model_parameter_sweep, + console, + ) + + # Handle parameter sweep mode (unified) + elif hasattr(args, "parameter_sweep") and args.parameter_sweep: + # Unified parameter sweep + base_config_args = { + "num_layers": args.num_layers, + "head_dim": args.head_dim, + "num_q_heads": args.num_q_heads, + "num_kv_heads": args.num_kv_heads, + "block_size": args.block_size, + "device": args.device, + "repeats": args.repeats, + "warmup_iters": args.warmup_iters, + "profile_memory": args.profile_memory, + } + all_results = run_parameter_sweep( + backends, args.batch_specs, base_config_args, args.parameter_sweep, console + ) + + else: + # Normal mode: compare backends + total = len(backends) * len(args.batch_specs) + + with tqdm(total=total, desc="Benchmarking") as pbar: + for spec in args.batch_specs: + for backend in backends: + config = BenchmarkConfig( + backend=backend, + batch_spec=spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + ) + + result = run_benchmark(config) + all_results.append(result) + + if not result.success: + console.print(f"[red]Error {backend} {spec}: {result.error}[/]") + + pbar.update(1) + + # Display results + console.print("\n[bold green]Results:[/]") + formatter = ResultsFormatter(console) + formatter.print_table(all_results, backends) + + # Save results + if all_results: + formatter = ResultsFormatter(console) + if args.output_csv: + formatter.save_csv(all_results, args.output_csv) + if args.output_json: + formatter.save_json(all_results, args.output_json) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py new file mode 100644 index 0000000000000000000000000000000000000000..6bba93e502388a6d93f5bc1890db1b77b2f63bd2 --- /dev/null +++ b/benchmarks/attention_benchmarks/common.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Common utilities for attention benchmarking.""" + +import csv +import json +import math +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import torch +from batch_spec import get_batch_type, parse_batch_spec +from rich.console import Console +from rich.table import Table + + +def batch_spec_sort_key(spec: str) -> tuple[int, int, int]: + """ + Extract sorting key from batch spec: (batch_size, max_q_len, max_kv_len). + + This ensures results are sorted by batch size first, then query length, + then sequence length, rather than alphabetically. + """ + try: + requests = parse_batch_spec(spec) + batch_size = len(requests) + max_q_len = max(r.q_len for r in requests) if requests else 0 + max_kv_len = max(r.kv_len for r in requests) if requests else 0 + return (batch_size, max_q_len, max_kv_len) + except Exception: + # Fallback for unparseable specs + return (0, 0, 0) + + +# Mock classes for vLLM attention infrastructure + + +class MockHfConfig: + """Mock HuggingFace config that satisfies vLLM's requirements.""" + + def __init__(self, mla_dims: dict, index_topk: int | None = None): + self.num_attention_heads = mla_dims["num_q_heads"] + self.num_key_value_heads = mla_dims["num_kv_heads"] + self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"] + self.model_type = "deepseek_v2" + self.is_encoder_decoder = False + self.kv_lora_rank = mla_dims["kv_lora_rank"] + self.qk_nope_head_dim = mla_dims["qk_nope_head_dim"] + self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"] + self.v_head_dim = mla_dims["v_head_dim"] + self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"] + if index_topk is not None: + self.index_topk = index_topk + + def get_text_config(self): + return self + + +# Import AttentionLayerBase at module level to avoid circular dependencies +try: + from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +except ImportError: + AttentionLayerBase = object # Fallback + + +class MockKVBProj: + """Mock KV projection layer for MLA prefill mode. + + Mimics ColumnParallelLinear behavior for kv_b_proj in MLA backends. + Projects kv_c_normed to [qk_nope_head_dim + v_head_dim] per head. + """ + + def __init__(self, num_heads: int, qk_nope_head_dim: int, v_head_dim: int): + self.num_heads = num_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.v_head_dim = v_head_dim + self.out_dim = qk_nope_head_dim + v_head_dim + + def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]: + """ + Project kv_c_normed to output space. + + Args: + x: Input tensor [num_tokens, kv_lora_rank] + + Returns: + Tuple containing output tensor + [num_tokens, num_heads, qk_nope_head_dim + v_head_dim] + """ + num_tokens = x.shape[0] + result = torch.randn( + num_tokens, + self.num_heads, + self.out_dim, + device=x.device, + dtype=x.dtype, + ) + return (result,) # Return as tuple to match ColumnParallelLinear API + + +class MockIndexer: + """Mock Indexer for sparse MLA backends. + + Provides topk_indices_buffer that sparse MLA backends use to determine + which KV cache slots to attend to for each token. + """ + + def __init__( + self, + max_num_tokens: int, + topk_tokens: int, + device: torch.device, + ): + self.topk_tokens = topk_tokens + self.topk_indices_buffer = torch.zeros( + (max_num_tokens, topk_tokens), + dtype=torch.int32, + device=device, + ) + + def fill_random_indices(self, num_tokens: int, max_kv_len: int): + """Fill topk_indices_buffer with random valid indices for benchmarking.""" + indices = torch.randint( + 0, + max_kv_len, + (num_tokens, self.topk_tokens), + dtype=torch.int32, + device=self.topk_indices_buffer.device, + ) + self.topk_indices_buffer[:num_tokens] = indices + + +class MockLayer(AttentionLayerBase): + """Mock attention layer with scale parameters and impl. + + Inherits from AttentionLayerBase so it passes isinstance checks + in get_layers_from_vllm_config when FlashInfer prefill is enabled. + """ + + def __init__(self, device: torch.device, impl=None, kv_cache_spec=None): + # Don't call super().__init__() as AttentionLayerBase doesn't have __init__ + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + self._q_scale = torch.tensor(1.0, device=device) + # Scalar floats for kernels that need them + self._k_scale_float = float(self._k_scale.item()) + self._v_scale_float = float(self._v_scale.item()) + self._q_scale_float = float(self._q_scale.item()) + # AttentionImpl for metadata builders to query + self.impl = impl + # KV cache spec for get_kv_cache_spec + self._kv_cache_spec = kv_cache_spec + + def get_attn_backend(self): + """Get the attention backend class (required by AttentionLayerBase).""" + # Return None as this is just a mock layer for benchmarking + return None + + def get_kv_cache_spec(self): + """Get the KV cache spec (required by AttentionLayerBase).""" + return self._kv_cache_spec + + +@dataclass +class ParameterSweep: + """Configuration for sweeping a backend parameter.""" + + param_name: str # Name of the backend parameter to sweep + values: list[Any] # List of values to test + include_auto: bool = False # Also test with param unset (auto mode) + label_format: str = "{backend}_{param_name}_{value}" # Result label template + + def get_label(self, backend: str, value: Any) -> str: + """Generate a label for a specific parameter value.""" + return self.label_format.format( + backend=backend, param_name=self.param_name, value=value + ) + + +@dataclass +class ModelParameterSweep: + """Configuration for sweeping a model configuration parameter.""" + + param_name: str # Name of the model config parameter to sweep (e.g., "num_q_heads") + values: list[Any] # List of values to test + label_format: str = "{backend}_{param_name}_{value}" # Result label template + + def get_label(self, backend: str, value: Any) -> str: + """Generate a label for a specific parameter value.""" + return self.label_format.format( + backend=backend, param_name=self.param_name, value=value + ) + + +@dataclass +class BenchmarkConfig: + """Configuration for a single benchmark run.""" + + backend: str + batch_spec: str + num_layers: int + head_dim: int + num_q_heads: int + num_kv_heads: int + block_size: int + device: str + dtype: torch.dtype = torch.float16 + repeats: int = 1 + warmup_iters: int = 3 + profile_memory: bool = False + use_cuda_graphs: bool = False + + # MLA-specific + kv_lora_rank: int | None = None + qk_nope_head_dim: int | None = None + qk_rope_head_dim: int | None = None + v_head_dim: int | None = None + + # Backend-specific tuning + num_kv_splits: int | None = None # CUTLASS MLA + reorder_batch_threshold: int | None = None # FlashAttn MLA, FlashMLA + + +@dataclass +class BenchmarkResult: + """Results from a single benchmark run.""" + + config: BenchmarkConfig + mean_time: float # seconds + std_time: float # seconds + min_time: float # seconds + max_time: float # seconds + throughput_tokens_per_sec: float | None = None + memory_allocated_mb: float | None = None + memory_reserved_mb: float | None = None + error: str | None = None + + @property + def success(self) -> bool: + """Whether benchmark completed successfully.""" + return self.error is None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "config": asdict(self.config), + "mean_time": self.mean_time, + "std_time": self.std_time, + "min_time": self.min_time, + "max_time": self.max_time, + "throughput_tokens_per_sec": self.throughput_tokens_per_sec, + "memory_allocated_mb": self.memory_allocated_mb, + "memory_reserved_mb": self.memory_reserved_mb, + "error": self.error, + } + + +class ResultsFormatter: + """Format and display benchmark results.""" + + def __init__(self, console: Console | None = None): + self.console = console or Console() + + def print_table( + self, + results: list[BenchmarkResult], + backends: list[str], + compare_to_fastest: bool = True, + ): + """ + Print results as a rich table. + + Args: + results: List of BenchmarkResult + backends: List of backend names being compared + compare_to_fastest: Show percentage comparison to fastest + """ + # Group by batch spec, preserving first-occurrence order + by_spec = {} + specs_order = [] + for r in results: + spec = r.config.batch_spec + if spec not in by_spec: + by_spec[spec] = {} + specs_order.append(spec) + by_spec[spec][r.config.backend] = r + + # Sort specs by (batch_size, q_len, kv_len) instead of alphabetically + specs_order = sorted(by_spec.keys(), key=batch_spec_sort_key) + + # Create shortened backend names for display + def shorten_backend_name(name: str) -> str: + """Shorten long backend names for table display.""" + # Remove common prefixes + name = name.replace("flashattn_mla", "famla") + name = name.replace("flashinfer_mla", "fimla") + name = name.replace("flashmla", "fmla") + name = name.replace("cutlass_mla", "cmla") + name = name.replace("numsplits", "ns") + return name + + table = Table(title="Attention Benchmark Results") + table.add_column("Batch\nSpec", no_wrap=True) + table.add_column("Type", no_wrap=True) + table.add_column("Batch\nSize", justify="right", no_wrap=True) + + multi = len(backends) > 1 + for backend in backends: + short_name = shorten_backend_name(backend) + # Time column + col_time = f"{short_name}\nTime (s)" + table.add_column(col_time, justify="right", no_wrap=False) + if multi and compare_to_fastest: + # Relative performance column + col_rel = f"{short_name}\nvs Best" + table.add_column(col_rel, justify="right", no_wrap=False) + + # Add rows + for spec in specs_order: + spec_results = by_spec[spec] + times = {b: r.mean_time for b, r in spec_results.items() if r.success} + best_time = min(times.values()) if times else 0.0 + + batch_type = get_batch_type(spec) + batch_size = len(parse_batch_spec(spec)) + row = [spec, batch_type, str(batch_size)] + for backend in backends: + if backend in spec_results: + r = spec_results[backend] + if r.success: + row.append(f"{r.mean_time:.6f}") + if multi and compare_to_fastest: + pct = ( + (r.mean_time / best_time * 100) if best_time > 0 else 0 + ) + pct_str = f"{pct:.1f}%" + if r.mean_time == best_time: + pct_str = f"[bold green]{pct_str}[/]" + row.append(pct_str) + else: + row.append("[red]ERROR[/]") + if multi and compare_to_fastest: + row.append("-") + else: + row.append("-") + if multi and compare_to_fastest: + row.append("-") + + table.add_row(*row) + + self.console.print(table) + + def save_csv(self, results: list[BenchmarkResult], path: str): + """Save results to CSV file.""" + if not results: + return + + path_obj = Path(path) + path_obj.parent.mkdir(parents=True, exist_ok=True) + + with open(path, "w", newline="") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "backend", + "batch_spec", + "num_layers", + "mean_time", + "std_time", + "throughput", + "memory_mb", + ], + ) + writer.writeheader() + for r in results: + writer.writerow( + { + "backend": r.config.backend, + "batch_spec": r.config.batch_spec, + "num_layers": r.config.num_layers, + "mean_time": r.mean_time, + "std_time": r.std_time, + "throughput": r.throughput_tokens_per_sec or 0, + "memory_mb": r.memory_allocated_mb or 0, + } + ) + + self.console.print(f"[green]Saved CSV results to {path}[/]") + + def save_json(self, results: list[BenchmarkResult], path: str): + """Save results to JSON file.""" + path_obj = Path(path) + path_obj.parent.mkdir(parents=True, exist_ok=True) + + data = [r.to_dict() for r in results] + with open(path, "w") as f: + json.dump(data, f, indent=2, default=str) + + self.console.print(f"[green]Saved JSON results to {path}[/]") + + +def setup_mla_dims(model_name: str = "deepseek-v3") -> dict: + """ + Get MLA dimensions for known models. + + Args: + model_name: Model identifier + + Returns: + Dict with MLA dimension configuration + """ + configs = { + "deepseek-v2": { + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "num_q_heads": 128, + "num_kv_heads": 1, + "head_dim": 576, + }, + "deepseek-v3": { + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "num_q_heads": 128, + "num_kv_heads": 1, + "head_dim": 576, + }, + "deepseek-v2-lite": { + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "num_q_heads": 16, + "num_kv_heads": 1, + "head_dim": 576, + }, + } + + if model_name not in configs: + raise ValueError( + f"Unknown model '{model_name}'. Known models: {list(configs.keys())}" + ) + + return configs[model_name] + + +def get_attention_scale(head_dim: int) -> float: + """Compute attention scale factor (1/sqrt(d)).""" + return 1.0 / math.sqrt(head_dim) + + +def is_mla_backend(backend: str) -> bool: + """ + Check if backend is an MLA backend using the AttentionBackendEnum. + + Args: + backend: Backend name matching AttentionBackendEnum exactly + (e.g., "FLASHMLA_SPARSE") + + Returns: + True if the backend is an MLA backend, False otherwise + """ + from vllm.v1.attention.backends.registry import AttentionBackendEnum + + try: + backend_enum = AttentionBackendEnum[backend] + backend_class = backend_enum.get_class() + return backend_class.is_mla() + except (KeyError, ValueError, ImportError, AttributeError): + return False diff --git a/benchmarks/attention_benchmarks/configs/mla_decode.yaml b/benchmarks/attention_benchmarks/configs/mla_decode.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d758654dbe802e391f5c84f9b067fab40f035564 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_decode.yaml @@ -0,0 +1,70 @@ +# MLA decode-only benchmark configuration + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 # Base value, can be swept for TP simulation + num_kv_heads: 1 # MLA uses single latent KV + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128 + +# Model parameter sweep: simulate tensor parallelism by varying num_q_heads +# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads +model_parameter_sweep: + param_name: "num_q_heads" + values: [128, 64, 32, 16] + label_format: "{backend}_{value}h" + +batch_specs: + # Small batches, varying sequence lengths + - "16q1s512" # 16 requests, 512 KV cache + - "16q1s1k" # 16 requests, 1k KV cache + - "16q1s2k" # 16 requests, 2k KV cache + - "16q1s4k" # 16 requests, 4k KV cache + + # Medium batches + - "32q1s1k" # 32 requests, 1k KV cache + - "32q1s2k" # 32 requests, 2k KV cache + - "32q1s4k" # 32 requests, 4k KV cache + - "32q1s8k" # 32 requests, 8k KV cache + + # Large batches + - "64q1s1k" # 64 requests, 1k KV cache + - "64q1s2k" # 64 requests, 2k KV cache + - "64q1s4k" # 64 requests, 4k KV cache + - "64q1s8k" # 64 requests, 8k KV cache + + # Very large batches + - "128q1s1k" # 128 requests, 1k KV cache + - "128q1s2k" # 128 requests, 2k KV cache + - "128q1s4k" # 128 requests, 4k KV cache + - "128q1s8k" # 128 requests, 8k KV cache + + # Long context + - "32q1s16k" # 32 requests, 16k KV cache + - "32q1s32k" # 32 requests, 32k KV cache + +backends: + - CUTLASS_MLA + - FLASHINFER_MLA + - FLASH_ATTN_MLA # Hopper only + - FLASHMLA # Hopper only + +device: "cuda:0" +repeats: 100 +warmup_iters: 10 +profile_memory: true + +# Backend-specific tuning +CUTLASS_MLA: + num_kv_splits: auto # or specific value like 4, 8, 16 + +FLASH_ATTN_MLA: + reorder_batch_threshold: 512 + +FLASHMLA: + reorder_batch_threshold: 1 diff --git a/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b555d90cbf6296f376118f4c7499b01925d2c2bf --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml @@ -0,0 +1,60 @@ +# MLA mixed batch benchmark (prefill + decode) +# Tests chunked prefill performance + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 + +batch_specs: + # Small prefill + decode + - "1q1k_8q1s1k" # 1 prefill + 8 decode + - "2q2k_16q1s1k" # 2 prefill + 16 decode + - "4q1k_32q1s2k" # 4 prefill + 32 decode + + # Medium prefill + decode + - "2q4k_32q1s2k" # 2 medium prefill + 32 decode + - "4q4k_64q1s2k" # 4 medium prefill + 64 decode + - "8q2k_64q1s4k" # 8 prefill + 64 decode + + # Large prefill + decode (chunked prefill stress test) + - "2q8k_32q1s1k" # 2 large prefill + 32 decode + - "1q16k_16q1s2k" # 1 very large prefill + 16 decode + - "2q16k_32q1s4k" # 2 very large prefill + 32 decode + + # Context extension + decode + - "2q1kkv2k_16q1s1k" # 2 extend + 16 decode + - "4q2kkv4k_32q1s2k" # 4 extend + 32 decode + - "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode + + # Explicitly chunked prefill + - "q8k" # 8k prefill with chunking hint + - "q16k" # 16k prefill with chunking hint + - "2q8k_32q1s2k" # 2 chunked prefill + 32 decode + + # High decode ratio (realistic serving) + - "1q2k_63q1s1k" # 1 prefill + 63 decode + - "2q2k_62q1s2k" # 2 prefill + 62 decode + - "4q4k_60q1s4k" # 4 prefill + 60 decode + +backends: + - CUTLASS_MLA + - FLASHINFER_MLA + - FLASH_ATTN_MLA # Hopper only + - FLASHMLA # Hopper only + +device: "cuda:0" +repeats: 5 +warmup_iters: 3 +profile_memory: true + +# Analyze chunked prefill workspace size impact +chunked_prefill: + test_workspace_sizes: [4096, 8192, 16384, 32768, 65536] diff --git a/benchmarks/attention_benchmarks/configs/mla_prefill.yaml b/benchmarks/attention_benchmarks/configs/mla_prefill.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef6b2cb07dc70192ff428adaa0b18e32f0941e7e --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_prefill.yaml @@ -0,0 +1,62 @@ +# MLA prefill-only benchmark configuration for sparse backends + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 + +# Model parameter sweep: simulate tensor parallelism by varying num_q_heads +# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads +model_parameter_sweep: + param_name: "num_q_heads" + values: [128, 64, 32, 16] + label_format: "{backend}_{value}h" + +batch_specs: + # Pure prefill + - "1q512" + - "1q1k" + - "1q2k" + - "1q4k" + - "1q8k" + + # Batched pure prefill + - "2q512" + - "2q1k" + - "2q2k" + - "2q4k" + - "2q8k" + - "4q512" + - "4q1k" + - "4q2k" + - "4q4k" + - "4q8k" + - "8q512" + - "8q1k" + - "8q2k" + - "8q4k" + - "8q8k" + + # Extend + - "1q512s4k" + - "1q512s8k" + - "1q1ks8k" + - "1q2ks8k" + - "1q2ks16k" + - "1q4ks16k" + +backends: + - FLASHMLA_SPARSE + - FLASHINFER_MLA_SPARSE + +device: "cuda:0" +repeats: 10 +warmup_iters: 3 +profile_memory: true diff --git a/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml b/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d76ef0a358ca7584676cd3cfedf8982cd0b7b46 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml @@ -0,0 +1,87 @@ +# Study 4: What is optimal reorder_batch_threshold for MLA backends supporting query length > 1? +# Question: At what query length does prefill pipeline become faster than decode pipeline? +# Methodology: For each query length, compare decode vs prefill performance to find crossover point +# Applies to: FlashAttn MLA, FlashMLA + +description: "Decode vs Prefill pipeline crossover analysis" + +# Test FlashAttn MLA +backend: FLASH_ATTN_MLA + +# Mode: decode_vs_prefill comparison (special sweep mode) +# For each batch spec, we'll test both decode and prefill pipelines +mode: "decode_vs_prefill" + +# Query lengths to test (from old benchmark_mla_threshold.py methodology) +# Each query length will be tested with BOTH decode and prefill pipelines: +# - decode: threshold >= query_length (forces decode pipeline) +# - prefill: threshold < query_length (forces prefill pipeline) +# +# We use qs1k format which creates q_len=N, seq_len=1024 requests +# This tests different query lengths with fixed sequence length context +# +# Using batch_spec_ranges for automatic generation: +batch_spec_ranges: + - template: "q{q_len}s1k" + q_len: + start: 1 + stop: 16 + step: 1 + end_inclusive: false + - template: "q{q_len}s1k" + q_len: + start: 16 + stop: 64 + step: 2 + end_inclusive: false + - template: "q{q_len}s1k" + q_len: + start: 64 + stop: 1024 + step: 4 + end_inclusive: true + +# Batch sizes to test (from old script) +batch_sizes: + - 1 + - 2 + - 4 + - 8 + - 16 + - 32 + - 64 + - 128 + - 256 + +# Model configuration (DeepSeek V2/V3 defaults) +model: + num_layers: 10 + head_dim: 576 + num_q_heads: 128 + num_kv_heads: 1 + block_size: 128 + +# Benchmark settings +device: "cuda:0" +repeats: 15 # More repeats for spec decode variance +warmup_iters: 5 +profile_memory: false + +# Output +output: + csv: "reorder_threshold_results.csv" + json: "reorder_threshold_results.json" + +# Expected outcome (reproduces old benchmark_mla_threshold.py study): +# - For each batch size, find the crossover point where prefill becomes faster than decode +# - Show decode vs prefill performance across all query lengths +# - Determine optimal reorder_batch_threshold based on last query length where decode is faster +# - Understand how crossover point varies with batch size +# - Provide data-driven guidance for default threshold value +# +# Methodology (from old script): +# - Each query length tested with BOTH pipelines: +# * decode: threshold >= query_length (forces decode pipeline) +# * prefill: threshold < query_length (forces prefill pipeline) +# - Compare which is faster to find crossover point +# diff --git a/benchmarks/attention_benchmarks/configs/speculative_decode.yaml b/benchmarks/attention_benchmarks/configs/speculative_decode.yaml new file mode 100644 index 0000000000000000000000000000000000000000..47b6d3604d1d256dcbfd9181cb6a8a2817f8dded --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/speculative_decode.yaml @@ -0,0 +1,61 @@ +# Speculative decoding benchmark configuration +# Tests reorder_batch_threshold optimization + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + +batch_specs: + # Pure speculative decode (K-token verification) + - "q2s1k" # 2-token spec, 1k KV + - "q4s1k" # 4-token spec, 1k KV + - "q8s1k" # 8-token spec, 1k KV + - "q16s1k" # 16-token spec, 1k KV + + # Speculative with different context lengths + - "q4s2k" # 4-token spec, 2k KV + - "q4s4k" # 4-token spec, 4k KV + - "q8s2k" # 8-token spec, 2k KV + - "q8s4k" # 8-token spec, 4k KV + + # Mixed: speculative + regular decode + - "32q4s1k" # 32 spec requests + - "16q4s1k_16q1s1k" # 16 spec + 16 regular + - "8q8s2k_24q1s2k" # 8 spec (8-tok) + 24 regular + + # Mixed: speculative + prefill + decode + - "2q1k_16q4s1k_16q1s1k" # 2 prefill + 16 spec + 16 decode + - "4q2k_32q4s2k_32q1s2k" # 4 prefill + 32 spec + 32 decode + + # Large batches with speculation + - "64q4s1k" # 64 spec requests + - "32q8s2k" # 32 spec (8-token) + - "16q16s4k" # 16 spec (16-token) + +# Backends that support query length > 1 +backends: + - FLASH_ATTN_MLA # reorder_batch_threshold = 512 + - FLASHMLA # reorder_batch_threshold = 1 (tunable) + +# FlashInfer-MLA also supports uniform spec-as-decode but with different mechanism +# - FLASHINFER_MLA + +# Benchmark settings +device: "cuda:0" +repeats: 10 # More repeats for statistical significance +warmup_iters: 5 +profile_memory: false + +# Test these threshold values for optimization +parameter_sweep: + param_name: "reorder_batch_threshold" + values: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + include_auto: false + label_format: "{backend}_threshold_{value}" diff --git a/benchmarks/attention_benchmarks/configs/standard_attention.yaml b/benchmarks/attention_benchmarks/configs/standard_attention.yaml new file mode 100644 index 0000000000000000000000000000000000000000..deb5a4b27ff3fc4362de880b65372e3814abbf5d --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/standard_attention.yaml @@ -0,0 +1,48 @@ +# Standard attention backend benchmark configuration + +model: + num_layers: 32 + num_q_heads: 32 + num_kv_heads: 8 # GQA with 4:1 ratio + head_dim: 128 + block_size: 16 + +batch_specs: + # Pure prefill + - "q512" # Small prefill (512 tokens) + - "q2k" # Medium prefill (2048 tokens) + - "q4k" # Large prefill (4096 tokens) + - "q8k" # Very large prefill (8192 tokens) + + # Pure decode + - "8q1s1k" # 8 requests, 1k KV cache each + - "16q1s2k" # 16 requests, 2k KV cache each + - "32q1s1k" # 32 requests, 1k KV cache each + - "64q1s4k" # 64 requests, 4k KV cache each + + # Mixed prefill/decode + - "2q2k_8q1s1k" # 2 prefill + 8 decode + - "4q1k_16q1s2k" # 4 prefill + 16 decode + - "2q4k_32q1s1k" # 2 large prefill + 32 decode + + # Speculative decode (q <= 8) + - "16q2s1k" # 16 requests, 2 spec tokens, 1k KV cache + - "16q4s1k" # 16 requests, 4 spec tokens, 1k KV cache + - "16q8s1k" # 16 requests, 8 spec tokens, 1k KV cache + - "32q4s2k" # 32 requests, 4 spec tokens, 2k KV cache + - "8q8s4k" # 8 requests, 8 spec tokens, 4k KV cache + + # Context extension (chunked prefill) + - "q1ks2k" # 1k query, 2k sequence + - "2q1ks4k" # 2 requests: 1k query, 4k sequence + +# Available backends: FLASH_ATTN, TRITON_ATTN, FLASHINFER +backends: + - FLASH_ATTN + - TRITON_ATTN + - FLASHINFER + +device: "cuda:0" +repeats: 5 +warmup_iters: 3 +profile_memory: false diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..ffcfa457217a4fbce0ac698218157bc2336769de --- /dev/null +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -0,0 +1,891 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MLA benchmark runner - shared utilities for MLA benchmarks. + +This module provides helpers for running MLA backends without +needing full VllmConfig integration. +""" + +import numpy as np +import torch +from batch_spec import parse_batch_spec +from common import ( + BenchmarkResult, + MockHfConfig, + MockIndexer, + MockKVBProj, + MockLayer, + setup_mla_dims, +) + +from vllm.config import ( + CacheConfig, + CompilationConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) + +# ============================================================================ +# VllmConfig Creation +# ============================================================================ + + +def _add_mock_methods_to_model_config(model_config: ModelConfig) -> None: + """ + Add mock methods for layer-specific queries to ModelConfig. + + These methods are needed by metadata builders but aren't normally + present on ModelConfig when used in benchmark contexts. + """ + import types + + model_config.get_num_layers = types.MethodType(lambda self: 1, model_config) + model_config.get_sliding_window_for_layer = types.MethodType( + lambda self, _i: None, model_config + ) + model_config.get_logits_soft_cap_for_layer = types.MethodType( + lambda self, _i: None, model_config + ) + model_config.get_sm_scale_for_layer = types.MethodType( + lambda self, _i: 1.0 / model_config.get_head_size() ** 0.5, model_config + ) + + +def create_minimal_vllm_config( + model_name: str = "deepseek-v3", + block_size: int = 128, + max_num_seqs: int = 256, + mla_dims: dict | None = None, + index_topk: int | None = None, +) -> VllmConfig: + """ + Create minimal VllmConfig for MLA benchmarks. + + Args: + model_name: Model name (deepseek-v2, deepseek-v3, etc.) - used if mla_dims not + provided + block_size: KV cache block size + max_num_seqs: Maximum number of sequences + mla_dims: Optional custom MLA dimensions dict. If not provided, uses + setup_mla_dims(model_name) + index_topk: Optional topk value for sparse MLA backends. If provided, + the config will include index_topk for sparse attention. + + Returns: + VllmConfig for benchmarking + """ + # Get MLA dimensions - use provided or load from model name + if mla_dims is None: + mla_dims = setup_mla_dims(model_name) + + # Create mock HF config first (avoids downloading from HuggingFace) + mock_hf_config = MockHfConfig(mla_dims, index_topk=index_topk) + + # Create a temporary minimal config.json to avoid HF downloads + # This ensures consistent ModelConfig construction without network access + import json + import os + import shutil + import tempfile + + minimal_config = { + "architectures": ["DeepseekV2ForCausalLM"], + "model_type": "deepseek_v2", + "num_attention_heads": mla_dims["num_q_heads"], + "num_key_value_heads": mla_dims["num_kv_heads"], + "hidden_size": mla_dims["head_dim"] * mla_dims["num_q_heads"], + "torch_dtype": "bfloat16", + "max_position_embeddings": 163840, # DeepSeek V3 default + "rope_theta": 10000.0, + "vocab_size": 128256, + } + + # Create temporary directory with config.json + temp_dir = tempfile.mkdtemp(prefix="vllm_bench_") + config_path = os.path.join(temp_dir, "config.json") + with open(config_path, "w") as f: + json.dump(minimal_config, f) + + try: + # Create model config using local path - no HF downloads + model_config = ModelConfig( + model=temp_dir, # Use local temp directory + tokenizer=None, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + max_model_len=32768, + quantization=None, + enforce_eager=False, + max_logprobs=20, + disable_sliding_window=False, + skip_tokenizer_init=True, + served_model_name=None, + limit_mm_per_prompt=None, + config_format="auto", + ) + finally: + # Clean up temporary directory + shutil.rmtree(temp_dir, ignore_errors=True) + + # Override with our mock config + model_config.hf_config = mock_hf_config + model_config.hf_text_config = mock_hf_config + + # Add mock methods for layer-specific queries + _add_mock_methods_to_model_config(model_config) + + # Create sub-configs + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=False, + ) + + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=8192, + max_model_len=32768, + is_encoder_decoder=False, + enable_chunked_prefill=True, + ) + + parallel_config = ParallelConfig( + tensor_parallel_size=1, + ) + + compilation_config = CompilationConfig() + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + compilation_config=compilation_config, + ) + + +# ============================================================================ +# Backend Configuration +# ============================================================================ + + +# Backend-specific properties that can't be inferred from the backend class +# Keys are AttentionBackendEnum names (uppercase) +_BACKEND_PROPERTIES = { + "FLASHMLA": { + "query_format": "concat", # Single concatenated tensor (vs tuple) + }, + "FLASHMLA_SPARSE": { + "query_format": "concat", # Single concatenated tensor (vs tuple) + }, +} + + +def _get_backend_config(backend: str) -> dict: + """ + Get backend configuration from AttentionBackendEnum. + + Uses the registry to get the backend class and extract configuration + from its methods (get_impl_cls, get_builder_cls, is_sparse, etc.). + + Args: + backend: Backend name matching AttentionBackendEnum exactly + (e.g., "FLASHMLA_SPARSE") + + Returns: + Dict with backend configuration + """ + from vllm.v1.attention.backends.registry import AttentionBackendEnum + + try: + backend_enum = AttentionBackendEnum[backend] + backend_class = backend_enum.get_class() + except (KeyError, ValueError) as e: + valid_backends = [e.name for e in AttentionBackendEnum if e.name != "CUSTOM"] + raise ValueError( + f"Unknown backend: {backend}. " + f"Valid MLA backends: {[b for b in valid_backends if 'MLA' in b]}" + ) from e + + # Get block size from backend class + block_sizes = backend_class.get_supported_kernel_block_sizes() + # Use first supported block size (backends typically support one for MLA) + block_size = block_sizes[0] if block_sizes else None + if hasattr(block_size, "value"): + # Handle MultipleOf enum + block_size = None + + # Check if sparse via class method if available + is_sparse = getattr(backend_class, "is_sparse", lambda: False)() + + # Get properties that can't be inferred + props = _BACKEND_PROPERTIES.get(backend, {}) + + return { + "backend_class": backend_class, + "impl_class": backend_class.get_impl_cls(), + "builder_class": backend_class.get_builder_cls(), + "query_format": props.get("query_format", "tuple"), + "block_size": block_size, + "is_sparse": is_sparse, + } + + +# ============================================================================ +# Metadata Building Helpers +# ============================================================================ + + +def _build_attention_metadata( + requests: list, + block_size: int, + device: torch.device, + builder_instance, +) -> tuple: + """ + Build attention metadata from batch requests. + + Args: + requests: List of BatchRequest objects + block_size: KV cache block size + device: Target device + builder_instance: Metadata builder instance + + Returns: + Tuple of (metadata, kv_cache_num_blocks) + """ + q_lens = [r.q_len for r in requests] + kv_lens = [r.kv_len for r in requests] + total_q = sum(q_lens) + max_kv = max(kv_lens) + + # Build query start locations + q_start_cpu = torch.tensor( + [0] + [sum(q_lens[: i + 1]) for i in range(len(q_lens))], + dtype=torch.int32, + ) + q_start_gpu = q_start_cpu.to(device) + + # Build sequence lengths + seq_lens_cpu = torch.tensor(kv_lens, dtype=torch.int32) + seq_lens_gpu = seq_lens_cpu.to(device) + + # Build num_computed_tokens (context length for each request) + context_lens = [kv_len - q_len for q_len, kv_len in zip(q_lens, kv_lens)] + num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + + # Build block table + num_blocks_per_req = [(kv + block_size - 1) // block_size for kv in kv_lens] + max_num_blocks = max(num_blocks_per_req) + + block_table_cpu = np.zeros((len(requests), max_num_blocks), dtype=np.int32) + current_block = 0 + for i, num_blocks in enumerate(num_blocks_per_req): + for j in range(num_blocks): + block_table_cpu[i, j] = current_block + current_block += 1 + + block_table_gpu = torch.from_numpy(block_table_cpu).to(device) + + # Build slot mapping + slot_mapping_list = [] + for i, (q_len, kv_len, num_blocks) in enumerate( + zip(q_lens, kv_lens, num_blocks_per_req) + ): + context_len = kv_len - q_len + for j in range(q_len): + token_kv_idx = context_len + j + block_idx = token_kv_idx // block_size + offset_in_block = token_kv_idx % block_size + global_block_id = block_table_cpu[i, block_idx] + slot_id = global_block_id * block_size + offset_in_block + slot_mapping_list.append(slot_id) + + slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64, device=device) + + # Create CommonAttentionMetadata + from vllm.v1.attention.backends.utils import CommonAttentionMetadata + + common_attn_metadata = CommonAttentionMetadata( + num_reqs=len(requests), + max_query_len=max(q_lens), + max_seq_len=max_kv, + num_actual_tokens=total_q, + query_start_loc=q_start_gpu, + query_start_loc_cpu=q_start_cpu, + seq_lens=seq_lens_gpu, + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, + slot_mapping=slot_mapping, + block_table_tensor=block_table_gpu, + dcp_local_seq_lens=None, + ) + + # Use the production build() method + metadata = builder_instance.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=False, + ) + + return metadata, current_block + + +def _create_input_tensors( + total_q: int, + mla_dims: dict, + query_format: str, + device: torch.device, + dtype: torch.dtype, +): + """ + Create input tensors for both decode and prefill modes. + + MLA requires different tensor formats for decode vs prefill: + - Decode: Uses kv_lora_rank (512) dimension + - Prefill: Uses qk_nope_head_dim (128) to stay under FlashAttention's 256 limit + + Args: + total_q: Total number of query tokens + mla_dims: MLA dimension configuration + query_format: Either "tuple" or "concat" + device: Target device + dtype: Tensor dtype + + Returns: + Tuple of (decode_inputs, prefill_inputs) + - decode_inputs: Query tensor(s) for decode mode + - prefill_inputs: Dict with 'q', 'k_c_normed', 'k_pe', 'k_scale' for prefill + """ + if query_format == "tuple": + # Decode mode format: (q_nope, q_pe) where q_nope has kv_lora_rank dim + q_nope_decode = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["kv_lora_rank"], + device=device, + dtype=dtype, + ) + q_pe = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["qk_rope_head_dim"], + device=device, + dtype=dtype, + ) + decode_inputs = (q_nope_decode, q_pe) + + # For prefill, we need q with qk_nope_head_dim instead of kv_lora_rank + q_nope_prefill = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["qk_nope_head_dim"], + device=device, + dtype=dtype, + ) + prefill_q = torch.cat([q_nope_prefill, q_pe], dim=-1) + else: # concat + decode_inputs = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"], + device=device, + dtype=dtype, + ) + # For prefill with concat format + prefill_q = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"], + device=device, + dtype=dtype, + ) + + # Create additional inputs needed for prefill forward + k_c_normed = torch.randn( + total_q, + mla_dims["kv_lora_rank"], + device=device, + dtype=dtype, + ) + k_pe = torch.randn( + total_q, + 1, # Single head for MLA + mla_dims["qk_rope_head_dim"], + device=device, + dtype=dtype, + ) + k_scale = torch.ones(1, device=device, dtype=torch.float32) + + output = torch.zeros( + total_q, + mla_dims["num_q_heads"] * mla_dims["v_head_dim"], + device=device, + dtype=dtype, + ) + + prefill_inputs = { + "q": prefill_q, + "k_c_normed": k_c_normed, + "k_pe": k_pe, + "k_scale": k_scale, + "output": output, + } + + return decode_inputs, prefill_inputs + + +# ============================================================================ +# Backend Initialization +# ============================================================================ + + +def _create_backend_impl( + backend_cfg: dict, + mla_dims: dict, + vllm_config: VllmConfig, + device: torch.device, + max_num_tokens: int = 8192, + index_topk: int | None = None, +): + """ + Create backend implementation instance. + + Args: + backend_cfg: Backend configuration dict from _get_backend_config() + mla_dims: MLA dimension configuration + vllm_config: VllmConfig instance + device: Target device + max_num_tokens: Maximum number of tokens for sparse indexer buffer + index_topk: Topk value for sparse MLA backends + + Returns: + Tuple of (impl, layer, builder_instance, indexer) + """ + # Get classes from backend config (already resolved by _get_backend_config) + impl_class = backend_cfg["impl_class"] + builder_class = backend_cfg["builder_class"] + + # Calculate scale + scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]) + + # Create mock kv_b_proj layer for prefill mode + mock_kv_b_proj = MockKVBProj( + num_heads=mla_dims["num_q_heads"], + qk_nope_head_dim=mla_dims["qk_nope_head_dim"], + v_head_dim=mla_dims["v_head_dim"], + ) + + # Create indexer for sparse backends + indexer = None + if backend_cfg.get("is_sparse", False): + if index_topk is None: + index_topk = 2048 # Default topk for sparse MLA + indexer = MockIndexer( + max_num_tokens=max_num_tokens, + topk_tokens=index_topk, + device=device, + ) + + # Build impl kwargs + impl_kwargs = { + "num_heads": mla_dims["num_q_heads"], + "head_size": mla_dims["head_dim"], + "scale": scale, + "num_kv_heads": mla_dims["num_kv_heads"], + "alibi_slopes": None, + "sliding_window": None, + "kv_cache_dtype": "auto", + "logits_soft_cap": None, + "attn_type": "decoder", + "kv_sharing_target_layer_name": None, + "q_lora_rank": None, + "kv_lora_rank": mla_dims["kv_lora_rank"], + "qk_nope_head_dim": mla_dims["qk_nope_head_dim"], + "qk_rope_head_dim": mla_dims["qk_rope_head_dim"], + "qk_head_dim": mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"], + "v_head_dim": mla_dims["v_head_dim"], + "kv_b_proj": mock_kv_b_proj, + } + + # Add indexer for sparse backends + if indexer is not None: + impl_kwargs["indexer"] = indexer + + # Create impl + impl = impl_class(**impl_kwargs) + + # Initialize DCP attributes + if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1): + impl.dcp_world_size = 1 + impl.dcp_rank = 0 + + # Create KV cache spec for MockLayer + from vllm.v1.kv_cache_interface import FullAttentionSpec + + kv_cache_spec = FullAttentionSpec( + block_size=backend_cfg["block_size"] or vllm_config.cache_config.block_size, + num_kv_heads=1, # MLA uses 1 KV head + head_size=576, # MLA head dim + dtype=torch.bfloat16, + ) + + # Create mock layer + layer = MockLayer(device, impl=impl, kv_cache_spec=kv_cache_spec) + + # Create builder instance if needed + builder_instance = None + if builder_class: + # Populate static_forward_context so builder can find the layer + # MockLayer inherits from AttentionLayerBase, so isinstance checks pass + vllm_config.compilation_config.static_forward_context = {"placeholder": layer} + + builder_instance = builder_class( + kv_cache_spec=kv_cache_spec, + layer_names=["placeholder"], + vllm_config=vllm_config, + device=device, + ) + + return impl, layer, builder_instance, indexer + + +# ============================================================================ +# Config Helpers +# ============================================================================ + + +def _extract_mla_dims_from_config(config) -> dict | None: + """ + Extract MLA dimensions from BenchmarkConfig if all required fields are present. + + Args: + config: BenchmarkConfig instance + + Returns: + Dict with MLA dimensions if all fields are provided, None otherwise + """ + # Check if all MLA-specific fields are provided + if all( + [ + config.kv_lora_rank is not None, + config.qk_nope_head_dim is not None, + config.qk_rope_head_dim is not None, + config.v_head_dim is not None, + ] + ): + return { + "kv_lora_rank": config.kv_lora_rank, + "qk_nope_head_dim": config.qk_nope_head_dim, + "qk_rope_head_dim": config.qk_rope_head_dim, + "v_head_dim": config.v_head_dim, + "num_q_heads": config.num_q_heads, + "num_kv_heads": config.num_kv_heads, + "head_dim": config.head_dim, + } + # Fallback: if MLA fields not fully specified, try to construct from basic fields + elif config.head_dim == 576: + # This looks like a DeepSeek MLA config, use standard dimensions with custom + # head count + return { + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "num_q_heads": config.num_q_heads, + "num_kv_heads": config.num_kv_heads, + "head_dim": config.head_dim, + } + return None + + +# ============================================================================ +# Benchmark Execution +# ============================================================================ + + +def _run_single_benchmark( + config, + impl, + layer, + builder_instance, + backend_cfg: dict, + mla_dims: dict, + device: torch.device, + indexer=None, +) -> BenchmarkResult: + """ + Run a single benchmark iteration. + + Args: + config: BenchmarkConfig instance + impl: Backend implementation instance + layer: MockLayer instance + builder_instance: Metadata builder instance + backend_cfg: Backend configuration dict + mla_dims: MLA dimension configuration + device: Target device + indexer: Optional MockIndexer for sparse backends + + Returns: + BenchmarkResult with timing statistics + """ + # Parse batch spec + requests = parse_batch_spec(config.batch_spec) + q_lens = [r.q_len for r in requests] + kv_lens = [r.kv_len for r in requests] + total_q = sum(q_lens) + max_kv_len = max(kv_lens) + + # Determine block size + block_size = backend_cfg["block_size"] or config.block_size + + # Build metadata + metadata, num_blocks = _build_attention_metadata( + requests, block_size, device, builder_instance + ) + + # Create KV cache + kv_cache = torch.zeros( + num_blocks, + block_size, + mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"], + device=device, + dtype=torch.bfloat16, + ) + + # Create input tensors for both decode and prefill modes + decode_inputs, prefill_inputs = _create_input_tensors( + total_q, + mla_dims, + backend_cfg["query_format"], + device, + torch.bfloat16, + ) + + # Fill indexer with random indices for sparse backends + is_sparse = backend_cfg.get("is_sparse", False) + if is_sparse and indexer is not None: + indexer.fill_random_indices(total_q, max_kv_len) + + # Determine which forward method to use + if is_sparse: + # Sparse backends use forward_mqa + forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer) + elif metadata.decode is not None: + forward_fn = lambda: impl._forward_decode( + decode_inputs, kv_cache, metadata, layer + ) + elif metadata.prefill is not None: + forward_fn = lambda: impl._forward_prefill( + prefill_inputs["q"], + prefill_inputs["k_c_normed"], + prefill_inputs["k_pe"], + kv_cache, + metadata, + prefill_inputs["k_scale"], + prefill_inputs["output"], + ) + else: + raise RuntimeError("Metadata has neither decode nor prefill metadata") + + # Warmup + for _ in range(config.warmup_iters): + forward_fn() + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(config.repeats): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(config.num_layers): + forward_fn() + end.record() + + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) + times.append(elapsed_ms / 1000.0 / config.num_layers) + + mean_time = float(np.mean(times)) + return BenchmarkResult( + config=config, + mean_time=mean_time, + std_time=float(np.std(times)), + min_time=float(np.min(times)), + max_time=float(np.max(times)), + throughput_tokens_per_sec=total_q / mean_time if mean_time > 0 else 0, + ) + + +def _run_mla_benchmark_batched( + backend: str, + configs_with_params: list[tuple], # [(config, threshold, num_splits), ...] + index_topk: int = 2048, +) -> list[BenchmarkResult]: + """ + Unified batched MLA benchmark runner for all backends. + + Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla, + flashinfer_mla_sparse, flashmla_sparse + + This function reuses backend initialization across multiple benchmarks + to avoid setup/teardown overhead. + + Args: + backend: Backend name + configs_with_params: List of (config, threshold, num_splits) tuples + - threshold: reorder_batch_threshold (FlashAttn/FlashMLA only) + - num_splits: num_kv_splits (CUTLASS only) + index_topk: Topk value for sparse MLA backends (default 2048) + + Returns: + List of BenchmarkResult objects + """ + if not configs_with_params: + return [] + + backend_cfg = _get_backend_config(backend) + device = torch.device(configs_with_params[0][0].device) + torch.cuda.set_device(device) + + # Determine block size + config_block_size = configs_with_params[0][0].block_size + block_size = backend_cfg["block_size"] or config_block_size + + # Extract MLA dimensions from the first config + first_config = configs_with_params[0][0] + mla_dims = _extract_mla_dims_from_config(first_config) + + # If config didn't provide MLA dims, fall back to default model + if mla_dims is None: + mla_dims = setup_mla_dims("deepseek-v3") + + # Determine if this is a sparse backend + is_sparse = backend_cfg.get("is_sparse", False) + + # Create and set vLLM config for MLA (reused across all benchmarks) + vllm_config = create_minimal_vllm_config( + model_name="deepseek-v3", # Used only for model path + block_size=block_size, + mla_dims=mla_dims, # Use custom dims from config or default + index_topk=index_topk if is_sparse else None, + ) + + results = [] + + with set_current_vllm_config(vllm_config): + # Create backend impl, layer, builder, and indexer (reused across benchmarks) + impl, layer, builder_instance, indexer = _create_backend_impl( + backend_cfg, + mla_dims, + vllm_config, + device, + index_topk=index_topk if is_sparse else None, + ) + + # Run each benchmark with the shared impl + for config, threshold, num_splits in configs_with_params: + # Set threshold for this benchmark (FlashAttn/FlashMLA only) + original_threshold = None + if threshold is not None and builder_instance: + original_threshold = builder_instance.reorder_batch_threshold + builder_instance.reorder_batch_threshold = threshold + + # Set num_splits for CUTLASS + original_num_splits = None + if num_splits is not None and hasattr(impl, "_num_kv_splits"): + original_num_splits = impl._num_kv_splits + impl._num_kv_splits = num_splits + + try: + result = _run_single_benchmark( + config, + impl, + layer, + builder_instance, + backend_cfg, + mla_dims, + device, + indexer=indexer, + ) + results.append(result) + + finally: + # Restore original threshold + if original_threshold is not None: + builder_instance.reorder_batch_threshold = original_threshold + + # Restore original num_splits + if original_num_splits is not None: + impl._num_kv_splits = original_num_splits + + return results + + +# ============================================================================ +# Public API +# ============================================================================ + + +def run_mla_benchmark( + backend: str, + config, + reorder_batch_threshold: int | None = None, + num_kv_splits: int | None = None, + index_topk: int = 2048, +) -> BenchmarkResult | list[BenchmarkResult]: + """ + Unified MLA benchmark runner for all backends. + + Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla, + flashinfer_mla_sparse, flashmla_sparse + + Always uses batched execution internally for optimal performance. + + Args: + backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla, + flashinfer_mla_sparse, flashmla_sparse) + config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples + reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA + (single config mode only) + num_kv_splits: Number of KV splits for CUTLASS (single config mode only) + index_topk: Topk value for sparse MLA backends (default 2048) + + Returns: + BenchmarkResult (single mode) or list of BenchmarkResult (batched mode) + """ + # Normalize to batched mode: (config, threshold, num_splits) + if isinstance(config, list): + # Already in batched format + if len(config) > 0 and isinstance(config[0], tuple): + # Format: [(cfg, param), ...] where param is threshold or num_splits + if backend in ("flashattn_mla", "flashmla", "flashmla_sparse"): + configs_with_params = [(cfg, param, None) for cfg, param in config] + else: # cutlass_mla, flashinfer_mla, or sparse backends + configs_with_params = [(cfg, None, param) for cfg, param in config] + else: + # Format: [cfg, ...] - just configs + configs_with_params = [(cfg, None, None) for cfg in config] + return_single = False + else: + # Single config: convert to batched format + configs_with_params = [(config, reorder_batch_threshold, num_kv_splits)] + return_single = True + + # Use unified batched execution + results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk) + + # Return single result or list based on input + return results[0] if return_single else results diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..6457a599ab9182dc046fe5da7939473cb2629032 --- /dev/null +++ b/benchmarks/attention_benchmarks/runner.py @@ -0,0 +1,539 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Standard attention benchmark runner - shared utilities for non-MLA benchmarks. + +This module provides helpers for running standard attention backends +(FlashAttention, Triton, FlashInfer) with real vLLM integration. +""" + +import logging +import types +from contextlib import contextmanager + +import numpy as np +import torch +from batch_spec import parse_batch_spec, reorder_for_flashinfer +from common import BenchmarkConfig, BenchmarkResult, MockLayer, get_attention_scale + +from vllm.config import ( + CacheConfig, + CompilationConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + get_kv_cache_layout, + set_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec + +# ============================================================================ +# Backend Configuration +# ============================================================================ + + +def _get_backend_config(backend: str) -> dict: + """ + Get backend configuration from AttentionBackendEnum. + + Args: + backend: Backend name matching AttentionBackendEnum exactly + (e.g., "FLASH_ATTN", "TRITON_ATTN", "FLASHINFER") + + Returns: + Dict with backend_class + """ + from vllm.v1.attention.backends.registry import AttentionBackendEnum + + try: + backend_enum = AttentionBackendEnum[backend] + backend_class = backend_enum.get_class() + except (KeyError, ValueError) as e: + valid_backends = [b.name for b in AttentionBackendEnum if b.name != "CUSTOM"] + raise ValueError( + f"Unknown backend: {backend}. Valid backends: {valid_backends}" + ) from e + + return {"backend_class": backend_class} + + +@contextmanager +def log_warnings_and_errors_only(): + """Temporarily set vLLM logger to WARNING level.""" + logger = logging.getLogger("vllm") + old_level = logger.level + logger.setLevel(logging.WARNING) + try: + yield + finally: + logger.setLevel(old_level) + + +# ============================================================================ +# Metadata Building Helpers +# ============================================================================ + + +def _build_common_attn_metadata( + q_lens: list[int], + kv_lens: list[int], + block_size: int, + device: torch.device, +) -> CommonAttentionMetadata: + """Build CommonAttentionMetadata from query/kv lengths.""" + batch_size = len(q_lens) + total_tokens = sum(q_lens) + + query_start_loc = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + query_start_loc[1:] = torch.tensor(q_lens, dtype=torch.int32, device=device).cumsum( + 0 + ) + query_start_loc_cpu = query_start_loc.cpu() + + seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device) + max_seq_len = int(seq_lens.max().item()) + + max_blocks = (max(kv_lens) + block_size - 1) // block_size + num_blocks = batch_size * max_blocks + block_table_tensor = torch.arange( + num_blocks, dtype=torch.int32, device=device + ).view(batch_size, max_blocks) + slot_mapping = torch.arange(total_tokens, dtype=torch.int64, device=device) + + max_query_len = max(q_lens) + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + num_reqs=batch_size, + num_actual_tokens=total_tokens, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + causal=True, + ) + + +def _create_vllm_config( + config: BenchmarkConfig, + max_num_blocks: int, +) -> VllmConfig: + """Create a VllmConfig for benchmarking with mock model methods.""" + model_config = ModelConfig( + model="meta-llama/Meta-Llama-3-8B", + tokenizer="meta-llama/Meta-Llama-3-8B", + trust_remote_code=False, + dtype="auto", # Use model's native dtype + seed=0, + max_model_len=1024, + ) + + cache_config = CacheConfig( + block_size=config.block_size, + cache_dtype="auto", + swap_space=0, + ) + cache_config.num_gpu_blocks = max_num_blocks + cache_config.num_cpu_blocks = 0 + + parallel_config = ParallelConfig(tensor_parallel_size=1) + scheduler_config = SchedulerConfig( + max_num_seqs=256, + max_num_batched_tokens=8192, + max_model_len=8192, + is_encoder_decoder=False, + enable_chunked_prefill=True, + ) + device_config = DeviceConfig() + load_config = LoadConfig() + compilation_config = CompilationConfig() + + # Add mock methods for benchmark config values + model_config.get_num_layers = types.MethodType( + lambda self: config.num_layers, model_config + ) + model_config.get_sliding_window_for_layer = types.MethodType( + lambda self, i: None, model_config + ) + model_config.get_logits_soft_cap_for_layer = types.MethodType( + lambda self, i: 0.0, model_config + ) + model_config.get_sm_scale_for_layer = types.MethodType( + lambda self, i: 1.0 / config.head_dim**0.5, model_config + ) + model_config.get_num_attention_heads = types.MethodType( + lambda self, parallel_config=None: config.num_q_heads, model_config + ) + model_config.get_num_kv_heads = types.MethodType( + lambda self, parallel_config=None: config.num_kv_heads, model_config + ) + model_config.get_head_size = types.MethodType( + lambda self: config.head_dim, model_config + ) + model_config.get_sliding_window = types.MethodType(lambda self: None, model_config) + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + compilation_config=compilation_config, + ) + + +# ============================================================================ +# Backend Initialization +# ============================================================================ + + +def _create_backend_impl( + backend_cfg: dict, + config: BenchmarkConfig, + device: torch.device, + dtype: torch.dtype, +): + """Create backend implementation instance.""" + backend_class = backend_cfg["backend_class"] + + scale = get_attention_scale(config.head_dim) + + impl = backend_class.get_impl_cls()( + num_heads=config.num_q_heads, + head_size=config.head_dim, + scale=scale, + num_kv_heads=config.num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + kv_cache_spec = FullAttentionSpec( + block_size=config.block_size, + num_kv_heads=config.num_kv_heads, + head_size=config.head_dim, + dtype=dtype, + ) + + layer = MockLayer(device, kv_cache_spec=kv_cache_spec) + + return backend_class, impl, layer + + +def _create_metadata_builder( + backend_class, + kv_cache_spec: FullAttentionSpec, + vllm_config: VllmConfig, + device: torch.device, + backend_name: str = "", +): + """Create metadata builder instance.""" + layer_names = ["layer_0"] + builder_cls = backend_class.get_builder_cls() + + # Flashinfer needs get_per_layer_parameters mocked since we don't have + # real model layers registered + if backend_name == "FLASHINFER": + import unittest.mock + + from vllm.v1.attention.backends.utils import PerLayerParameters + + def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): + head_size = vllm_config.model_config.get_head_size() + return { + layer_name: PerLayerParameters( + window_left=-1, # No sliding window + logits_soft_cap=0.0, # No soft cap + sm_scale=1.0 / (head_size**0.5), # Standard scale + ) + for layer_name in layer_names + } + + with unittest.mock.patch( + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + mock_get_per_layer_parameters, + ): + return builder_cls( + kv_cache_spec=kv_cache_spec, + layer_names=layer_names, + vllm_config=vllm_config, + device=device, + ) + + return builder_cls( + kv_cache_spec=kv_cache_spec, + layer_names=layer_names, + vllm_config=vllm_config, + device=device, + ) + + +# ============================================================================ +# Tensor Creation Helpers +# ============================================================================ + + +def _create_input_tensors( + config: BenchmarkConfig, + total_q: int, + device: torch.device, + dtype: torch.dtype, +) -> tuple: + """Create Q, K, V input tensors for all layers.""" + q_list = [ + torch.randn( + total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype + ) + for _ in range(config.num_layers) + ] + k_list = [ + torch.randn( + total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype + ) + for _ in range(config.num_layers) + ] + v_list = [ + torch.randn( + total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype + ) + for _ in range(config.num_layers) + ] + return q_list, k_list, v_list + + +def _create_kv_cache( + config: BenchmarkConfig, + max_num_blocks: int, + backend_class, + device: torch.device, + dtype: torch.dtype, +) -> list: + """Create KV cache tensors for all layers using the backend's methods. + + Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order() + to create the cache with the correct shape and memory layout. + """ + # Get the logical shape from the backend + cache_shape = backend_class.get_kv_cache_shape( + num_blocks=max_num_blocks, + block_size=config.block_size, + num_kv_heads=config.num_kv_heads, + head_size=config.head_dim, + ) + + # Get the stride order for custom memory layout + try: + stride_order = backend_class.get_kv_cache_stride_order() + assert len(stride_order) == len(cache_shape) + except (AttributeError, NotImplementedError): + stride_order = tuple(range(len(cache_shape))) + + # Permute shape to physical layout order + physical_shape = tuple(cache_shape[i] for i in stride_order) + + # Compute inverse permutation to get back to logical view + inv_order = [stride_order.index(i) for i in range(len(stride_order))] + + cache_list = [] + for _ in range(config.num_layers): + # Allocate in physical layout order (contiguous in memory) + cache = torch.zeros(*physical_shape, device=device, dtype=dtype) + # Permute to logical view + cache = cache.permute(*inv_order) + cache_list.append(cache) + + return cache_list + + +# ============================================================================ +# Benchmark Execution +# ============================================================================ + + +def _run_single_benchmark( + config: BenchmarkConfig, + impl, + layer, + q_list: list, + k_list: list, + v_list: list, + cache_list: list, + attn_metadata, + device: torch.device, + dtype: torch.dtype, +) -> tuple: + """Run single benchmark iteration with warmup and timing loop.""" + total_q = q_list[0].shape[0] + out = torch.empty( + total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype + ) + + # Warmup + for _ in range(config.warmup_iters): + for i in range(config.num_layers): + impl.forward( + layer, + q_list[i], + k_list[i], + v_list[i], + cache_list[i], + attn_metadata, + output=out, + ) + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(config.repeats): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for i in range(config.num_layers): + impl.forward( + layer, + q_list[i], + k_list[i], + v_list[i], + cache_list[i], + attn_metadata, + output=out, + ) + end.record() + + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) + times.append(elapsed_ms / 1000.0 / config.num_layers) # seconds per layer + + mem_stats = {} + if config.profile_memory: + mem_stats = { + "allocated_mb": torch.cuda.memory_allocated(device) / 1024**2, + "reserved_mb": torch.cuda.memory_reserved(device) / 1024**2, + } + + return times, mem_stats + + +# ============================================================================ +# Public API +# ============================================================================ + + +def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: + """ + Run standard attention benchmark with real kernels. + + Supports: FLASH_ATTN, TRITON_ATTN, FLASHINFER + + Args: + config: Benchmark configuration + + Returns: + BenchmarkResult with timing and memory statistics + """ + device = torch.device(config.device) + torch.cuda.set_device(device) + + backend_cfg = _get_backend_config(config.backend) + + requests = parse_batch_spec(config.batch_spec) + + if config.backend == "FLASHINFER": + requests = reorder_for_flashinfer(requests) + + q_lens = [r.q_len for r in requests] + kv_lens = [r.kv_len for r in requests] + total_q = sum(q_lens) + max_kv = max(kv_lens) + batch_size = len(q_lens) + + # Calculate total blocks needed: batch_size * max_blocks_per_request + max_blocks_per_request = (max_kv + config.block_size - 1) // config.block_size + max_num_blocks = batch_size * max_blocks_per_request + + # Suppress vLLM logs during setup to reduce spam + with log_warnings_and_errors_only(): + # Create vllm_config first - uses model's native dtype via "auto" + vllm_config = _create_vllm_config(config, max_num_blocks) + dtype = vllm_config.model_config.dtype + + # Wrap everything in set_current_vllm_config context + # This is required for backends like flashinfer that need global config + with set_current_vllm_config(vllm_config): + backend_class, impl, layer = _create_backend_impl( + backend_cfg, config, device, dtype + ) + + # Set KV cache layout if the backend requires a specific one + # (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention) + required_layout = backend_class.get_required_kv_cache_layout() + if required_layout is not None: + set_kv_cache_layout(required_layout) + get_kv_cache_layout.cache_clear() + + common_metadata = _build_common_attn_metadata( + q_lens, kv_lens, config.block_size, device + ) + + kv_cache_spec = FullAttentionSpec( + block_size=config.block_size, + num_kv_heads=config.num_kv_heads, + head_size=config.head_dim, + dtype=dtype, + ) + + builder = _create_metadata_builder( + backend_class, kv_cache_spec, vllm_config, device, config.backend + ) + + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_metadata, + ) + + q_list, k_list, v_list = _create_input_tensors( + config, total_q, device, dtype + ) + + cache_list = _create_kv_cache( + config, max_num_blocks, backend_class, device, dtype + ) + + times, mem_stats = _run_single_benchmark( + config, + impl, + layer, + q_list, + k_list, + v_list, + cache_list, + attn_metadata, + device, + dtype, + ) + + mean_time = np.mean(times) + throughput = total_q / mean_time if mean_time > 0 else 0 + + return BenchmarkResult( + config=config, + mean_time=mean_time, + std_time=np.std(times), + min_time=np.min(times), + max_time=np.max(times), + throughput_tokens_per_sec=throughput, + memory_allocated_mb=mem_stats.get("allocated_mb"), + memory_reserved_mb=mem_stats.get("reserved_mb"), + ) diff --git a/benchmarks/auto_tune/README.md b/benchmarks/auto_tune/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9a9600e08dafeccbfeff11ae3450c83b22c9f999 --- /dev/null +++ b/benchmarks/auto_tune/README.md @@ -0,0 +1,218 @@ +# Automated vLLM Server Parameter Tuning + +This script automates the process of finding the optimal server parameter combination (`max-num-seqs` and `max-num-batched-tokens`) to maximize throughput for a vLLM server. It also supports additional constraints such as E2E latency and prefix cache hit rate. + +## Table of Contents + +- [Prerequisites](#prerequisites) +- [Configuration](#configuration) +- [How to Run](#how-to-run) +- [Example Use Cases](#example-use-cases) +- [Output](#output) +- [How It Works](#how-it-works) + +## Prerequisites + +Before running the script, please ensure the following steps are completed: + +1. **Clone vLLM & Set Up Branch**: Clone the vLLM repository and check out to your desired branch. + +```bash +git clone https://github.com/vllm-project/vllm.git +cd vllm +# git checkout +``` + +1. **Install Environment**: Install or update the correct running environment. For TPU usage, activate your `conda` environment and install the corresponding `torch` and `torch_xla` versions. + +2. **Model Configuration**: If you are using a customized model, ensure its configuration files are correctly placed and accessible. + +## Configuration + +You must set the following variables at the top of the script before execution. + + Note: You can also override the default values below via environment variables when running the script. + +```bash +MODEL=meta-llama/Llama-3.3-70B-Instruct SYSTEM=TPU TP=8 DOWNLOAD_DIR='' INPUT_LEN=128 OUTPUT_LEN=2048 MAX_MODEL_LEN=2300 MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 NUM_SEQS_LIST="128 256" NUM_BATCHED_TOKENS_LIST="1024 2048 4096" VLLM_LOGGING_LEVEL=DEBUG bash auto_tune.sh +``` + +| Variable | Description | Example Value | +| --- | --- | --- | +| `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` | +| `MODEL` | **Required.** The Hugging Face model identifier to be served by vllm. | `"meta-llama/Llama-3.1-8B-Instruct"` | +| `SYSTEM`| **Required.** The hardware you are running on. Choices: `TPU` or `GPU`. (For other systems, it might not support saving profiles) | `"TPU"` | +| `TP` | **Required.** The tensor-parallelism size. | `1` | +| `DOWNLOAD_DIR` | **Required.** Directory to download and load model weights from. | `""` (default download path) | +| `INPUT_LEN` | **Required.** Request input length. | `4000` | +| `OUTPUT_LEN` | **Required.** Request output length. | `16` | +| `MAX_MODEL_LEN` | **Required.** Max model length. | `4096` | +| `MIN_CACHE_HIT_PCT` | Prefix cache hit rate in percentage (0-100). Set to `0` to disable. | `60` | +| `MAX_LATENCY_ALLOWED_MS` | The maximum allowed P99 end-to-end latency in milliseconds. Set to a very large number (e.g., `100000000000`) to effectively ignore the latency constraint. | `500` | +| `NUM_SEQS_LIST` | A space-separated string of `max-num-seqs` values to test. | `"128 256"` | +| `NUM_BATCHED_TOKENS_LIST` | A space-separated string of `max-num-batched-tokens` values to test. | `"1024 2048 4096"` | + +**Note**: The default `NUM_SEQS_LIST` and `NUM_BATCHED_TOKENS_LIST` are set for medium-sized inputs/outputs. For very short contexts (e.g., 20 input, 20 output tokens), you may need to test larger values for `max-num-seqs`. + +## How to Run + +1. **Configure**: Edit the script and set the variables in the [Configuration](#configuration) section. +2. **Execute**: Run the script. Since the process can take a long time, it is highly recommended to use a terminal multiplexer like `tmux` or `screen` to prevent the script from stopping if your connection is lost. + +```bash +cd +bash auto_tune.sh +``` + + Please note that the `bash auto_tune.sh` command cannot contain full or partial path with keyword `vllm`, otherwise `pkill -f vllm` command will also kill this script itself. + +## Example Use Cases + +Here are a few examples of how to configure the script for different goals: + +### 1. Maximize Throughput (No Latency Constraint) + +- **Goal**: Find the best `max-num-seqs` and `max-num-batched-tokens` to get the highest possible throughput for 1800 input tokens and 20 output tokens. +- **Configuration**: + +```bash +INPUT_LEN=1800 +OUTPUT_LEN=20 +MAX_MODEL_LEN=2048 +MIN_CACHE_HIT_PCT=0 +MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number +``` + +### 2. Maximize Throughput with a Latency Requirement + +- **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms. +- **Configuration**: + +```bash +INPUT_LEN=1800 +OUTPUT_LEN=20 +MAX_MODEL_LEN=2048 +MIN_CACHE_HIT_PCT=0 +MAX_LATENCY_ALLOWED_MS=500 +``` + +### 3. Maximize Throughput with Prefix Caching and Latency Requirements + +- **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms. +- **Configuration**: + +```bash +INPUT_LEN=1800 +OUTPUT_LEN=20 +MAX_MODEL_LEN=2048 +MIN_CACHE_HIT_PCT=60 +MAX_LATENCY_ALLOWED_MS=500 +``` + +## Output + +After the script finishes, you will find the results in a new, timestamped directory created inside `$BASE/auto-benchmark/`. + +- **Log Files**: The directory (`$BASE/auto-benchmark/YYYY_MM_DD_HH_MM/`) contains detailed logs for each run: + - `vllm_log_...txt`: The log output from the vLLM server for each parameter combination. + - `bm_log_...txt`: The log output from the `vllm bench serve` command for each benchmark run. + +- **Final Result Summary**: A file named `result.txt` is created in the log directory. It contains a summary of each tested combination and concludes with the overall best parameters found. + +```text +# Example result.txt content +hash:a1b2c3d4... +max_num_seqs: 128, max_num_batched_tokens: 2048, request_rate: 10.0, e2el: 450.5, throughput: 9.8, goodput: 9.8 +max_num_seqs: 128, max_num_batched_tokens: 4096 does not meet latency requirement 500 +... +best_max_num_seqs: 256, best_num_batched_tokens: 2048, best_throughput: 12.5, profile saved in: /home/user/vllm/auto-benchmark/2024_08_01_10_30/profile +``` + + If it cannot find the best parameters, the final row will be `best_max_num_seqs: 0, best_num_batched_tokens: 0, best_throughput: 0`. This can be due to either the server not starting properly, or the latency requirement being too strict. + +- **Profiler Trace**: A directory named `profile` is created inside the log directory. It contains the profiler trace file (e.g., `.xplane.pb` for TPU or a `.json` trace for GPU) from the single best-performing run. + +## How It Works + +The script follows a systematic process to find the optimal parameters: + +1. **Find Max GPU Memory Utilization**: The script first determines the highest safe `gpu-memory-utilization` (starting from 0.98 and decreasing) that does not cause an Out-Of-Memory (OOM) error when launching the server. This ensures the benchmark runs use the maximum available memory without crashing. + +2. **Iterate and Benchmark**: It then enters a nested loop, iterating through every combination of `max-num-seqs` and `max-num-batched-tokens` provided in the configuration lists. + +3. **Latency-Aware Throughput Search**: For each parameter combination: + - The vLLM server is started. + - A benchmark is first run with an infinite request rate (`--request-rate inf`). + - If the resulting P99 E2E latency is within the `MAX_LATENCY_ALLOWED_MS` limit, this throughput is considered the maximum for this configuration. + - If the latency is too high, the script performs a search by iteratively decreasing the request rate until the latency constraint is met. This finds the highest sustainable throughput for the given parameters and latency requirement. + +4. **Track Best Result**: Throughout the process, the script tracks the parameter combination that has yielded the highest valid throughput so far. + +5. **Profile Collection**: For the best-performing run, the script saves the vLLM profiler output, which can be used for deep-dive performance analysis with tools like TensorBoard. + +## Batched `auto_tune` + +The `batch_auto_tune.sh` script allows you to run multiple `auto_tune.sh` experiments sequentially from a single configuration file. It iterates through a list of parameter sets, executes `auto_tune.sh` for each, and records the results back into the input file. + +### Prerequisites + +- **jq**: This script requires `jq` to parse the JSON configuration file. +- **gcloud**: If you plan to upload results to Google Cloud Storage, the `gcloud` CLI must be installed and authenticated. + +### How to Run + +1. **Create a JSON configuration file**: Create a file (e.g., `runs_config.json`) containing an array of JSON objects. Each object defines the parameters for a single `auto_tune.sh` run. + +2. **Execute the script**: + + ```bash + bash batch_auto_tune.sh [gcs_upload_path] + ``` + + - ``: **Required.** Path to your JSON configuration file. + - `[gcs_upload_path]`: **Optional.** A GCS path (e.g., `gs://my-bucket/benchmark-results`) where the detailed results and profiles for each run will be uploaded. If this is empty, the results will be available on the local filesystem (see the log for `RESULT_FILE=/path/to/results/file.txt`). + +### Configuration File + +The JSON configuration file should contain an array of objects. Each object's keys correspond to the configuration variables for `auto_tune.sh` (see the [Configuration table above](#configuration)). These keys will be converted to uppercase environment variables for each run. + +Here is an example `runs_config.json` with two benchmark configurations: + +```json +[ + { + "base": "/home/user", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "system": "TPU", # OR GPU + "tp": 8, + "input_len": 128, + "output_len": 2048, + "max_model_len": 2300, + "num_seqs_list": "128 256", + "num_batched_tokens_list": "8192 16384" + }, + { + "base": "/home/user", + "model": "meta-llama/Llama-3.1-70B-Instruct", + "system": "TPU", # OR GPU + "tp": 8, + "input_len": 4000, + "output_len": 16, + "max_model_len": 4096, + "num_seqs_list": "64 128", + "num_batched_tokens_list": "4096 8192", + "max_latency_allowed_ms": 500 + } +] +``` + +### Output + +The script modifies the input JSON file in place, adding the results of each run to the corresponding object. The following fields are added: + +- `run_id`: A unique identifier for the run, derived from the timestamp. +- `status`: The outcome of the run (`SUCCESS`, `FAILURE`, or `WARNING_NO_RESULT_FILE`). +- `results`: The content of the `result.txt` file from the `auto_tune.sh` run. +- `gcs_results`: The GCS URL where the run's artifacts are stored (if a GCS path was provided). + +A summary of successful and failed runs is also printed to the console upon completion. diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh new file mode 100644 index 0000000000000000000000000000000000000000..c06b76be5ee68166939c560de7453ec4cfe0506f --- /dev/null +++ b/benchmarks/auto_tune/auto_tune.sh @@ -0,0 +1,322 @@ +#!/bin/bash + +# This script aims to tune the best server parameter combinations to maximize throughput for given requirement. +# See details in README (benchmarks/auto_tune/README.md). + +TAG=$(date +"%Y_%m_%d_%H_%M") +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +VLLM_LOGGING_LEVEL=${VLLM_LOGGING_LEVEL:-INFO} +BASE=${BASE:-"$SCRIPT_DIR/../../.."} +MODEL=${MODEL:-"meta-llama/Llama-3.1-8B-Instruct"} +SYSTEM=${SYSTEM:-"TPU"} +TP=${TP:-1} +DOWNLOAD_DIR=${DOWNLOAD_DIR:-""} +INPUT_LEN=${INPUT_LEN:-4000} +OUTPUT_LEN=${OUTPUT_LEN:-16} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-4096} +MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0} +MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000} +NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"} +NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"} +HOSTNAME=$(hostname) +if [[ -z "$HOSTNAME" ]]; then + echo "Error: Failed to determine hostname." >&2 + exit 1 +fi + +LOG_FOLDER="$BASE/auto-benchmark/$TAG" +RESULT="$LOG_FOLDER/result.txt" +PROFILE_PATH="$LOG_FOLDER/profile" + +echo "====================== AUTO TUNE PARAMETERS ====================" +echo "SCRIPT_DIR=$SCRIPT_DIR" +echo "BASE=$BASE" +echo "MODEL=$MODEL" +echo "SYSTEM=$SYSTEM" +echo "TP=$TP" +echo "DOWNLOAD_DIR=$DOWNLOAD_DIR" +echo "INPUT_LEN=$INPUT_LEN" +echo "OUTPUT_LEN=$OUTPUT_LEN" +echo "MAX_MODEL_LEN=$MAX_MODEL_LEN" +echo "MIN_CACHE_HIT_PCT=$MIN_CACHE_HIT_PCT" +echo "MAX_LATENCY_ALLOWED_MS=$MAX_LATENCY_ALLOWED_MS" +echo "NUM_SEQS_LIST=$NUM_SEQS_LIST" +echo "NUM_BATCHED_TOKENS_LIST=$NUM_BATCHED_TOKENS_LIST" +echo "VLLM_LOGGING_LEVEL=$VLLM_LOGGING_LEVEL" +echo "RESULT_FILE=$RESULT" +echo "====================== AUTO TUNEPARAMETERS ====================" + +rm -rf "$LOG_FOLDER" +rm -rf "$PROFILE_PATH" +mkdir -p "$LOG_FOLDER" +mkdir -p "$PROFILE_PATH" + +cd "$BASE/vllm" + +pip install -q datasets + +current_hash=$(git rev-parse HEAD) +echo "hash:$current_hash" >> "$RESULT" +echo "current_hash: $current_hash" + +TOTAL_LEN=$((INPUT_LEN + OUTPUT_LEN)) +RED='\033[0;31m' +if (( TOTAL_LEN > MAX_MODEL_LEN )); then + echo -e "${RED}FAILED: INPUT_LEN($INPUT_LEN) + OUTPUT_LEN($OUTPUT_LEN) = $TOTAL_LEN, which is > MAX_MODEL_LEN = $MAX_MODEL_LEN.\033[0m" >&2 + exit 1 +fi + +best_throughput=0 +best_max_num_seqs=0 +best_num_batched_tokens=0 +best_goodput=0 +best_request_rate=0 + +start_server() { + local gpu_memory_utilization=$1 + local max_num_seqs=$2 + local max_num_batched_tokens=$3 + local vllm_log=$4 + local profile_dir=$5 + + pkill -if "vllm serve" || true + + # Define the common arguments as a bash array. + # Each argument and its value are separate elements. + local common_args_array=( + "$MODEL" + "--port" "8004" + "--host" "$HOSTNAME" + "--gpu-memory-utilization" "$gpu_memory_utilization" + "--max-num-seqs" "$max_num_seqs" + "--max-num-batched-tokens" "$max_num_batched_tokens" + "--tensor-parallel-size" "$TP" + "--enable-prefix-caching" + "--load-format" "dummy" + "--download-dir" "$DOWNLOAD_DIR" + "--max-model-len" "$MAX_MODEL_LEN" + ) + + # Use the array expansion "${common_args_array[@]}" + # This correctly passes each element as a separate argument. + if [[ -n "$profile_dir" ]]; then + # Start server with profiling enabled + local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}" + VLLM_SERVER_DEV_MODE=1 \ + vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 & + else + # Start server without profiling + VLLM_SERVER_DEV_MODE=1 \ + vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & + fi + local server_pid=$! + + # wait for 10 minutes... + server_started=0 + for _ in {1..60}; do + # This line checks whether the server is still alive or not, + # since that we should always have permission to send signal to the server process. + kill -0 $server_pid 2> /dev/null || break + + RESPONSE=$(curl -s -X GET "http://${HOSTNAME}:8004/health" -w "%{http_code}" -o /dev/stdout) + STATUS_CODE=$(echo "$RESPONSE" | tail -n 1) + if [[ "$STATUS_CODE" -eq 200 ]]; then + server_started=1 + break + else + sleep 10 + fi + done + + if (( ! server_started )); then + echo "server did not start within 10 minutes or crashed. Please check server log at $vllm_log". + return 1 + else + return 0 + fi +} + +run_benchmark() { + local max_num_seqs=$1 + local max_num_batched_tokens=$2 + local gpu_memory_utilization=$3 + echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" + local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" + echo "vllm_log: $vllm_log" + echo + rm -f "$vllm_log" + pkill -if "vllm serve" || true + + echo "starting server..." + # Call start_server without a profile_dir to avoid profiling overhead + start_server "$gpu_memory_utilization" "$max_num_seqs" "$max_num_batched_tokens" "$vllm_log" "" + result=$? + if [[ "$result" -eq 1 ]]; then + echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" + else + echo "server started." + fi + echo + + echo "run benchmark test..." + meet_latency_requirement=0 + # get a basic qps by using request-rate inf + bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt" + prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 )) + adjusted_input_len=$(( INPUT_LEN - prefix_len )) + # --profile flag is removed from this call + vllm bench serve \ + --backend vllm \ + --model "$MODEL" \ + --dataset-name random \ + --random-input-len $adjusted_input_len \ + --random-output-len "$OUTPUT_LEN" \ + --ignore-eos \ + --disable-tqdm \ + --request-rate inf \ + --percentile-metrics ttft,tpot,itl,e2el \ + --goodput e2el:"$MAX_LATENCY_ALLOWED_MS" \ + --num-prompts 1000 \ + --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ + --port 8004 &> "$bm_log" + throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') + goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + + if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then + meet_latency_requirement=1 + request_rate=inf + fi + + if (( ! meet_latency_requirement )); then + # start from request-rate as int(throughput) + 1 + request_rate=$((${throughput%.*} + 1)) + while ((request_rate > 0)); do + # clear prefix cache + curl -X POST http://"${HOSTNAME}":8004/reset_prefix_cache + sleep 5 + bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt" + vllm bench serve \ + --backend vllm \ + --model "$MODEL" \ + --dataset-name random \ + --random-input-len $adjusted_input_len \ + --random-output-len "$OUTPUT_LEN" \ + --ignore-eos \ + --disable-tqdm \ + --request-rate $request_rate \ + --percentile-metrics ttft,tpot,itl,e2el \ + --goodput e2el:"$MAX_LATENCY_ALLOWED_MS" \ + --num-prompts 100 \ + --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ + --port 8004 &> "$bm_log" + throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') + goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then + meet_latency_requirement=1 + break + fi + request_rate=$((request_rate-1)) + done + fi + # write the results and update the best result. + if ((meet_latency_requirement)); then + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, throughput: $throughput, goodput: $goodput" + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, throughput: $throughput, goodput: $goodput" >> "$RESULT" + if (( $(echo "$throughput > $best_throughput" | bc -l) )); then + best_throughput=$throughput + best_max_num_seqs=$max_num_seqs + best_num_batched_tokens=$max_num_batched_tokens + best_goodput=$goodput + best_request_rate=$request_rate + fi + else + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" >> "$RESULT" + fi + + echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" + + pkill -if "vllm serve" || true + sleep 10 + echo "====================" + return 0 +} + +read -r -a num_seqs_list <<< "$NUM_SEQS_LIST" +read -r -a num_batched_tokens_list <<< "$NUM_BATCHED_TOKENS_LIST" + +# first find out the max gpu-memory-utilization without HBM OOM. +gpu_memory_utilization=0.98 +find_gpu_memory_utilization=0 +while (( $(echo "$gpu_memory_utilization >= 0.9" | bc -l) )); do + # Pass empty string for profile_dir argument + start_server "$gpu_memory_utilization" "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log" "" + result=$? + if [[ "$result" -eq 0 ]]; then + find_gpu_memory_utilization=1 + break + else + gpu_memory_utilization=$(echo "$gpu_memory_utilization - 0.01" | bc) + fi +done + +if [[ "$find_gpu_memory_utilization" -eq 1 ]]; then + echo "Using gpu_memory_utilization=$gpu_memory_utilization to serve model." +else + echo "Cannot find a proper gpu_memory_utilization over 0.9 to serve the model, please check logs in $LOG_FOLDER." + exit 1 +fi + +for num_seqs in "${num_seqs_list[@]}"; do + for num_batched_tokens in "${num_batched_tokens_list[@]}"; do + run_benchmark "$num_seqs" "$num_batched_tokens" "$gpu_memory_utilization" + done +done +echo "finish permutations" + +# ================================================================================= +# FINAL PROFILING RUN FOR THE BEST CONFIGURATION +# ================================================================================= +if (( $(echo "$best_throughput > 0" | bc -l) )); then + echo + echo "Benchmark tuning finished. Now running profiling on the best configuration found..." + echo "Best config: max_num_seqs: $best_max_num_seqs, max_num_batched_tokens: $best_num_batched_tokens, throughput: $best_throughput, goodput: $best_goodput" + echo + + vllm_log="$LOG_FOLDER/vllm_log_BEST_PROFILE.txt" + bm_log="$LOG_FOLDER/bm_log_BEST_PROFILE.txt" + + # Start server with the best params and profiling ENABLED + echo "Starting server for profiling..." + start_server "$gpu_memory_utilization" "$best_max_num_seqs" "$best_num_batched_tokens" "$vllm_log" "$PROFILE_PATH" + + # Run benchmark with the best params and the --profile flag + echo "Running benchmark with profiling..." + prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 )) + adjusted_input_len=$(( INPUT_LEN - prefix_len )) + vllm bench serve \ + --backend vllm \ + --model "$MODEL" \ + --dataset-name random \ + --random-input-len $adjusted_input_len \ + --random-output-len "$OUTPUT_LEN" \ + --ignore-eos \ + --disable-tqdm \ + --request-rate "$best_request_rate" \ + --percentile-metrics ttft,tpot,itl,e2el \ + --goodput e2el:"$MAX_LATENCY_ALLOWED_MS" \ + --num-prompts 100 \ + --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ + --port 8004 \ + --profile &> "$bm_log" +else + echo "No configuration met the latency requirements. Skipping final profiling run." +fi +pkill -if "vllm serve" || true +echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" +echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT" diff --git a/benchmarks/auto_tune/batch_auto_tune.sh b/benchmarks/auto_tune/batch_auto_tune.sh new file mode 100644 index 0000000000000000000000000000000000000000..0f3ef0f0385d2e221b8720f3cfd5829c3154999f --- /dev/null +++ b/benchmarks/auto_tune/batch_auto_tune.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +INPUT_JSON="$1" +GCS_PATH="$2" # Optional GCS path for uploading results for each run + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) +AUTOTUNE_SCRIPT="$SCRIPT_DIR/auto_tune.sh" + +if [[ -z "$INPUT_JSON" ]]; then + echo "Error: Input JSON file not provided." + echo "Usage: $0 [gcs_upload_path]" + exit 1 +fi + +if [[ ! -f "$INPUT_JSON" ]]; then + echo "Error: File not found at '$INPUT_JSON'" + exit 1 +fi + +if ! command -v jq &> /dev/null; then + echo "Error: 'jq' command not found. Please install jq to process the JSON input." + exit 1 +fi + +if [[ -n "$GCS_PATH" ]] && ! command -v gcloud &> /dev/null; then + echo "Error: 'gcloud' command not found, but a GCS_PATH was provided." + exit 1 +fi + +SUCCESS_COUNT=0 +FAILURE_COUNT=0 +FAILED_RUNS=() +SCRIPT_START_TIME=$(date +%s) + +json_content=$(cat "$INPUT_JSON") +if ! num_runs=$(echo "$json_content" | jq 'length'); then + echo "Error: Invalid JSON in $INPUT_JSON. 'jq' failed to get array length." >&2 + exit 1 +fi + +echo "Found $num_runs benchmark configurations in $INPUT_JSON." +echo "Starting benchmark runs..." +echo "--------------------------------------------------" + +for i in $(seq 0 $(($num_runs - 1))); do + run_object=$(echo "$json_content" | jq ".[$i]") + + RUN_START_TIME=$(date +%s) + ENV_VARS_ARRAY=() + # Dynamically create env vars from the JSON object's keys + for key in $(echo "$run_object" | jq -r 'keys_unsorted[]'); do + value=$(echo "$run_object" | jq -r ".$key") + var_name=$(echo "$key" | tr '[:lower:]' '[:upper:]' | tr -cd 'A-Z0-9_') + ENV_VARS_ARRAY+=("${var_name}=${value}") + done + + echo "Executing run #$((i+1))/$num_runs with parameters: ${ENV_VARS_ARRAY[*]}" + + # Execute auto_tune.sh and capture output + RUN_OUTPUT_FILE=$(mktemp) + if env "${ENV_VARS_ARRAY[@]}" bash "$AUTOTUNE_SCRIPT" > >(tee -a "$RUN_OUTPUT_FILE") 2>&1; then + STATUS="SUCCESS" + ((SUCCESS_COUNT++)) + else + STATUS="FAILURE" + ((FAILURE_COUNT++)) + FAILED_RUNS+=("Run #$((i+1)): $(echo "$run_object" | jq -c .)") + fi + + RUN_OUTPUT=$(<"$RUN_OUTPUT_FILE") + rm "$RUN_OUTPUT_FILE" + + # Parse results and optionally upload them to GCS + RUN_ID="" + RESULTS="" + GCS_RESULTS_URL="" + if [[ "$STATUS" == "SUCCESS" ]]; then + RESULT_FILE_PATH=$(echo "$RUN_OUTPUT" | grep 'RESULT_FILE=' | tail -n 1 | cut -d'=' -f2 | tr -s '/' || true) + + if [[ -n "$RESULT_FILE_PATH" && -f "$RESULT_FILE_PATH" ]]; then + RUN_ID=$(basename "$(dirname "$RESULT_FILE_PATH")") + RESULT_DIR=$(dirname "$RESULT_FILE_PATH") + RESULTS=$(cat "$RESULT_FILE_PATH") + + if [[ -n "$GCS_PATH" ]]; then + GCS_RESULTS_URL="${GCS_PATH}/${RUN_ID}" + echo "Uploading results to GCS..." + if gcloud storage rsync --recursive "$RESULT_DIR/" "$GCS_RESULTS_URL"; then + echo "GCS upload successful." + else + echo "Warning: GCS upload failed for RUN_ID $RUN_ID." + fi + fi + else + echo "Warning: Could not find result file for a successful run." + STATUS="WARNING_NO_RESULT_FILE" + fi + fi + + # Add the results back into the JSON object for this run + json_content=$(echo "$json_content" | jq --argjson i "$i" --arg run_id "$RUN_ID" --arg status "$STATUS" --arg results "$RESULTS" --arg gcs_results "$GCS_RESULTS_URL" \ + '.[$i] += {run_id: $run_id, status: $status, results: $results, gcs_results: $gcs_results}') + + RUN_END_TIME=$(date +%s) + echo "Run finished in $((RUN_END_TIME - RUN_START_TIME)) seconds. Status: $STATUS" + echo "--------------------------------------------------" + + # Save intermediate progress back to the file + echo "$json_content" > "$INPUT_JSON.tmp" && mv "$INPUT_JSON.tmp" "$INPUT_JSON" + +done + +SCRIPT_END_TIME=$(date +%s) +echo "All benchmark runs completed in $((SCRIPT_END_TIME - SCRIPT_START_TIME)) seconds." +echo +echo "====================== SUMMARY ======================" +echo "Successful runs: $SUCCESS_COUNT" +echo "Failed runs: $FAILURE_COUNT" +echo "===================================================" + +if [[ $FAILURE_COUNT -gt 0 ]]; then + echo "Details of failed runs (see JSON file for full parameters):" + for failed in "${FAILED_RUNS[@]}"; do + echo " - $failed" + done +fi + +echo "Updated results have been saved to '$INPUT_JSON'." diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py new file mode 100644 index 0000000000000000000000000000000000000000..a69637bfc437dd10079774a4943ca603dc9a2e20 --- /dev/null +++ b/benchmarks/backend_request_func.py @@ -0,0 +1,651 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import io +import json +import os +import sys +import time +import traceback +from dataclasses import dataclass, field + +import aiohttp +import huggingface_hub.constants +from tqdm.asyncio import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +# NOTE(simon): do not import vLLM here so the benchmark script +# can run without vLLM installed. + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + model_name: str | None = None + logprobs: int | None = None + extra_body: dict | None = None + multi_modal_content: dict | list[dict] | None = None + ignore_eos: bool = False + language: str | None = None + request_id: str | None = None + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0.0 + output_tokens: int = 0 + ttft: float = 0.0 # Time to first token + itl: list[float] = field(default_factory=list) # list of inter-token latencies + tpot: float = 0.0 # avg next-token latencies + prompt_len: int = 0 + error: str = "" + + +async def async_request_tgi( + request_func_input: RequestFuncInput, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: + params = { + "max_new_tokens": request_func_input.output_len, + "do_sample": True, + "temperature": 0.01, # TGI does not accept 0.0 temperature. + "top_p": 0.99, # TGI does not accept 1.0 top_p. + "truncate": request_func_input.prompt_len, + "ignore_eos_token": request_func_input.ignore_eos, + } + payload = { + "inputs": request_func_input.prompt, + "parameters": params, + } + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + if request_func_input.ignore_eos: + output.output_tokens = request_func_input.output_len + else: + output.output_tokens = None + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk_bytes = chunk_bytes.decode("utf-8") + + # NOTE: Sometimes TGI returns a ping response without + # any data, we should skip it. + if chunk_bytes.startswith(":"): + continue + chunk = chunk_bytes.removeprefix("data:") + + data = json.loads(chunk) + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + output.generated_text = data["generated_text"] + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + } + if request_func_input.ignore_eos: + payload["min_length"] = request_func_input.output_len + headers = None + if request_func_input.request_id: + headers = {"x-request-id": request_func_input.request_id} + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_deepspeed_mii( + request_func_input: RequestFuncInput, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith(("completions", "profile")), ( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) + + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "max_tokens": request_func_input.output_len, + "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. + "top_p": 1.0, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024, + # will use 0 as placeholder. + # See https://github.com/microsoft/DeepSpeed-MII/pull/311 + output.ttft = 0 + + st = time.perf_counter() + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + parsed_resp = await response.json() + output.latency = time.perf_counter() - st + if "choices" in parsed_resp: + output.generated_text = parsed_resp["choices"][0]["text"] + elif "text" in parsed_resp: + output.generated_text = parsed_resp["text"][0] + else: + output.error = ( + "Unexpected response format: " + "neither 'choices' nor 'text' found" + ) + output.success = False + output.success = True + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith(("completions", "profile")), ( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) + + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "repetition_penalty": 1.0, + "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + if usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!" + ) + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith(("chat/completions", "profile")), ( + "OpenAI Chat Completions API URL must end with 'chat/completions'." + ) + + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + mm_content = request_func_input.multi_modal_content + if isinstance(mm_content, list): + content.extend(mm_content) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "multi_modal_content must be a dict or list[dict] for openai-chat" + ) + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "messages": [ + {"role": "user", "content": content}, + ], + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk_bytes = chunk_bytes.decode("utf-8") + # NOTE: SSE comments (often used as pings) start with a colon. + # These are not JSON data payload and should be skipped. + if chunk_bytes.startswith(":"): + continue + + chunk = chunk_bytes.removeprefix("data: ") + + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + + api_url = request_func_input.api_url + assert api_url.endswith(("transcriptions", "translations")), ( + "OpenAI Chat Completions API URL must end with 'transcriptions' " + ) + "or `translations`." + + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", + # Flattened due to multipart/form-data + "stream_include_usage": True, + "stream_continuous_usage_stats": True, + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + mm_audio = request_func_input.multi_modal_content + if not isinstance(mm_audio, dict) or "audio" not in mm_audio: + raise TypeError("multi_modal_content must be a dict containing 'audio'") + with to_bytes(*mm_audio["audio"]) as f: + form = aiohttp.FormData() + form.add_field("file", f, content_type="audio/wav") + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, data=form, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp + ) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens" + ) + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true": + from modelscope import snapshot_download + + from vllm.model_executor.model_loader.weight_utils import get_lock + + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(pretrained_model_name_or_path): + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) + + return model_path + return pretrained_model_name_or_path + + +def get_tokenizer( + pretrained_model_name_or_path: str, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + **kwargs, +) -> PreTrainedTokenizer | PreTrainedTokenizerFast: + if pretrained_model_name_or_path is not None and not os.path.exists( + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + if tokenizer_mode == "mistral": + try: + from vllm.tokenizers.mistral import MistralTokenizer + except ImportError as e: + raise ImportError( + "MistralTokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use mistral tokenizer mode." + ) from e + return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path)) + else: + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + +ASYNC_REQUEST_FUNCS = { + "tgi": async_request_tgi, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "deepspeed-mii": async_request_deepspeed_mii, + "openai": async_request_openai_completions, + "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, + "tensorrt-llm": async_request_trt_llm, + "scalellm": async_request_openai_completions, + "sglang": async_request_openai_completions, + "llama.cpp": async_request_openai_completions, +} diff --git a/benchmarks/benchmark_batch_invariance.py b/benchmarks/benchmark_batch_invariance.py new file mode 100644 index 0000000000000000000000000000000000000000..7473a41e51406dcb5b3e1a9a1ccfce41f10573fb --- /dev/null +++ b/benchmarks/benchmark_batch_invariance.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark to measure the performance overhead of VLLM_BATCH_INVARIANT mode. + +This benchmark runs the same workload twice: +1. With VLLM_BATCH_INVARIANT=0 (baseline) +2. With VLLM_BATCH_INVARIANT=1 (batch invariant mode) + +And reports the timing and throughput metrics for comparison. + +Environment variables: + VLLM_BENCH_MODEL: Model to benchmark (default: "Qwen/Qwen3-1.7B") + VLLM_BENCH_TP_SIZE: Tensor parallel size (default: 1, use 8 for deepseek) + VLLM_BENCH_BATCH_SIZE: Max batch size (default: 128) + VLLM_BENCH_NUM_TRIALS: Number of trials to run (default: 5) + VLLM_BENCH_MIN_PROMPT: Min prompt length in words (default: 1024) + VLLM_BENCH_MAX_PROMPT: Max prompt length in words (default: 2048) + VLLM_BENCH_MAX_TOKENS: Max tokens to generate (default: 128) + VLLM_BENCH_TEMPERATURE: Temperature for sampling (default: 0.0) + VLLM_BENCH_GPU_MEMORY_UTILIZATION: GPU memory utilization (default: 0.4) + VLLM_BENCH_MAX_MODEL_LEN: Max model length (default: 5120) + VLLM_BENCH_BACKEND: Attention backend (default: FLASH_ATTN) + +Example usage: + # Benchmark qwen3 (default) + python benchmarks/benchmark_batch_invariance.py + + # Benchmark deepseek with 8 GPUs + VLLM_BENCH_MODEL="deepseek-ai/DeepSeek-V3" VLLM_BENCH_TP_SIZE=8 \\ + python benchmarks/benchmark_batch_invariance.py + + # Quick test with fewer trials + VLLM_BENCH_NUM_TRIALS=2 VLLM_BENCH_BATCH_SIZE=32 \\ + python benchmarks/benchmark_batch_invariance.py +""" + +import contextlib +import os +import random +import time + +from vllm import LLM, SamplingParams +from vllm.platforms import current_platform + + +def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: + """Generate a random prompt for benchmarking.""" + prompt_templates = [ + "Question: What is the capital of France?\nAnswer: The capital of France is", + "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which", + "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is", + "Once upon a time in a distant galaxy, there lived", + "The old man walked slowly down the street, remembering", + "In the year 2157, humanity finally discovered", + "To implement a binary search tree in Python, first we need to", + "The algorithm works by iterating through the array and", + "Here's how to optimize database queries using indexing:", + "The Renaissance was a period in European history that", + "Climate change is caused by several factors including", + "The human brain contains approximately 86 billion neurons which", + "I've been thinking about getting a new laptop because", + "Yesterday I went to the store and bought", + "My favorite thing about summer is definitely", + ] + + base_prompt = random.choice(prompt_templates) + + if max_words < min_words: + max_words = min_words + target_words = random.randint(min_words, max_words) + + if target_words > 50: + padding_text = ( + " This is an interesting topic that deserves more explanation. " + * (target_words // 50) + ) + base_prompt = base_prompt + padding_text + + return base_prompt + + +def run_benchmark_with_batch_invariant( + model: str, + tp_size: int, + max_batch_size: int, + num_trials: int, + min_prompt: int, + max_prompt: int, + max_tokens: int, + temperature: float, + gpu_mem_util: float, + max_model_len: int, + backend: str, + batch_invariant: bool, + seed: int = 12345, +) -> dict: + """ + Run the benchmark with the specified configuration. + + Returns a dict with timing and throughput metrics. + """ + random.seed(seed) + + # Set environment variables + if batch_invariant: + os.environ["VLLM_BATCH_INVARIANT"] = "1" + else: + os.environ["VLLM_BATCH_INVARIANT"] = "0" + + print(f"\n{'=' * 80}") + print(f"BENCHMARK: VLLM_BATCH_INVARIANT={int(batch_invariant)}") + print(f" Model: {model}") + print(f" TP Size: {tp_size}") + print(f" Backend: {backend}") + print(f" Max Batch Size: {max_batch_size}") + print(f" Trials: {num_trials}") + print(f" Max Tokens: {max_tokens}") + print(f"{'=' * 80}\n") + + sampling = SamplingParams( + temperature=temperature, + top_p=0.95, + max_tokens=max_tokens, + seed=20240919, + ) + + needle_prompt = "There once was a " + + llm = None + try: + # Create LLM engine + start_init = time.perf_counter() + llm = LLM( + model=model, + max_num_seqs=max_batch_size, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + dtype="bfloat16", + tensor_parallel_size=tp_size, + attention_config={"backend": backend}, + enable_prefix_caching=False, + ) + init_time = time.perf_counter() - start_init + print(f"Engine initialization time: {init_time:.2f}s\n") + + # Generate baseline + print("Generating baseline (warmup)...") + baseline_out = llm.generate([needle_prompt], sampling) + assert len(baseline_out) == 1 + baseline_text = baseline_out[0].outputs[0].text + print(f"Baseline output: '{baseline_text[:50]}...'\n") + + # Run trials and measure timing + trial_times: list[float] = [] + total_tokens = 0 + total_prompts = 0 + + for trial in range(num_trials): + # Create a batch + prompts: list[str] = [] + batch_size = random.randint(max_batch_size // 2, max_batch_size) + needle_pos = random.randint(0, batch_size - 1) + for i in range(batch_size): + if i == needle_pos: + prompts.append(needle_prompt) + else: + prompts.append(_random_prompt(min_prompt, max_prompt)) + + # Measure time for this trial + start_time = time.perf_counter() + outputs = llm.generate(prompts, sampling) + trial_time = time.perf_counter() - start_time + + trial_times.append(trial_time) + total_prompts += len(prompts) + + # Count tokens + for output in outputs: + if output.outputs: + total_tokens += len(output.outputs[0].token_ids) + + print( + f"Trial {trial + 1}/{num_trials}: " + f"batch_size={batch_size}, " + f"time={trial_time:.2f}s" + ) + + # Verify needle output still matches + needle_output = outputs[needle_pos] + assert needle_output.prompt == needle_prompt + + # Compute statistics + avg_time = sum(trial_times) / len(trial_times) + min_time = min(trial_times) + max_time = max(trial_times) + throughput = total_tokens / sum(trial_times) + prompts_per_sec = total_prompts / sum(trial_times) + + print(f"\n{'=' * 80}") + print("RESULTS:") + print(f" Average time per trial: {avg_time:.2f}s") + print(f" Min time: {min_time:.2f}s") + print(f" Max time: {max_time:.2f}s") + print(f" Total tokens generated: {total_tokens}") + print(f" Total prompts processed: {total_prompts}") + print(f" Throughput: {throughput:.2f} tokens/s") + print(f" Prompts/s: {prompts_per_sec:.2f}") + print(f"{'=' * 80}\n") + + return { + "init_time": init_time, + "avg_time": avg_time, + "min_time": min_time, + "max_time": max_time, + "total_tokens": total_tokens, + "total_prompts": total_prompts, + "throughput": throughput, + "prompts_per_sec": prompts_per_sec, + "trial_times": trial_times, + } + + finally: + # Cleanup + if llm is not None: + with contextlib.suppress(Exception): + llm.shutdown() + + +def main(): + # Check platform support + if not (current_platform.is_cuda() and current_platform.has_device_capability(90)): + print("ERROR: Requires CUDA and >= Hopper (SM90)") + print(f"Current platform: {current_platform.device_type}") + if current_platform.is_cuda(): + print(f"Device capability: {current_platform.get_device_capability()}") + return 1 + + # Read configuration from environment + model = os.getenv("VLLM_BENCH_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_BENCH_TP_SIZE", "1")) + max_batch_size = int(os.getenv("VLLM_BENCH_BATCH_SIZE", "128")) + num_trials = int(os.getenv("VLLM_BENCH_NUM_TRIALS", "5")) + min_prompt = int(os.getenv("VLLM_BENCH_MIN_PROMPT", "1024")) + max_prompt = int(os.getenv("VLLM_BENCH_MAX_PROMPT", "2048")) + max_tokens = int(os.getenv("VLLM_BENCH_MAX_TOKENS", "128")) + temperature = float(os.getenv("VLLM_BENCH_TEMPERATURE", "0.0")) + gpu_mem_util = float(os.getenv("VLLM_BENCH_GPU_MEMORY_UTILIZATION", "0.4")) + max_model_len = int(os.getenv("VLLM_BENCH_MAX_MODEL_LEN", "5120")) + backend = os.getenv("VLLM_BENCH_BACKEND", "FLASH_ATTN") + + print("\n" + "=" * 80) + print("VLLM BATCH INVARIANCE BENCHMARK") + print("=" * 80) + print("\nConfiguration:") + print(f" Model: {model}") + print(f" Tensor Parallel Size: {tp_size}") + print(f" Attention Backend: {backend}") + print(f" Max Batch Size: {max_batch_size}") + print(f" Number of Trials: {num_trials}") + print(f" Prompt Length Range: {min_prompt}-{max_prompt} words") + print(f" Max Tokens to Generate: {max_tokens}") + print(f" Temperature: {temperature}") + print(f" GPU Memory Utilization: {gpu_mem_util}") + print(f" Max Model Length: {max_model_len}") + print("=" * 80) + + # Run benchmark WITHOUT batch invariance (baseline) + print("\n" + "=" * 80) + print("PHASE 1: Running WITHOUT batch invariance (baseline)") + print("=" * 80) + baseline_results = run_benchmark_with_batch_invariant( + model=model, + tp_size=tp_size, + max_batch_size=max_batch_size, + num_trials=num_trials, + min_prompt=min_prompt, + max_prompt=max_prompt, + max_tokens=max_tokens, + temperature=temperature, + gpu_mem_util=gpu_mem_util, + max_model_len=max_model_len, + backend=backend, + batch_invariant=False, + ) + + # Run benchmark WITH batch invariance + print("\n" + "=" * 80) + print("PHASE 2: Running WITH batch invariance") + print("=" * 80) + batch_inv_results = run_benchmark_with_batch_invariant( + model=model, + tp_size=tp_size, + max_batch_size=max_batch_size, + num_trials=num_trials, + min_prompt=min_prompt, + max_prompt=max_prompt, + max_tokens=max_tokens, + temperature=temperature, + gpu_mem_util=gpu_mem_util, + max_model_len=max_model_len, + backend=backend, + batch_invariant=True, + ) + + # Compare results + print("\n" + "=" * 80) + print("COMPARISON: Batch Invariance vs Baseline") + print("=" * 80) + + init_overhead_pct = ( + (batch_inv_results["init_time"] - baseline_results["init_time"]) + / baseline_results["init_time"] + * 100 + ) + time_overhead_pct = ( + (batch_inv_results["avg_time"] - baseline_results["avg_time"]) + / baseline_results["avg_time"] + * 100 + ) + throughput_change_pct = ( + (batch_inv_results["throughput"] - baseline_results["throughput"]) + / baseline_results["throughput"] + * 100 + ) + + print("\nInitialization Time:") + print(f" Baseline: {baseline_results['init_time']:.2f}s") + print(f" Batch Invariant: {batch_inv_results['init_time']:.2f}s") + print(f" Overhead: {init_overhead_pct:+.2f}%") + + print("\nAverage Trial Time:") + print(f" Baseline: {baseline_results['avg_time']:.2f}s") + print(f" Batch Invariant: {batch_inv_results['avg_time']:.2f}s") + print(f" Overhead: {time_overhead_pct:+.2f}%") + + print("\nThroughput (tokens/s):") + print(f" Baseline: {baseline_results['throughput']:.2f}") + print(f" Batch Invariant: {batch_inv_results['throughput']:.2f}") + print(f" Change: {throughput_change_pct:+.2f}%") + + print("\nPrompts/s:") + print(f" Baseline: {baseline_results['prompts_per_sec']:.2f}") + print(f" Batch Invariant: {batch_inv_results['prompts_per_sec']:.2f}") + + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + if time_overhead_pct > 0: + print( + f"Batch invariance mode adds approximately {time_overhead_pct:.1f}% " + "overhead" + ) + else: + print( + f"Batch invariance mode is approximately {-time_overhead_pct:.1f}% " + "faster (unexpected!)" + ) + + if abs(throughput_change_pct) < 1.0: + print("Throughput difference is negligible (< 1%)") + elif throughput_change_pct < 0: + print( + f"Throughput decreased by {-throughput_change_pct:.1f}% " + "with batch invariance" + ) + else: + print( + f"Throughput increased by {throughput_change_pct:.1f}% " + "with batch invariance (unexpected!)" + ) + + print("=" * 80 + "\n") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/benchmarks/benchmark_block_pool.py b/benchmarks/benchmark_block_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..20cd26bdddf513c21f31853380d62583ac51980d --- /dev/null +++ b/benchmarks/benchmark_block_pool.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc + +from benchmark_utils import TimeCollector +from tabulate import tabulate + +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.v1.core.block_pool import BlockPool + + +def main(args): + rows = [] + for allocate_block in args.allocate_blocks: + # Enforce a GC collect ahead to minimize the impact among runs + gc.collect() + block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True) + + get_blocks_times = TimeCollector(TimeCollector.US) + free_blocks_times = TimeCollector(TimeCollector.US) + for _ in range(args.num_iteration): + with get_blocks_times: + blocks = block_pool.get_new_blocks(allocate_block) + with free_blocks_times: + block_pool.free_blocks(blocks) + + rows.append( + [get_blocks_times.cnt, args.num_gpu_blocks, allocate_block] + + get_blocks_times.dump_avg_max() + + free_blocks_times.dump_avg_max() + ) + + print( + tabulate( + rows, + headers=[ + "Iterations", + "Total\nBlocks", + "Allocated\nBlocks", + "Get Blocks\nAvg (us)", + "Get Blocks\nMax (us)", + "Free Blocks\nAvg (us)", + "Free Blocks\nMax (us)", + ], + tablefmt="grid", + floatfmt=".3f", + ) + ) + + +def invoke_main() -> None: + parser = FlexibleArgumentParser( + description="Benchmark the performance of BlockPool for KV Cache." + ) + parser.add_argument("--num-gpu-blocks", type=int, default=100000) + parser.add_argument( + "--num-iteration", + type=int, + default=1000, + help="Number of iterations to run to stabilize final data readings", + ) + parser.add_argument( + "--allocate-blocks", + type=int, + nargs="*", + default=[10, 50, 100, 500, 1000], + help="Number of blocks to allocate", + ) + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/benchmark_hash.py b/benchmarks/benchmark_hash.py new file mode 100644 index 0000000000000000000000000000000000000000..08cdc012d6527aa454b0000b7c1bccdc414b384a --- /dev/null +++ b/benchmarks/benchmark_hash.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Micro benchmark comparing built-in hash(), SHA-256, and xxHash. + +This focuses on a single test payload shaped like the prefix-cache hash input: + (32-byte bytes object, 32-int tuple) + +Usage: + python benchmarks/hash_micro_benchmark.py --iterations 20000 +""" + +from __future__ import annotations + +import argparse +import random +import statistics +import time +from collections.abc import Callable, Iterable + +from vllm.utils.hashing import sha256, xxhash + + +def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]: + """Generate a deterministic test payload.""" + random.seed(seed) + bytes_data = bytes(random.getrandbits(8) for _ in range(32)) + int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32)) + return (bytes_data, int_tuple) + + +def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int): + """Return (avg_seconds, std_seconds) for hashing `data` `iterations` times.""" + times: list[float] = [] + + # Warm-up to avoid first-run noise. + for _ in range(200): + func(data) + + for _ in range(iterations): + start = time.perf_counter() + func(data) + end = time.perf_counter() + times.append(end - start) + + avg = statistics.mean(times) + std = statistics.stdev(times) if len(times) > 1 else 0.0 + return avg, std + + +def _run_benchmarks( + benchmarks: Iterable[tuple[str, Callable[[tuple], object]]], + data: tuple, + iterations: int, +): + """Yield (name, avg, std) for each benchmark, skipping unavailable ones.""" + for name, func in benchmarks: + try: + avg, std = _benchmark_func(func, data, iterations) + except ModuleNotFoundError as exc: + print(f"Skipping {name}: {exc}") + continue + yield name, avg, std + + +def builtin_hash(data: tuple) -> int: + """Wrapper for Python's built-in hash().""" + return hash(data) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--iterations", + type=int, + default=10_000, + help="Number of measured iterations per hash function.", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for test payload." + ) + args = parser.parse_args() + + data = _generate_test_data(args.seed) + benchmarks = ( + ("SHA256 (pickle)", sha256), + ("xxHash (pickle)", xxhash), + ("built-in hash()", builtin_hash), + ) + + print("=" * 60) + print("HASH FUNCTION MICRO BENCHMARK") + print("=" * 60) + print("Test data: (32-byte bytes object, 32-int tuple)") + print(f"Iterations: {args.iterations:,}") + print("=" * 60) + + results = list(_run_benchmarks(benchmarks, data, args.iterations)) + builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None) + + print("\nResults:") + for name, avg, std in results: + print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs") + + if builtin_entry: + _, builtin_avg, _ = builtin_entry + print("\n" + "=" * 60) + print("SUMMARY (relative to built-in hash())") + print("=" * 60) + for name, avg, _ in results: + if name == "built-in hash()": + continue + speed_ratio = avg / builtin_avg + print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()") + else: + print("\nBuilt-in hash() result missing; cannot compute speed ratios.") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py new file mode 100644 index 0000000000000000000000000000000000000000..a7892f3f71243755a9d2cf59c1ad562e1878fda8 --- /dev/null +++ b/benchmarks/benchmark_latency.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys + +if __name__ == "__main__": + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench latency + +For help with the new command, run: + vllm bench latency --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench latency --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..f64fd09bab9fa7d57dfe5a1312bdcc6eb0f9292f --- /dev/null +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Offline benchmark to test the long document QA throughput. + +Example usage: + # This workload samples 8 different prompts with a default input + # length of 20000 tokens, then replicates each prompt 2 times + # in random order. + python benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --repeat-count 2 + +Commandline arguments: + --num-documents: The number of documents to sample prompts from. + + --document-length: The length of each document in tokens. + (Optional, default: 20000) + + --output-len: The number of tokens to generate for each prompt. + (Optional, default: 10) + + --repeat-count: The number of times to repeat each prompt. + (Optional, default: 2) + + --repeat-mode: The mode to repeat prompts. The supported modes are: + - 'random': shuffle the prompts randomly. (Default) + - 'tile': the entire prompt list is repeated in sequence. (Potentially + lowest cache hit) + - 'interleave': each prompt is repeated consecutively before + moving to the next element. (Highest cache hit) + + --shuffle-seed: Random seed when the repeat mode is "random". + (Optional, default: 0) + +In the meantime, it also supports all the vLLM engine args to initialize the +LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more +details. +""" + +import dataclasses +import random +import time + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def test_long_document_qa(llm=None, sampling_params=None, prompts=None): + """ + Test long document QA with the given prompts and sampling parameters. + Print the time spent in processing all the prompts. + + Args: + llm: The language model used for generating responses. + sampling_params: Sampling parameter used to generate the response. + prompts: A list of prompt strings to be processed by the LLM. + """ + start_time = time.time() + llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + print(f"Time to execute all requests: {end_time - start_time:.4f} secs") + + +def repeat_prompts(prompts, repeat_count, mode: str): + """ + Repeat each prompt in the list for a specified number of times. + The order of prompts in the output list depends on the mode. + + Args: + prompts: A list of prompts to be repeated. + repeat_count: The number of times each prompt is repeated. + mode: The mode of repetition. Supported modes are: + - 'random': Shuffle the prompts randomly after repetition. + - 'tile': Repeat the entire prompt list in sequence. + Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. + - 'interleave': Repeat each prompt consecutively before moving to + the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. + + Returns: + A list of repeated prompts in the specified order. + + Raises: + ValueError: If an invalid mode is provided. + """ + print("Repeat mode: ", mode) + if mode == "random": + repeated_prompts = prompts * repeat_count + random.shuffle(repeated_prompts) + return repeated_prompts + elif mode == "tile": + return prompts * repeat_count + elif mode == "interleave": + repeated_prompts = [] + for prompt in prompts: + repeated_prompts.extend([prompt] * repeat_count) + return repeated_prompts + else: + raise ValueError( + f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'" + ) + + +def main(args): + random.seed(args.shuffle_seed) + + # Prepare the prompts: + # we append the document id at the beginning to avoid any of the document + # being the prefix of other documents + prompts = [ + str(i) + " ".join(["hi"] * args.document_length) + for i in range(args.num_documents) + ] + + prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) + + warmup_prompts = [ + "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length) + for i in range(args.num_documents) + ] + + # Create the LLM engine + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + + print("------warm up------") + test_long_document_qa( + llm=llm, + prompts=warmup_prompts, + sampling_params=sampling_params, + ) + + print("------start generating------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + +def create_argument_parser(): + parser = FlexibleArgumentParser( + description="Benchmark the performance with or " + "without automatic prefix caching." + ) + + parser.add_argument( + "--document-length", + type=int, + # Roughly the number of tokens for a system paper, + # excluding images + default=20000, + help="Range of input lengths for sampling prompts, " + 'specified as "min:max" (e.g., "128:256").', + ) + + parser.add_argument( + "--num-documents", + type=int, + default=8, + help="Range of input lengths for sampling prompts, " + 'specified as "min:max" (e.g., "128:256").', + ) + + parser.add_argument("--output-len", type=int, default=10) + + parser.add_argument( + "--repeat-count", + type=int, + default=2, + help="Number of times to repeat each prompt", + ) + + parser.add_argument( + "--repeat-mode", + type=str, + default="random", + help="The mode to repeat prompts. The supported " + 'modes are "random", "tile", and "interleave". ' + "See repeat_prompts() in the source code for details.", + ) + + parser.add_argument( + "--shuffle-seed", + type=int, + default=0, + help='Random seed when the repeat mode is "random"', + ) + + parser = EngineArgs.add_cli_args(parser) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py new file mode 100644 index 0000000000000000000000000000000000000000..57a6c1aef5e78ee892a45d4267409c5d524ac4dd --- /dev/null +++ b/benchmarks/benchmark_ngram_proposer.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +import time +from unittest import mock + +import numpy as np +from benchmark_utils import TimeCollector +from tabulate import tabulate + +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.platforms import current_platform +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + +def benchmark_propose(args): + rows = [] + for max_ngram in args.max_ngram: + collector = TimeCollector(TimeCollector.US) + + model_config = ModelConfig( + model="facebook/opt-125m", + max_model_len=args.num_token + args.num_spec_token, + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + dtype="auto", + seed=0, + trust_remote_code=False, + ) + proposer = NgramProposer( + vllm_config=VllmConfig( + model_config=model_config, + speculative_config=SpeculativeConfig( + prompt_lookup_min=args.min_ngram, + prompt_lookup_max=max_ngram, + num_speculative_tokens=args.num_spec_token, + method="ngram", + ), + ) + ) + + # Warm up + proposer.propose(np.random.randint(0, 20, (args.num_token,))) + + gc.collect() + for _ in range(args.num_iteration): + tokens = np.random.randint(0, 20, (args.num_req, args.num_token)) + with collector: + for i in range(args.num_req): + proposer.propose(tokens[i, :]) + rows.append( + [args.num_req, args.num_token, args.min_ngram, max_ngram] + + collector.dump_avg_max() + ) + + print( + tabulate( + rows, + headers=[ + "# Request", + "# Token", + "Min Ngram", + "Max Ngram", + "Avg (us)", + "Max (us)", + ], + tablefmt="grid", + floatfmt=".3f", + ) + ) + + +def benchmark_batched_propose(args): + NUM_SPECULATIVE_TOKENS_NGRAM = 10 + PROMPT_LOOKUP_MIN = 5 + PROMPT_LOOKUP_MAX = 15 + MAX_MODEL_LEN = int(1e7) + DEVICE = current_platform.device_type + + model_config = ModelConfig(model="facebook/opt-125m", runner="generate") + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + method="ngram", + num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM, + prompt_lookup_max=PROMPT_LOOKUP_MAX, + prompt_lookup_min=PROMPT_LOOKUP_MIN, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), + ) + + # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + with mock.patch( + "vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group + ): + runner = GPUModelRunner(vllm_config, DEVICE) + + # hack max model len + runner.max_model_len = MAX_MODEL_LEN + runner.drafter.max_model_len = MAX_MODEL_LEN + + dummy_input_batch = InputBatch( + max_num_reqs=args.num_req, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=args.num_req * args.num_token, + device=DEVICE, + pin_memory=False, + vocab_size=256000, + block_sizes=[16], + ) + dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req)) + dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req + dummy_input_batch.token_ids_cpu = np.random.randint( + 0, 20, (args.num_req, args.num_token) + ) + + runner.input_batch = dummy_input_batch + + sampled_token_ids = [[0]] * args.num_req + + print("Starting benchmark") + # first run is warmup so ignore it + for _ in range(args.num_iteration): + start = time.time() + runner.drafter.propose( + sampled_token_ids, + dummy_input_batch.num_tokens_no_spec, + dummy_input_batch.token_ids_cpu, + ) + end = time.time() + print(f"Iteration time (s): {end - start}") + + +def invoke_main() -> None: + parser = FlexibleArgumentParser( + description="Benchmark the performance of N-gram speculative decode drafting" + ) + parser.add_argument( + "--batched", action="store_true", help="consider time to prepare batch" + ) + parser.add_argument( + "--num-iteration", + type=int, + default=100, + help="Number of iterations to run to stabilize final data readings", + ) + parser.add_argument( + "--num-req", type=int, default=128, help="Number of requests in the batch" + ) + parser.add_argument( + "--num-token", type=int, default=1500, help="Number of tokens for each request" + ) + parser.add_argument( + "--min-ngram", + type=int, + default=3, + help="Minimum n-gram to match", + ) + parser.add_argument( + "--max-ngram", + type=int, + nargs="*", + default=[5, 7, 10, 15, 20], + help="Maximum n-gram to match", + ) + parser.add_argument( + "--num-spec-token", + type=int, + default=3, + help="Number of speculative tokens to generate", + ) + args = parser.parse_args() + + if not args.batched: + benchmark_propose(args) + else: + benchmark_batched_propose(args) + + +""" +# Example command lines: +# time python3 benchmarks/benchmark_ngram_proposer.py +# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128 +""" # noqa: E501 +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/benchmarks/benchmark_prefix_block_hash.py b/benchmarks/benchmark_prefix_block_hash.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcd8af0d31022140a9ea82fd72896b87acae3d4 --- /dev/null +++ b/benchmarks/benchmark_prefix_block_hash.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Simple benchmark to compare prefix-cache block hashing algorithms. + +Example: + python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 +""" + +from __future__ import annotations + +import argparse +import random +import statistics +import sys +import time +from collections.abc import Callable, Iterable, Sequence + +from vllm.utils.hashing import get_hash_fn_by_name +from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash + +SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor") + + +def _generate_blocks( + num_blocks: int, block_size: int, vocab_size: int, seed: int +) -> list[list[int]]: + rng = random.Random(seed) + return [ + [rng.randrange(vocab_size) for _ in range(block_size)] + for _ in range(num_blocks) + ] + + +def _hash_all_blocks( + hash_fn: Callable[[object], bytes], + blocks: Iterable[Sequence[int]], +) -> float: + parent_hash: BlockHash | None = None + start = time.perf_counter() + for block in blocks: + parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None) + end = time.perf_counter() + return end - start + + +def _benchmark( + hash_algo: str, + blocks: list[list[int]], + trials: int, +) -> tuple[float, float, float] | None: + try: + hash_fn = get_hash_fn_by_name(hash_algo) + init_none_hash(hash_fn) + timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)] + except ModuleNotFoundError as exc: + print(f"Skipping {hash_algo}: {exc}", file=sys.stderr) + return None + + avg = statistics.mean(timings) + best = min(timings) + # throughput: tokens / second + tokens_hashed = len(blocks) * len(blocks[0]) + throughput = tokens_hashed / best + return avg, best, throughput + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.") + parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.") + parser.add_argument( + "--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)." + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument( + "--trials", type=int, default=5, help="Number of timed trials per algorithm." + ) + parser.add_argument( + "--algorithms", + nargs="+", + default=SUPPORTED_ALGOS, + choices=SUPPORTED_ALGOS, + help="Hash algorithms to benchmark.", + ) + args = parser.parse_args() + + blocks = _generate_blocks( + args.num_blocks, args.block_size, args.vocab_size, args.seed + ) + print( + f"Benchmarking {len(args.algorithms)} algorithms on " + f"{args.num_blocks} blocks (block size={args.block_size})." + ) + + for algo in args.algorithms: + result = _benchmark(algo, blocks, args.trials) + if result is None: + continue + + avg, best, throughput = result + print( + f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s " + f"throughput: {throughput / 1e6:.2f}M tokens/s" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py new file mode 100644 index 0000000000000000000000000000000000000000..e6391134ff9322022644e81673addca2fed66930 --- /dev/null +++ b/benchmarks/benchmark_prefix_caching.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark the efficiency of prefix caching. + +This script allows you to benchmark the performance of +a model with and without prefix caching using either fixed prompts +or prompts sampled from the ShareGPT dataset. + +Fixed example usage: + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 \ + --input-length-range 128:256 + +ShareGPT example usage: + # This command samples 20 prompts with input lengths + # between 128 and 256 tokens from the ShareGPT dataset, + # then replicates each prompt 5 times. + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \ + --enable-prefix-caching \ + --num-prompts 20 \ + --repeat-count 5 \ + --input-length-range 128:256 +""" + +import dataclasses +import json +import random +import time + +from transformers import PreTrainedTokenizerBase + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + +try: + from vllm.tokenizers import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + +PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 + + +def test_prefix(llm=None, sampling_params=None, prompts=None): + start_time = time.time() + + llm.generate(prompts, sampling_params=sampling_params) + + end_time = time.time() + print(f"cost time {end_time - start_time}") + + +@dataclasses.dataclass +class Request: + prompt: str + prompt_len: int + output_len: int + + +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]: + vocab = tokenizer.get_vocab() + all_special_ids = set(tokenizer.all_special_ids) + + # Remove the special tokens. + return random.choices( + [v for v in vocab.values() if v not in all_special_ids], + k=length, + ) + + +def sample_requests_from_dataset( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: tuple[int, int], + fixed_output_len: int | None, +) -> list[Request]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + min_len, max_len = input_length_range + assert min_len >= 0 and max_len >= min_len, "input_length_range too small" + + # Filter out sequences that are too long or too short + filtered_requests: list[Request] = [] + + for i in range(len(dataset)): + if len(filtered_requests) == num_requests: + break + + # Tokenize the prompts and completions. + prompt_token_ids = tokenizer(dataset[i][0]).input_ids + prompt = tokenizer.decode(prompt_token_ids) + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if min_len <= prompt_len <= max_len: + filtered_requests.append(Request(prompt, prompt_len, output_len)) + + return filtered_requests + + +def sample_requests_from_random( + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: tuple[int, int], + fixed_output_len: int | None, + prefix_len: int, +) -> list[Request]: + requests = [] + prefix_token_ids = sample_tokens(tokenizer, prefix_len) + min_len, max_len = input_length_range + + for i in range(num_requests): + unique_part_token_ids = sample_tokens( + tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len) + ) + prompt_token_ids = prefix_token_ids + unique_part_token_ids + prompt = tokenizer.decode(prompt_token_ids) + prompt_len = len(prompt_token_ids) + assert min_len <= prompt_len <= max_len, ( + f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + ) + requests.append(Request(prompt, prompt_len, fixed_output_len)) + return requests + + +def repeat_and_sort_requests( + requests: list[Request], repeat_count: int, sort: bool = False +) -> list[str]: + repeated_requests = requests * repeat_count + if sort: + repeated_requests.sort(key=lambda x: x[1]) + else: + random.shuffle(repeated_requests) + return [req.prompt for req in repeated_requests] + + +def main(args): + tokenizer = get_tokenizer(args.model, trust_remote_code=True) + input_length_range = tuple(map(int, args.input_length_range.split(":"))) + random.seed(args.seed) + if args.dataset_path is not None: + if args.prefix_len > 0: + raise ValueError( + "prefix-len is not supported when dataset-path is provided." + ) + print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}") + filtered_requests = sample_requests_from_dataset( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + ) + else: + print(f"Start to sample {args.num_prompts} prompts from random") + filtered_requests = sample_requests_from_random( + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + prefix_len=args.prefix_len, + ) + + # Print some helpful stats of the requests. + print(f"Sampled {len(filtered_requests)} requests.") + prompt_lens = [req.prompt_len for req in filtered_requests] + print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}") + print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}") + print(f"Min Prompt Length: {min(prompt_lens)}") + print(f"Max Prompt Length: {max(prompt_lens)}") + + engine_args = EngineArgs.from_cli_args(args) + + llm = LLM(**dataclasses.asdict(engine_args)) + + sampling_params = SamplingParams( + temperature=0, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) + + print("Testing filtered requests") + prompts = repeat_and_sort_requests( + filtered_requests, repeat_count=args.repeat_count, sort=args.sort + ) + + print("------start generating------") + test_prefix( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + +def create_argument_parser(): + parser = FlexibleArgumentParser( + description="Benchmark the performance with or without " + "automatic prefix caching." + ) + parser.add_argument( + "--dataset-path", type=str, default=None, help="Path to the dataset." + ) + parser.add_argument("--output-len", type=int, default=10) + parser.add_argument( + "--num-prompts", + type=int, + required=True, + help="Number of the prompts sampled from dataset", + ) + parser.add_argument( + "--repeat-count", + type=int, + default=1, + help="Number of times to repeat each prompt", + ) + parser.add_argument( + "--sort", action="store_true", help="Sort prompts by input length" + ) + parser.add_argument( + "--input-length-range", + type=str, + required=True, + help="Range of input lengths for sampling prompts," + 'specified as "min:max" (e.g., "128:256").', + ) + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Specifies the length of a common prefix to be " + "added to the input prompt. The input-length-range will " + "subtract this length when filtering prompts. Only used " + "when dataset-path is not provided.", + ) + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), + ) + + parser = EngineArgs.add_cli_args(parser) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py new file mode 100644 index 0000000000000000000000000000000000000000..a35db0063b0ae245f2022af198f80c673c700512 --- /dev/null +++ b/benchmarks/benchmark_prioritization.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark offline prioritization.""" + +import argparse +import dataclasses +import json +import random +import time + +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.engine.arg_utils import EngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +# Select a equi-probable random priority +def get_random_flag(): + return 0 if random.random() < 0.5 else 1 + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: int | None, +) -> list[tuple[str, int, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: list[tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + priority = get_random_flag() + + filtered_dataset.append((prompt, prompt_len, output_len, priority)) + + return filtered_dataset + + +def run_vllm( + requests: list[tuple[str, int, int]], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> float: + from vllm import LLM, SamplingParams + + llm = LLM(**dataclasses.asdict(engine_args)) + + assert all( + llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " input_len and output_len for all requests." + ) + + # Add the requests to the engine. + prompts = [] + sampling_params = [] + priority = [] + for prompt, _, output_len, _priority in requests: + prompts.append(prompt) + priority.append(_priority) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=output_len, + detokenize=not disable_detokenize, + ) + ) + + start = time.perf_counter() + llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) + end = time.perf_counter() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code + ) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [ + (prompt, args.input_len, args.output_len, get_random_flag()) + for _ in range(args.num_prompts) + ] + else: + requests = sample_requests( + args.dataset, args.num_prompts, tokenizer, args.output_len + ) + + if args.backend == "vllm": + elapsed_time = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize + ) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum( + prompt_len + output_len for _, prompt_len, output_len, priority in requests + ) + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s" + ) + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +def create_argument_parser(): + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument( + "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm" + ) + parser.add_argument( + "--dataset", type=str, default=None, help="Path to the dataset." + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) + parser.add_argument( + "--num-prompts", type=int, default=200, help="Number of prompts to process." + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save the throughput results in JSON format.", + ) + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), + ) + + parser = EngineArgs.add_cli_args(parser) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + main(args) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py new file mode 100644 index 0000000000000000000000000000000000000000..76cf51498020b2581157527acc38987e75e242aa --- /dev/null +++ b/benchmarks/benchmark_serving.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys + +if __name__ == "__main__": + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench serve + +For help with the new command, run: + vllm bench serve --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench serve --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py new file mode 100644 index 0000000000000000000000000000000000000000..33aca831883aac7dfca8cbd6baa128630b935679 --- /dev/null +++ b/benchmarks/benchmark_serving_structured_output.py @@ -0,0 +1,1040 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +r"""Benchmark online serving throughput with structured outputs. + +On the server side, run one of the following commands: + (vLLM OpenAI API server) + vllm serve + +On the client side, run: + python benchmarks/benchmark_serving_structured_output.py \ + --backend \ + --model \ + --dataset json \ + --structured-output-ratio 1.0 \ + --request-rate 10 \ + --num-prompts 1000 + + when using tgi backend, add + --endpoint /generate_stream + to the end of the command above. +""" + +import argparse +import asyncio +import copy +import dataclasses +import json +import os +import random +import time +import uuid +import warnings +from collections.abc import AsyncGenerator +from contextlib import nullcontext +from dataclasses import dataclass + +import datasets +import numpy as np +import pandas as pd +from backend_request_func import ( + ASYNC_REQUEST_FUNCS, + RequestFuncInput, + RequestFuncOutput, +) +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +try: + from vllm.tokenizers import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + +try: + from vllm.utils.argparse_utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + +from vllm.v1.structured_output.backend_xgrammar import ( + has_xgrammar_unsupported_json_features, +) + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + request_throughput: float + request_goodput: float + output_throughput: float + total_token_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + percentiles_tpot_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + + +@dataclasses.dataclass +class SampleRequest: + """A class representing a single inference request for benchmarking. + + Attributes: + prompt: The input text prompt for the model. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + prompt_len: The length of the prompt in tokens. + expected_output_len: The expected length of the output in tokens. + """ + + prompt: str + prompt_len: int + expected_output_len: int + schema: dict + structure_type: str + completion: str = None + + +def sample_requests( + tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace +) -> list[SampleRequest]: + if args.dataset == "json" or args.dataset == "json-unique": + if args.json_schema_path is None: + dir_path = os.path.dirname(os.path.realpath(__file__)) + args.json_schema_path = os.path.join( + dir_path, "structured_schemas", "structured_schema_1.json" + ) + json_schemas = [] + with open(args.json_schema_path) as f: + schema = json.load(f) + + if args.dataset == "json-unique": + json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)] + for i in range(len(json_schemas)): + if "properties" not in json_schemas[i]: + json_schemas[i]["properties"] = {} + json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = { + "type": "string", + "description": "An unique optional field to avoid cached schemas", + } + else: + json_schemas = [schema] * args.num_prompts + + def gen_prompt(index: int): + return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501 + + def get_schema(index: int): + return json_schemas[index % len(json_schemas)] + + requests = [ + SampleRequest( + prompt=gen_prompt(i), + prompt_len=len(tokenizer(gen_prompt(i)).input_ids), + expected_output_len=args.output_len, + schema=get_schema(i), + structure_type=args.structure_type, + ) + for i in range(args.num_prompts) + ] + + elif args.dataset == "grammar": + schema = """ + root ::= select_statement + + select_statement ::= "SELECT " column " from " table " where " condition + + column ::= "col_1 " | "col_2 " + + table ::= "table_1 " | "table_2 " + + condition ::= column "= " number + + number ::= "1 " | "2 " + """ + prompt = "Generate an SQL query to show the 'username' \ + and 'email' from the 'users' table." + + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + ) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "regex": + regex = r"\w+@\w+\.com\n" + args.regex = regex + prompt = "Generate an email address for Alan Turing, \ + who works in Enigma. End in .com and new line. \ + Example result: alan.turing@enigma.com\n" + + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=regex, + structure_type=args.structure_type, + ) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "choice": + choice = ["Positive", "Negative"] + args.choice = choice + prompt = "Classify this sentiment: vLLM is wonderful!" + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=choice, + structure_type=args.structure_type, + ) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "xgrammar_bench": + requests: list[SampleRequest] = [] + dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train") + full_dataset_len = len(dataset) + + def _filter_func(item): + import json + + schema = json.loads(item["schema"]) + return not has_xgrammar_unsupported_json_features(schema) + + dataset = dataset.filter(_filter_func) + num_filtered_out = full_dataset_len - len(dataset) + print( + f"dataset has {len(dataset)} entries after filtering " + f"out {num_filtered_out} entries with unsupported features" + ) + len_dataset = len(dataset) + for data_point_idx in range(args.num_prompts): + idx = data_point_idx + while idx >= len_dataset: + idx -= len_dataset + schema = dataset["schema"][idx] + prompt = tokenizer.apply_chat_template( + dataset["prompt"][idx], tokenize=False, add_generation_prompt=True + ) + input_len = len(tokenizer(prompt).input_ids) + completion = dataset["completion"][idx] + + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + completion=completion, + ) + ) + + return requests + + +async def get_request( + input_requests: list[SampleRequest], + request_rate: float, + burstiness: float = 1.0, +) -> AsyncGenerator[tuple[int, SampleRequest], None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a tuple. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ + input_requests = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}." + ) + theta = 1.0 / (request_rate * burstiness) + + for i, request in enumerate(input_requests): + yield i, request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: list[tuple[str, int, int]], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + goodput_config_dict: dict[str, float] | None = None, +) -> tuple[BenchmarkMetrics, list[int]]: + actual_output_lens: list[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + # We use the tokenizer to count the number of output tokens for all + # serving backends instead of looking at len(outputs[i].itl) since + # multiple output tokens may be bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) + actual_output_lens.append(output_len) + total_input += input_requests[i].prompt_len + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + outputs[i].tpot = tpot + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], + ) + + return metrics, actual_output_lens + + +async def benchmark( + backend: str, + api_url: str, + base_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: list[SampleRequest], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[str], + ignore_eos: bool, + max_concurrency: int | None, + structured_output_ratio: float, + goodput_config_dict: dict[str, float] | None = None, +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + def prepare_extra_body(request) -> dict: + extra_body = {} + # Add the schema to the extra_body + extra_body["structured_outputs"] = {} + extra_body["structured_outputs"][request.structure_type] = request.schema + return extra_body + + print("Starting initial single prompt test run...") + structured_output_req_idx = random.sample( + range(len(input_requests)), int(len(input_requests) * structured_output_ratio) + ) + + test_request = input_requests[0] + test_req_extra_body = ( + prepare_extra_body(test_request) if 0 in structured_output_req_idx else None + ) + test_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=api_url, + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=test_req_extra_body, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=base_url + "/start_profile", + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=test_req_extra_body, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" + + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else nullcontext() + + async def limited_request_func(request_func_input, pbar): + async with semaphore: + return await request_func(request_func_input=request_func_input, pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + expected: list[str] = [] + async for i, request in get_request(input_requests, request_rate, burstiness): + extra_body = ( + prepare_extra_body(request) if i in structured_output_req_idx else None + ) + request_func_input = RequestFuncInput( + model=model_id, + prompt=request.prompt, + api_url=api_url, + prompt_len=request.prompt_len, + output_len=request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) + expected.append(request.completion) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentile_metrics=selected_percentile_metrics, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + if max_concurrency is not None: + print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency)) + if request_rate != float("inf"): + print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + if goodput_config_dict: + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total token throughput (tok/s):", metrics.total_token_throughput + ) + ) + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "ttft_description": pd.Series([output.ttft for output in outputs]) + .describe() + .to_dict(), + "tpot_description": pd.Series([output.tpot for output in outputs]) + .describe() + .to_dict(), + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "errors": [output.error for output in outputs], + } + + ret = [ + {"generated": output.generated_text, "expected": gt} + for output, gt in zip(outputs, expected) + ] + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms" + ) + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms" + ) + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + extra_body={test_request.structure_type: test_request.schema}, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + + return result, ret + + +def evaluate(ret, args): + def _eval_correctness_json(expected, actual): + # extract json string from string using regex + import regex as re + + actual = actual.replace("\n", "").replace(" ", "").strip() + try: + actual = re.search(r"\{.*\}", actual).group() + actual = json.loads(actual) + except Exception: + return False + + return True + + def _eval_correctness_choice(expected, actual): + return actual in args.choice + + def _eval_correctness_regex(expected, actual): + import regex as re + + return re.match(args.regex, actual) is not None + + def _eval_correctness(expected, actual): + if args.structure_type == "json": + return _eval_correctness_json(expected, actual) + elif args.structure_type == "regex": + return _eval_correctness_regex(expected, actual) + elif args.structure_type == "choice": + return _eval_correctness_choice(expected, actual) + else: + return None + + scores = [] + for res in ret: + score = _eval_correctness(res["expected"], res["generated"]) + res["correctness"] = score + scores.append(score) + + not_none_scores = [score for score in scores if score is not None] + + return ( + (sum(not_none_scores) / len(not_none_scores) * 100) + if len(not_none_scores) > 0 + else None + ) + + +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + 'Specify service level objectives for goodput as "KEY:VALUE" ' + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds." + ) from err + return goodput_config_dict + + +def check_goodput_args(args): + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. " + ) + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative." + ) + return goodput_config_dict + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + tokenizer = get_tokenizer( + tokenizer_id, + trust_remote_code=args.trust_remote_code, + tokenizer_mode=args.tokenizer_mode, + ) + + if args.dataset == "grammar": + args.structure_type = "grammar" + elif args.dataset == "regex": + args.structure_type = "regex" + elif args.dataset == "choice": + args.structure_type = "choice" + else: + args.structure_type = "json" + + if args.no_structured_output: + args.structured_output_ratio = 0 + if args.save_results: + result_file_name = f"{args.structured_output_ratio}so" + result_file_name += f"_{backend}" + result_file_name += f"_{args.request_rate}qps" + result_file_name += f"_{args.model.split('/')[-1]}" + result_file_name += f"_{args.dataset}" + result_file_name += f"_{args.num_prompts}" + result_file_name += f"_out{args.output_len}" + result_file_name += ".txt" + else: + result_file_name = None + + input_requests = sample_requests(tokenizer, args) + + goodput_config_dict = check_goodput_args(args) + + benchmark_result, ret = asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], + ignore_eos=args.ignore_eos, + max_concurrency=args.max_concurrency, + structured_output_ratio=args.structured_output_ratio, + goodput_config_dict=goodput_config_dict, + ) + ) + + # Save config and results to json + score = evaluate(ret, args) + print("correct_rate(%)", score, "\n") + if args.save_results: + results = { + "backend": backend, + "model_id": model_id, + "tokenizer_id": tokenizer_id, + "num_prompts": args.num_prompts, + "request_rate": args.request_rate + if args.request_rate < float("inf") + else "inf", + "burstiness": args.burstiness, + "max_concurrency": args.max_concurrency, + "correct_rate(%)": score, + } + results = {"outputs": ret, **results, **benchmark_result} + + # Save to file + if args.result_filename: + result_file_name = args.result_filename + if args.result_dir: + result_file_name = os.path.join(args.result_dir, result_file_name) + with open(result_file_name, "w", encoding="utf-8") as outfile: + json.dump(results, outfile, indent=4) + + +def create_argument_parser(): + parser = FlexibleArgumentParser( + description="Benchmark the online serving throughput." + ) + parser.add_argument( + "--backend", + type=str, + default="vllm", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument( + "--dataset", + default="json", + choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"], + ) + parser.add_argument( + "--json-schema-path", type=str, default=None, help="Path to json schema." + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.", + ) + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer, if not using the default tokenizer.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default="auto", + help="Name or path of the tokenizer, if not using the default tokenizer.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--output-len", + type=int, + default=128, + help="Number of output tokens.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--save-results", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use vLLM Profiling. --profiler-config must be provided on the server.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-separated list of selected metrics to report percentiles. " + "This argument specifies the metrics to report percentiles. " + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'Default value is "ttft,tpot,itl".', + ) + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-separated list of percentiles for selected metrics. " + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99". ' + 'Use "--percentile-metrics" to select metrics.', + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help='Specify service level objectives for goodput as "KEY:VALUE" ' + "pairs, where the key is a metric name, and the value is in " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' + "separated by spaces. Allowed request level metric names are " + '"ttft", "tpot", "e2el". For more context on the definition of ' + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + + parser.add_argument( + "--no-structured-output", + action="store_true", + default=False, + help="Whether to disable JSON decoding or not.", + ) + parser.add_argument( + "--structured-output-ratio", + type=float, + default=1.0, + help="Ratio of Structured Outputs requests", + ) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..b6dc0918fd4d1a3001241e84048344568aa15e16 --- /dev/null +++ b/benchmarks/benchmark_throughput.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys + +if __name__ == "__main__": + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench throughput + +For help with the new command, run: + vllm bench throughput --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench throughput --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_topk_topp.py b/benchmarks/benchmark_topk_topp.py new file mode 100644 index 0000000000000000000000000000000000000000..cac332a099d8b4a1ac0755efd0c87035b81e1893 --- /dev/null +++ b/benchmarks/benchmark_topk_topp.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark comparing Triton vs PyTorch sort-based top-k/top-p implementations. + +Compares: +- apply_top_k_top_p_triton (Triton binary search) +- apply_top_k_top_p (PyTorch sort-based) + +Scenarios: +- top_k only (whole batch, partial batch) +- top_p only (whole batch, partial batch) +- mix of top_k and top_p +""" + +import argparse +import gc +from dataclasses import dataclass + +import torch + +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch +from vllm.v1.sample.ops.topk_topp_triton import ( + apply_top_k_top_p_triton, + reset_buffer_cache, +) + + +@dataclass +class BenchmarkConfig: + """Configuration for a benchmark run.""" + + name: str + batch_size: int + vocab_size: int + # k and p can be tensors or None + k_values: torch.Tensor | None # [batch_size] or None + p_values: torch.Tensor | None # [batch_size] or None + description: str + ops_pct: float = 0.0 # Percentage of ops relative to batch size + + +def calculate_ops_pct( + k_values: torch.Tensor | None, + p_values: torch.Tensor | None, + vocab_size: int, + batch_size: int, +) -> float: + """ + Calculate the percentage of active top-k and top-p operations. + + Returns percentage where 100% = batch_size ops. + E.g., if all rows have both top-k and top-p active, returns 200%. + """ + active_ops = 0 + + if k_values is not None: + # Count rows where k < vocab_size (active top-k filtering) + active_ops += (k_values < vocab_size).sum().item() + + if p_values is not None: + # Count rows where p < 1.0 (active top-p filtering) + active_ops += (p_values < 1.0).sum().item() + + return (active_ops / batch_size) * 100 if batch_size > 0 else 0.0 + + +def create_logits( + batch_size: int, vocab_size: int, device: str = "cuda" +) -> torch.Tensor: + """Create random logits mimicking a realistic LLM distribution. + + Uses a Zipf-like probability distribution (rank^-1.1) converted to logits + via log, then randomly permuted per row. This produces a peaked distribution + where a small number of tokens capture most probability mass, similar to + real model outputs. + """ + # Create Zipf-like probabilities: p(rank) ~ rank^(-alpha) + ranks = torch.arange(1, vocab_size + 1, dtype=torch.float32, device=device) + probs = ranks.pow(-1.1) + probs = probs / probs.sum() + + # Convert to logits (log-probabilities, unnormalized is fine) + base_logits = probs.log() + + # Broadcast to batch and randomly permute each row + logits = base_logits.unsqueeze(0).expand(batch_size, -1).clone() + for i in range(batch_size): + logits[i] = logits[i, torch.randperm(vocab_size, device=device)] + + return logits + + +def measure_memory() -> tuple[int, int]: + """Return (allocated, reserved) memory in bytes.""" + torch.cuda.synchronize() + return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated() + + +def reset_memory_stats(): + """Reset peak memory statistics.""" + reset_buffer_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + gc.collect() + + +def benchmark_function( + func, + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + warmup_iters: int = 5, + benchmark_iters: int = 20, +) -> tuple[float, int]: + """ + Benchmark a function and return (avg_time_ms, peak_memory_bytes). + + Returns average time in milliseconds and peak memory usage. + """ + # Warmup + for _ in range(warmup_iters): + logits_copy = logits.clone() + func(logits_copy, k, p) + torch.cuda.synchronize() + + # Reset memory stats before benchmark + reset_memory_stats() + + # Benchmark + start_events = [ + torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters) + ] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)] + + for i in range(benchmark_iters): + logits_copy = logits.clone() + start_events[i].record() + func(logits_copy, k, p) + end_events[i].record() + + torch.cuda.synchronize() + + # Calculate timing + times = [ + start_events[i].elapsed_time(end_events[i]) for i in range(benchmark_iters) + ] + avg_time = sum(times) / len(times) + + # Get peak memory + _, peak_memory = measure_memory() + + return avg_time, peak_memory + + +def create_benchmark_configs( + batch_sizes: list[int], + vocab_sizes: list[int], + device: str = "cuda", +) -> list[BenchmarkConfig]: + """Create all benchmark configurations.""" + configs = [] + + for vocab_size in vocab_sizes: + for batch_size in batch_sizes: + # 1. Top-k only - whole batch (all rows have k < vocab_size) + k_all = torch.full((batch_size,), 50, dtype=torch.int32, device=device) + configs.append( + BenchmarkConfig( + name=f"topk_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_all, + p_values=None, + description=f"Top-k only (whole batch, k=50), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_all, None, vocab_size, batch_size), + ) + ) + + # 2. Top-k only - partial batch (half have k=50, half have k=vocab_size) + k_partial = torch.full((batch_size,), 50, dtype=torch.int32, device=device) + k_partial[batch_size // 2 :] = vocab_size # No filtering for second half + configs.append( + BenchmarkConfig( + name=f"topk_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_partial, + p_values=None, + description=f"Top-k only (partial batch, 50% k=50, 50% k=vocab), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_partial, None, vocab_size, batch_size), + ) + ) + + # 3. Top-p only - whole batch (all rows have p < 1.0) + p_all = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device) + configs.append( + BenchmarkConfig( + name=f"topp_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=None, + p_values=p_all, + description=f"Top-p only (whole batch, p=0.9), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(None, p_all, vocab_size, batch_size), + ) + ) + + # 4. Top-p only - partial batch (half have p=0.9, half have p=1.0) + p_partial = torch.full( + (batch_size,), 0.9, dtype=torch.float32, device=device + ) + p_partial[batch_size // 2 :] = 1.0 # No filtering for second half + configs.append( + BenchmarkConfig( + name=f"topp_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=None, + p_values=p_partial, + description=f"Top-p only (partial batch, 50% p=0.9, 50% p=1.0), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(None, p_partial, vocab_size, batch_size), + ) + ) + + # 5. Mix of top-k and top-p (both applied to whole batch) + k_mix = torch.full((batch_size,), 100, dtype=torch.int32, device=device) + p_mix = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device) + configs.append( + BenchmarkConfig( + name=f"topk_topp_whole_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_mix, + p_values=p_mix, + description=f"Top-k + Top-p (whole batch, k=100, p=0.9), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_mix, p_mix, vocab_size, batch_size), + ) + ) + + # 6. Mix with partial application (some rows k only, some p only, some both) + k_mixed = torch.full( + (batch_size,), vocab_size, dtype=torch.int32, device=device + ) + p_mixed = torch.full((batch_size,), 1.0, dtype=torch.float32, device=device) + # First third: k only + third = batch_size // 3 + k_mixed[:third] = 50 + # Second third: p only + p_mixed[third : 2 * third] = 0.5 + # Last third: both k and p + k_mixed[2 * third :] = 100 + p_mixed[2 * third :] = 0.9 + configs.append( + BenchmarkConfig( + name=f"mixed_partial_b{batch_size}_v{vocab_size // 1000}k", + batch_size=batch_size, + vocab_size=vocab_size, + k_values=k_mixed, + p_values=p_mixed, + description=f"Mixed partial (1/3 k=50, 1/3 p=0.9, 1/3 both), " + f"batch={batch_size}, vocab={vocab_size}", + ops_pct=calculate_ops_pct(k_mixed, p_mixed, vocab_size, batch_size), + ) + ) + + return configs + + +def format_memory(bytes_val: int) -> str: + """Format memory in human-readable form.""" + if bytes_val >= 1024**3: + return f"{bytes_val / (1024**3):.2f} GB" + elif bytes_val >= 1024**2: + return f"{bytes_val / (1024**2):.2f} MB" + elif bytes_val >= 1024: + return f"{bytes_val / 1024:.2f} KB" + return f"{bytes_val} B" + + +def run_benchmark( + configs: list[BenchmarkConfig], + warmup_iters: int = 5, + benchmark_iters: int = 20, + verbose: bool = True, +): + """Run all benchmarks and print results.""" + results = [] + + print("=" * 100) + print("Top-k/Top-p Benchmark: Triton vs PyTorch Sort-based") + print("=" * 100) + print() + + for config in configs: + if verbose: + print(f"Running: {config.description}") + + # Create fresh logits for this config + logits = create_logits(config.batch_size, config.vocab_size) + + # Benchmark Triton + reset_memory_stats() + triton_time, triton_mem = benchmark_function( + apply_top_k_top_p_triton, + logits, + config.k_values, + config.p_values, + warmup_iters, + benchmark_iters, + ) + + # Benchmark PyTorch + reset_memory_stats() + pytorch_time, pytorch_mem = benchmark_function( + apply_top_k_top_p_pytorch, + logits, + config.k_values, + config.p_values, + warmup_iters, + benchmark_iters, + ) + + speedup = pytorch_time / triton_time if triton_time > 0 else float("inf") + mem_ratio = pytorch_mem / triton_mem if triton_mem > 0 else float("inf") + + result = { + "config": config, + "triton_time_ms": triton_time, + "pytorch_time_ms": pytorch_time, + "triton_mem": triton_mem, + "pytorch_mem": pytorch_mem, + "speedup": speedup, + "mem_ratio": mem_ratio, + } + results.append(result) + + if verbose: + print(f" Triton: {triton_time:.3f} ms, {format_memory(triton_mem)}") + print(f" PyTorch: {pytorch_time:.3f} ms, {format_memory(pytorch_mem)}") + print(f" Speedup: {speedup:.2f}x, Memory ratio: {mem_ratio:.2f}x") + print() + + # Clean up + del logits + reset_memory_stats() + + return results + + +def print_summary_table(results: list[dict]): + """Print a summary table of results.""" + print() + print("=" * 130) + print("SUMMARY TABLE") + print("=" * 130) + print() + + # Header + header = ( + f"{'Scenario':<40} {'Batch':>6} {'Vocab':>7} {'Ops%':>6} " + f"{'Triton (ms)':>12} {'PyTorch (ms)':>13} {'Speedup':>8} " + f"{'Tri Mem':>10} {'Pyt Mem':>10}" + ) + print(header) + print("-" * 130) + + # Group by scenario type + current_vocab = None + for result in results: + config = result["config"] + + # Add separator between vocab sizes + if current_vocab != config.vocab_size: + if current_vocab is not None: + print("-" * 130) + current_vocab = config.vocab_size + + scenario = config.name.split("_b")[0] # Extract scenario name + print( + f"{scenario:<40} {config.batch_size:>6} {config.vocab_size:>7} " + f"{config.ops_pct:>5.0f}% " + f"{result['triton_time_ms']:>12.3f} {result['pytorch_time_ms']:>13.3f} " + f"{result['speedup']:>7.2f}x " + f"{format_memory(result['triton_mem']):>10} " + f"{format_memory(result['pytorch_mem']):>10}" + ) + + print("=" * 130) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark Triton vs PyTorch sort-based top-k/top-p implementations" + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=[1, 4, 16, 64, 128, 512, 1024, 2048], + help="Batch sizes to test (default: 1 4 16 64)", + ) + parser.add_argument( + "--vocab-sizes", + type=int, + nargs="+", + default=[32768, 131072], # 32k, 128k + help="Vocabulary sizes to test (default: 32768 131072)", + ) + parser.add_argument( + "--warmup-iters", + type=int, + default=5, + help="Number of warmup iterations (default: 5)", + ) + parser.add_argument( + "--benchmark-iters", + type=int, + default=20, + help="Number of benchmark iterations (default: 20)", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Only print summary table", + ) + + args = parser.parse_args() + + # Print configuration + print(f"Batch sizes: {args.batch_sizes}") + print(f"Vocab sizes: {args.vocab_sizes}") + print(f"Warmup iterations: {args.warmup_iters}") + print(f"Benchmark iterations: {args.benchmark_iters}") + print() + + # Check CUDA + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available. This benchmark requires a GPU.") + return + + device_name = torch.cuda.get_device_name(0) + print(f"GPU: {device_name}") + print() + + # Create configs + configs = create_benchmark_configs( + args.batch_sizes, + args.vocab_sizes, + ) + + # Run benchmarks + results = run_benchmark( + configs, + warmup_iters=args.warmup_iters, + benchmark_iters=args.benchmark_iters, + verbose=not args.quiet, + ) + + # Print summary + print_summary_table(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5865473e95426bcc89ab4c4130de76ca81e34d49 --- /dev/null +++ b/benchmarks/benchmark_utils.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from types import TracebackType + + +# Collect time and generate time metrics +# +# Example Usage: +# collector = TimeCollector(TimeCollector.US) +# for _ in range(total_iteration): +# with collector: +# ... +# collector.dump_avg_max() +class TimeCollector: + NS: int = 1 + US: int = NS * 1000 + MS: int = US * 1000 + S: int = MS * 1000 + + def __init__(self, scale: int) -> None: + self.cnt: int = 0 + self._sum: int = 0 + self._max: int | None = None + self.scale = scale + self.start_time: int = time.monotonic_ns() + + def collect(self, v: int) -> None: + self.cnt += 1 + self._sum += v + if self._max is None: + self._max = v + else: + self._max = max(self._max, v) + + def avg(self) -> float | str: + return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A" + + def max(self) -> float | str: + return self._max / self.scale if self._max else "N/A" + + def dump_avg_max(self) -> list[float | str]: + return [self.avg(), self.max()] + + def __enter__(self) -> None: + self.start_time = time.monotonic_ns() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + self.collect(time.monotonic_ns() - self.start_time) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..7720f15e45cc1535e3c195faf2752d618c42ee9d --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,517 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import itertools +import pickle as pkl +import time +from collections.abc import Callable, Iterable + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils.argparse_utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn( + label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs +) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm( + a, b_compressed, e, scale_a, scale_b, torch.bfloat16 + ) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn( + label, + sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16), + ) + ) + + # pytorch impl - float16 + timers.append( + bench_fn( + label, + sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.float16), + b.to(dtype=torch.float16), + ) + ) + + # cutlass impl + timers.append( + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + ) + ) + + # cutlass with bias + timers.append( + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) + + # cutlass sparse impl + timers.append( + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + ) + ) + + # cutlass sparse with bias + timers.append( + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) + + return timers + + +def bench_fp8( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm( + a, b_compressed, e, scale_a, scale_b, torch.bfloat16 + ) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn( + label, + sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), + ) + ) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + ) + ) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + ) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + ) + ) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True, + ) + ) + + # cutlass impl: bf16 output + timers.append( + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + ) + ) + + # cutlass impl: bf16 output + timers.append( + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + ) + ) + + # cutlass impl: fp16 output + timers.append( + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.float16, + ) + ) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.float16, + bias.to(dtype=torch.float16), + ) + ) + + return timers + + +def bench( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError( + f"Unsupported dtype {dtype}: should be one of torch.int8, torch.float8_e4m3fn." + ) + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run( + dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]] +) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']", + ) + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6cbcf6b68c89fc9e2719ccce8ab948276558fa2f --- /dev/null +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Cutlass bench utils + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device="cuda") * 5 + b = torch.randn((n, k), device="cuda").t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device="cuda") * 5 + b = torch.randn((n, k), device="cuda").t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..f7325ddd2cbbfef35bf1ffeb7e9ebb678c7e355b --- /dev/null +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -0,0 +1,372 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import itertools +import pickle as pkl +import time +from collections.abc import Callable, Iterable + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + w8a8_triton_block_scaled_mm, +) +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.math_utils import cdiv + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn( + label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs +) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: list[str] | None = None, +) -> Iterable[TMeasurement]: + """Benchmark INT8-based kernels.""" + assert dtype == torch.int8 + a, b = make_rand_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m,), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32) + + bench_fns = { + "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.float16), b.to(dtype=torch.float16) + ), + "cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16 + ), + "cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16, bias + ), + "cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias + ), + } + + timers = [] + for name, fn in bench_fns.items(): + # If bench_kernels is None, run all. Otherwise, run only exact matches. + if bench_kernels is None or name in bench_kernels: + print(f"Running {name}") + timers.append(bench_fn(label, sub_label, name, fn)) + + return timers + + +def bench_fp8( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: list[str] | None = None, +) -> Iterable[TMeasurement]: + """Benchmark FP8-based kernels.""" + assert dtype == torch.float8_e4m3fn + a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + a_cont = a.contiguous() + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32) + block_scale_b = torch.rand( + cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32 + ) + block_scale_a_M_major = block_scale_a.t().contiguous().t() + block_scale_b_K_major = block_scale_b.t().contiguous().t() + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) + + print(m, k, n) + + bench_fns = { + "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.float16), b.to(dtype=torch.float16) + ), + "pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16 + ), + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True + ), + "pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16 + ), + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True + ), + "cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16 + ), + "cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.float16 + ), + "cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16, bias + ), + "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) + ), + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm( + a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) + ), + "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( + a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16 + ), + } + + timers = [] + for name, fn in bench_fns.items(): + # If bench_kernels is None, run all. Otherwise, run only exact matches. + if bench_kernels is None or name in bench_kernels: + print(f"Running {name}") + timers.append(bench_fn(label, sub_label, name, fn)) + + return timers + + +def bench( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: list[str] | None = None, +) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run( + dtype: torch.dtype, + MKNs: Iterable[tuple[int, int, int]], + bench_kernels: list[str] | None = None, +) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench( + dtype, + m, + k, + n, + f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + bench_kernels=bench_kernels, + ) + print_timers(timers) + results.extend(timers) + return results + + +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +def run_square_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs, bench_kernels=args.kernels) + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs, bench_kernels=args.kernels) + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs, bench_kernels=args.kernels) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']", + ) + parser.add_argument( + "--kernels", + nargs="+", + type=str, + default=None, + help="Exact names of the kernels to benchmark. If not set, runs all kernels.", + ) + + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..25b96ef56620ea7dbd97846cdd57ab0e97a6dfd1 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], +} diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..d683835db96a4d0f720e3d86560694ee6af9b828 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -0,0 +1,143 @@ +#!/bin/bash + +# benchmark the overhead of disaggregated prefill. +# methodology: +# - send all request to prefill vLLM instance. It will buffer KV cache. +# - then send all request to decode instance. +# - The TTFT of decode instance is the overhead. + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 + # vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445 + pgrep VLLM | xargs -r kill -9 + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=10 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + + CUDA_VISIBLE_DEVICES=0 vllm serve $model \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + + CUDA_VISIBLE_DEVICES=1 vllm serve $model \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + wait_for_server 8100 + wait_for_server 8200 + + # let the prefill instance finish prefill + vllm bench serve \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_tp1.json \ + --request-rate "inf" + + + # send the request to decode. + # The TTFT of this command will be the overhead of disagg prefill impl. + vllm bench serve \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_tp1_overhead.json \ + --request-rate "$qps" + kill_gpu_processes + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx datasets + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=1 + default_output_len=1 + benchmark $default_qps $default_output_len + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..35c86cc845221ae24395f89394c764d78eed5329 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -0,0 +1,157 @@ +#!/bin/bash + +# Requirement: 2x GPUs. + + +# Model: meta-llama/Meta-Llama-3.1-8B-Instruct +# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests +# Resource: 2x GPU +# Approaches: +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 + # vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445 + pgrep VLLM | xargs -r kill -9 + for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done + sleep 1 +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +launch_chunked_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 vllm serve $model \ + --port 8100 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + CUDA_VISIBLE_DEVICES=1 vllm serve $model \ + --port 8200 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + wait_for_server 8100 + wait_for_server 8200 + python3 round_robin_proxy.py & + sleep 1 +} + + +launch_disagg_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 vllm serve $model \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + CUDA_VISIBLE_DEVICES=1 vllm serve $model \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + wait_for_server 8100 + wait_for_server 8200 + python3 disagg_prefill_proxy_server.py & + sleep 1 +} + + +benchmark() { + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=100 + qps=$1 + prefix_len=50 + input_len=1024 + output_len=$2 + tag=$3 + + vllm bench serve \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename "$tag"-qps-"$qps".json \ + --request-rate "$qps" + + sleep 2 +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + (which lsof) || (apt-get -y install lsof) + + pip install quart httpx matplotlib aiohttp datasets + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt so that we can sample 2048 tokens for input + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_output_len=6 + + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + launch_chunked_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len chunked_prefill + done + kill_gpu_processes + + launch_disagg_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len disagg_prefill + done + kill_gpu_processes + + python3 visualize_benchmark_results.py + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py new file mode 100644 index 0000000000000000000000000000000000000000..d072c03c440b2bc334e0ee32a63120124ac96c14 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import asyncio +import logging +import os +import time +import uuid +from urllib.parse import urlparse + +import aiohttp +from quart import Quart, Response, make_response, request + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_args(): + """parse command line arguments""" + parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server") + + # Add args + parser.add_argument( + "--timeout", + type=float, + default=6 * 60 * 60, + help="Timeout for backend service requests in seconds (default: 21600)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + ) + parser.add_argument( + "--prefill-url", + type=str, + default="http://localhost:8100", + help="Prefill service base URL (protocol + host[:port])", + ) + parser.add_argument( + "--decode-url", + type=str, + default="http://localhost:8200", + help="Decode service base URL (protocol + host[:port])", + ) + parser.add_argument( + "--kv-host", + type=str, + default="localhost", + help="Hostname or IP used by KV transfer (default: localhost)", + ) + parser.add_argument( + "--prefill-kv-port", + type=int, + default=14579, + help="Prefill KV port (default: 14579)", + ) + parser.add_argument( + "--decode-kv-port", + type=int, + default=14580, + help="Decode KV port (default: 14580)", + ) + + return parser.parse_args() + + +def main(): + """parse command line arguments""" + args = parse_args() + + # Initialize configuration using command line parameters + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout) + PREFILL_SERVICE_URL = args.prefill_url + DECODE_SERVICE_URL = args.decode_url + PORT = args.port + + PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}" + DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}" + + logger.info( + "Proxy resolved KV addresses -> prefill: %s, decode: %s", + PREFILL_KV_ADDR, + DECODE_KV_ADDR, + ) + + app = Quart(__name__) + + # Attach the configuration object to the application instance so helper + # coroutines can read the resolved backend URLs and timeouts without using + # globals. + app.config.update( + { + "AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT, + "PREFILL_SERVICE_URL": PREFILL_SERVICE_URL, + "DECODE_SERVICE_URL": DECODE_SERVICE_URL, + "PREFILL_KV_ADDR": PREFILL_KV_ADDR, + "DECODE_KV_ADDR": DECODE_KV_ADDR, + } + ) + + def _normalize_base_url(url: str) -> str: + """Remove any trailing slash so path joins behave predictably.""" + return url.rstrip("/") + + def _get_host_port(url: str) -> str: + """Return the hostname:port portion for logging and KV headers.""" + parsed = urlparse(url) + host = parsed.hostname or "localhost" + port = parsed.port + if port is None: + port = 80 if parsed.scheme == "http" else 443 + return f"{host}:{port}" + + PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL) + DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL) + KV_TARGET = _get_host_port(DECODE_SERVICE_URL) + + def _build_headers(request_id: str) -> dict[str, str]: + """Construct the headers expected by vLLM's P2P disagg connector.""" + headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET} + api_key = os.environ.get("OPENAI_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + async def _run_prefill( + request_path: str, + payload: dict, + headers: dict[str, str], + request_id: str, + ): + url = f"{PREFILL_BASE}{request_path}" + start_ts = time.perf_counter() + logger.info("[prefill] start request_id=%s url=%s", request_id, url) + try: + async with ( + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, + session.post(url=url, json=payload, headers=headers) as resp, + ): + if resp.status != 200: + error_text = await resp.text() + raise RuntimeError( + f"Prefill backend error {resp.status}: {error_text}" + ) + await resp.read() + logger.info( + "[prefill] done request_id=%s status=%s elapsed=%.2fs", + request_id, + resp.status, + time.perf_counter() - start_ts, + ) + except asyncio.TimeoutError as exc: + raise RuntimeError(f"Prefill service timeout at {url}") from exc + except aiohttp.ClientError as exc: + raise RuntimeError(f"Prefill service unavailable at {url}") from exc + + async def _stream_decode( + request_path: str, + payload: dict, + headers: dict[str, str], + request_id: str, + ): + url = f"{DECODE_BASE}{request_path}" + # Stream tokens from the decode service once the prefill stage has + # materialized KV caches on the target workers. + logger.info("[decode] start request_id=%s url=%s", request_id, url) + try: + async with ( + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, + session.post(url=url, json=payload, headers=headers) as resp, + ): + if resp.status != 200: + error_text = await resp.text() + logger.error( + "Decode backend error %s - %s", resp.status, error_text + ) + err_msg = ( + '{"error": "Decode backend error ' + str(resp.status) + '"}' + ) + yield err_msg.encode() + return + logger.info( + "[decode] streaming response request_id=%s status=%s", + request_id, + resp.status, + ) + async for chunk_bytes in resp.content.iter_chunked(1024): + yield chunk_bytes + logger.info("[decode] finished streaming request_id=%s", request_id) + except asyncio.TimeoutError: + logger.error("Decode service timeout at %s", url) + yield b'{"error": "Decode service timeout"}' + except aiohttp.ClientError as exc: + logger.error("Decode service error at %s: %s", url, exc) + yield b'{"error": "Decode service unavailable"}' + + async def process_request(): + """Process a single request through prefill and decode stages""" + try: + original_request_data = await request.get_json() + + # Create prefill request (max_tokens=1) + prefill_request = original_request_data.copy() + prefill_request["max_tokens"] = 1 + if "max_completion_tokens" in prefill_request: + prefill_request["max_completion_tokens"] = 1 + + # Execute prefill stage + # The request id encodes both KV socket addresses so the backend can + # shuttle tensors directly via NCCL once the prefill response + # completes. + request_id = ( + f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_" + f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}" + ) + + headers = _build_headers(request_id) + await _run_prefill(request.path, prefill_request, headers, request_id) + + # Execute decode stage and stream response + # Pass the unmodified user request so the decode phase can continue + # sampling with the already-populated KV cache. + generator = _stream_decode( + request.path, original_request_data, headers, request_id + ) + response = await make_response(generator) + response.timeout = None # Disable timeout for streaming response + return response + + except Exception: + logger.exception("Error processing request") + return Response( + response=b'{"error": "Internal server error"}', + status=500, + content_type="application/json", + ) + + @app.route("/v1/completions", methods=["POST"]) + async def handle_request(): + """Handle incoming API requests with concurrency and rate limiting""" + try: + return await process_request() + except asyncio.CancelledError: + logger.warning("Request cancelled") + return Response( + response=b'{"error": "Request cancelled"}', + status=503, + content_type="application/json", + ) + + # Start the Quart server with host can be set to 0.0.0.0 + app.run(port=PORT) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..b1df2f255822dad046f5dfcdc1d6538006463510 --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import itertools + +import aiohttp +from aiohttp import web + + +class RoundRobinProxy: + def __init__(self, target_ports): + self.target_ports = target_ports + self.port_cycle = itertools.cycle(self.target_ports) + + async def handle_request(self, request): + target_port = next(self.port_cycle) + target_url = f"http://localhost:{target_port}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + try: + # Forward the request + async with session.request( + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, + ) as response: + # Start sending the response + resp = web.StreamResponse( + status=response.status, headers=response.headers + ) + await resp.prepare(request) + + # Stream the response content + async for chunk in response.content.iter_any(): + await resp.write(chunk) + + await resp.write_eof() + return resp + + except Exception as e: + return web.Response(text=f"Error: {str(e)}", status=500) + + +async def main(): + proxy = RoundRobinProxy([8100, 8200]) + app = web.Application() + app.router.add_route("*", "/{path:.*}", proxy.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8000) + await site.start() + + print("Proxy server started on http://localhost:8000") + + # Keep the server running + await asyncio.Event().wait() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py new file mode 100644 index 0000000000000000000000000000000000000000..74fa56d076cf14bc066468be24f6053b45166001 --- /dev/null +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import matplotlib.pyplot as plt +import pandas as pd + +if __name__ == "__main__": + data = [] + for name in ["disagg_prefill", "chunked_prefill"]: + for qps in [2, 4, 6, 8]: + with open(f"results/{name}-qps-{qps}.json") as f: + x = json.load(f) + x["name"] = name + x["qps"] = qps + data.append(x) + + df = pd.DataFrame.from_dict(data) + dis_df = df[df["name"] == "disagg_prefill"] + chu_df = df[df["name"] == "chunked_prefill"] + + plt.style.use("bmh") + plt.rcParams["font.size"] = 20 + + for key in [ + "mean_ttft_ms", + "median_ttft_ms", + "p99_ttft_ms", + "mean_itl_ms", + "median_itl_ms", + "p99_itl_ms", + ]: + fig, ax = plt.subplots(figsize=(11, 7)) + plt.plot( + dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4 + ) + plt.plot( + chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4 + ) + ax.legend() + + ax.set_xlabel("QPS") + ax.set_ylabel(key) + ax.set_ylim(bottom=0) + fig.savefig(f"results/{key}.png") + plt.close(fig) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..4978a8777ab5c765ca855b06e872a37ca52ba6fb --- /dev/null +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle as pkl +import time +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from itertools import product + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from tqdm import tqdm + +import vllm._custom_ops as ops +from vllm.benchmarks.lib.utils import default_vllm_config +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) + + +@dataclass +class bench_params_t: + num_tokens: int + hidden_size: int + add_residual: bool + dtype: torch.dtype + group_size: list[int] + + def description(self): + return ( + f"N {self.num_tokens} " + f"x D {self.hidden_size} " + f"x R {self.add_residual} " + f"x DT {self.dtype}" + f"x GS {self.group_size}" + ) + + +def get_bench_params() -> list[bench_params_t]: + ## Test Fixtures + NUM_TOKENS = [2**x for x in range(11)] + HIDDEN_SIZES = list(range(1024, 8129, 1024)) + ADD_RESIDUAL = [True, False] + DTYPES = [torch.bfloat16, torch.float] + GROUP_SIZES = [[1, 64], [1, 128]] + + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES) + bench_params = list( + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations) + ) + return bench_params + + +# Reference impls +def unfused_int8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _, _ = ops.scaled_int8_quant(torch_out) + + +def unfused_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = ops.scaled_fp8_quant(torch_out) + + +def unfused_groupwise_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = per_token_group_quant_fp8( + torch_out, group_size=group_size[1], use_ue8m0=False + ) + + +def fused_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + out, _ = ops.rms_norm_dynamic_per_token_quant( + x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual + ) + + +def fused_groupwise_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + out, _ = ops.rms_norm_per_block_quant( + x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + group_size, + residual=residual, + is_scale_transposed=True, + ) + + +# Bench functions +def bench_fn( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor, + quant_dtype: torch.dtype, + group_size: list[int], + label: str, + sub_label: str, + fn: Callable, + description: str, +) -> TMeasurement: + min_run_time = 1 + + globals = { + "rms_norm_layer": rms_norm_layer, + "x": x, + "residual": residual, + "quant_dtype": quant_dtype, + "group_size": group_size, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]: + # Make inputs + layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) + # Make weights + layer.weight.data.normal_(mean=1.0, std=0.1) + # Make inputs + scale = 1 / params.hidden_size + x = ( + torch.randn( + params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda" + ) + * scale + ) + residual = ( + (torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None + ) + + timers = [] + + # unfused int8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.int8, + params.group_size, + label, + sub_label, + unfused_int8_impl, + "unfused_int8_impl", + ) + ) + + # unfused fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + unfused_fp8_impl, + "unfused_fp8_impl", + ) + ) + + # fused int8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.int8, + params.group_size, + label, + sub_label, + fused_impl, + "fused_int8_impl", + ) + ) + + # fused fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + fused_impl, + "fused_fp8_impl", + ) + ) + + # unfused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + unfused_groupwise_fp8_impl, + "unfused_groupwise_fp8_impl", + ) + ) + + # fused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + fused_groupwise_impl, + "fused_groupwise_fp8_impl", + ) + ) + + print_timers(timers) + + return timers + + +# launch bench +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +@default_vllm_config() +def main(): + torch.set_default_device("cuda") + bench_params = get_bench_params() + + timers = [] + for bp in tqdm(bench_params): + timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + print_timers(timers) + + # pickle all the results + timestamp = int(time.time()) + with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f: + pkl.dump(timers, f) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..04921dafbdbea0a1b581e6210ba0560dcc603316 --- /dev/null +++ b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from enum import Enum +from itertools import product +from typing import Any + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor, + silu_mul_per_token_group_quant_fp8_colmajor, +) +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + +from .utils import ArgPool, Bench, CudaGraphBenchParams + +GROUP_SIZE = 128 +FLOAT8_T = torch.float8_e4m3fn + + +def print_timers(timers: list[TMeasurement], cuda_graph_nops: int): + print( + f"Note : The timings reported above is for {cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {cuda_graph_nops} for single invocation " + "timings." + ) + compare = TBenchmark.Compare(timers) + compare.print() + + +class ImplType(Enum): + SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1 + REFERENCE = 2 + + def get_impl(self): + if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return silu_mul_per_token_group_quant_fp8_colmajor + elif self == ImplType.REFERENCE: + return reference + raise ValueError(f"Unrecognized ImplType {self}") + + +@dataclass +class BenchmarkTensors: + input: torch.Tensor + output: torch.Tensor + + # Reference act output tensor + ref_act_out: torch.Tensor + ref_quant_out: torch.Tensor + + @staticmethod + def make(T: int, N: int) -> "BenchmarkTensors": + assert T % GROUP_SIZE == 0 + assert N % (GROUP_SIZE * 2) == 0 + + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") + + # silu_mul_per_token_group_quant_fp8_colmajor output. + output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to( + FLOAT8_T + ) + + # reference output. + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") + ref_quant_out = torch.empty( + (T, N // 2), dtype=torch.bfloat16, device="cuda" + ).to(FLOAT8_T) + + return BenchmarkTensors( + input=input, + output=output, + ref_act_out=ref_act_out, + ref_quant_out=ref_quant_out, + ) + + @property + def T(self): + return self.input.size(0) + + @property + def N(self): + return self.input.size(1) + + def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]: + if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return { + "input": self.input, + "output": self.output, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + elif impl_type == ImplType.REFERENCE: + return { + "input": self.input, + "act_out": self.ref_act_out, + "quant_out": self.ref_quant_out, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + raise ValueError(f"Unrecognized impl_type {impl_type}") + + +def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool): + """ + Reference triton quant kernel from, + vllm.model_executor.layers.quantization.utils.fp8_utils + """ + assert quant_out.size() == x.size() + # Allocate the scale tensor column-major format. + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] + x_q = quant_out + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + + M = x.numel() // GROUP_SIZE + N = GROUP_SIZE + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + finfo = torch.finfo(FLOAT8_T) + fp8_min = finfo.min + fp8_max = finfo.max + + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + GROUP_SIZE, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + return x_q, x_s + + +def reference( + input: torch.Tensor, + act_out: torch.Tensor, + quant_out: torch.Tensor, + use_ue8m0: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + torch.ops._C.silu_and_mul(act_out, input) + return reference_quant(act_out, quant_out, use_ue8m0) + + +def bench_impl( + bench_tensors: list[BenchmarkTensors], impl_type: ImplType +) -> TMeasurement: + T = bench_tensors[0].T + N = bench_tensors[0].N + + arg_pool_size = len(bench_tensors) + kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors] + + # warmup + for kwargs in kwargs_list: + impl_type.get_impl()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + cuda_graph_params = None + cuda_graph_params = CudaGraphBenchParams(arg_pool_size) + timer = None + with Bench( + cuda_graph_params, + "silu-mul-quant", + f"num_tokens={T}, N={N}", + impl_type.name, + impl_type.get_impl(), + **kwargs, + ) as bench: + timer = bench.run() + return timer + + +def test_correctness(T: int, N: int): + print(f"Testing num_tokens={T}, N={N} ...") + + bench_tensor = BenchmarkTensors.make(T, N) + + def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]: + return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl)) + + # reference output + ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE) + + # test ouptut + out_q, out_s = output_from_impl( + ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + + torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32)) + torch.testing.assert_close(ref_out_s, out_s) + + +def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]: + timers = [] + for N, T in product(Ns, Ts): + test_correctness(T, N) + + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(T, N) for _ in range(arg_pool_size) + ] + + silu_mul_quant_timer = bench_impl( + bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + timers.append(silu_mul_quant_timer) + reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE) + timers.append(reference_timer) + + print_timers( + [silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size + ) + + print_timers(timers, cuda_graph_nops=arg_pool_size) + + return timers + + +if __name__ == "__main__": + T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)] + N = [2048, 4096, 8192] + + print(f"T = {T}, N = {N}") + run(T, N, arg_pool_size=8) diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..e1cec02b7cad727ca8125beb61b80b5175fc54e3 --- /dev/null +++ b/benchmarks/kernels/benchmark_activation.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# benchmark custom activation op performance +import itertools + +import torch + +import vllm.model_executor.layers.activation # noqa F401 +from vllm.benchmarks.lib.utils import default_vllm_config +from vllm.model_executor.custom_op import op_registry +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed + +batch_size_range = [1, 16, 128] +seq_len_range = [1, 16, 64, 1024, 4096] +intermediate_size = [3072, 9728, 12288] +configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) + + +@default_vllm_config() +def benchmark_activation( + batch_size: int, + seq_len: int, + intermediate_size: int, + provider: str, + func_name: str, + dtype: torch.dtype, +): + device = "cuda" + num_tokens = batch_size * seq_len + dim = intermediate_size + set_random_seed(42) + torch.set_default_device(device) + + if func_name == "gelu_and_mul": + layer = op_registry[func_name](approximate="none") + elif func_name == "gelu_and_mul_tanh": + layer = op_registry["gelu_and_mul"](approximate="tanh") + elif func_name == "fatrelu_and_mul": + threshold = 0.5 + layer = op_registry[func_name](threshold) + else: + layer = op_registry[func_name]() + + x = torch.randn(num_tokens, dim, dtype=dtype, device=device) + compiled_layer = torch.compile(layer.forward_native) + + if provider == "custom": + fn = lambda: layer(x) + elif provider == "compiled": + fn = lambda: compiled_layer(x) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) + return ms, max_ms, min_ms + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the custom activation op.") + parser.add_argument( + "--func-name", + type=str, + choices=[ + "mul_and_silu", + "silu_and_mul", + "gelu_and_mul", + "gelu_and_mul_tanh", + "fatrelu_and_mul", + "swigluoai_and_mul", + "gelu_new", + "gelu_fast", + "quick_gelu", + ], + default="silu_and_mul", + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" + ) + args = parser.parse_args() + assert args + + func_name = args.func_name + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + perf_report = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "intermediate_size"], + x_vals=configs, + line_arg="provider", + line_vals=["custom", "compiled"], + line_names=["Custom OP", "Compiled"], + styles=[("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"{func_name}-op-performance", + args={}, + ) + ) + + perf_report( + lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation( + batch_size, seq_len, intermediate_size, provider, func_name, dtype + ) + ).run(print_data=True) diff --git a/benchmarks/kernels/benchmark_block_fp8_gemm.py b/benchmarks/kernels/benchmark_block_fp8_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..8d50c3828206dfed74f3f95cc4a517e96f5e3b56 --- /dev/null +++ b/benchmarks/kernels/benchmark_block_fp8_gemm.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os + +# Disable DeepGEMM for this benchmark to use CUTLASS +os.environ["VLLM_USE_DEEP_GEMM"] = "0" + +import torch + +from vllm.benchmarks.lib.utils import default_vllm_config +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton as vllm_triton + +assert current_platform.is_cuda(), ( + "Only support benchmarking w8a8 block fp8 kernel on CUDA device." +) + +# DeepSeek-V3 weight shapes +DEEPSEEK_V3_SHAPES = [ + (512 + 64, 7168), + (2112, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + (18432 * 2, 7168), + (24576, 1536), + (12288, 7168), + (4096, 7168), + (7168, 2048), +] + + +@default_vllm_config() +def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): + """Build runner function for w8a8 block fp8 matmul.""" + factor_for_scale = 1e-2 + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + # Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp) + A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + + # Create quantized weight tensor + B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + # Create weight scales + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device) + * factor_for_scale + ) + + # Create W8A8BlockFp8LinearOp instance + weight_group_shape = GroupShape(block_n, block_k) + act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization + + linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=weight_group_shape, + act_quant_group_shape=act_quant_group_shape, + cutlass_block_fp8_supported=use_cutlass, + use_aiter_and_is_supported=False, + ) + + def run(): + return linear_op.apply( + input=A_ref, + weight=B, + weight_scale=Bs, + input_scale=None, + bias=None, + ) + + return run + + +# Determine available providers +available_providers = ["torch-bf16", "w8a8-block-fp8-triton"] +plot_title = "BF16 vs W8A8 Block FP8 GEMMs" + +if CUTLASS_BLOCK_FP8_SUPPORTED: + available_providers.append("w8a8-block-fp8-cutlass") + + +@vllm_triton.testing.perf_report( + vllm_triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=available_providers, + line_names=available_providers, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs W8A8 Block FP8 GEMMs", + args={}, + ) +) +def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): + M = batch_size + device = "cuda" + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + a = torch.randn((M, K), device=device, dtype=torch.bfloat16) + b = torch.randn((N, K), device=device, dtype=torch.bfloat16) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + elif provider == "w8a8-block-fp8-triton": + run_w8a8_triton = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=False + ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_triton(), quantiles=quantiles + ) + elif provider == "w8a8-block-fp8-cutlass": + run_w8a8_cutlass = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=True + ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_cutlass(), quantiles=quantiles + ) + else: + raise ValueError(f"Unknown provider: {provider}") + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +if __name__ == "__main__": + block_size = (128, 128) + + for N, K in DEEPSEEK_V3_SHAPES: + print(f"\nBenchmarking DeepSeek-V3, N={N} K={K}") + + print(f"TFLOP/s comparison (block_size={block_size}):") + benchmark_tflops.run( + print_data=True, + # show_plots=False, + # save_path=f"bench_w8a8_block_fp8_tflops_n{N}_k{K}", + N=N, + K=K, + block_size=block_size, + ) + + print("\nBenchmark finished!") diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..bd116e36a7166e2f0bc95dcb988e11a18fc6c316 --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe +kernel. Both kernels take in fp8 quantized weights and 16-bit activations, +but use different quantization strategies and backends. +""" + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.platforms import current_platform +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.v1.worker.workspace import init_workspace_manager + +# Weight shapes for different models: [num_experts, topk, hidden_size, +# intermediate_size] +WEIGHT_SHAPES_MOE = { + "mixtral-8x7b": [ + [8, 2, 4096, 14336], + ], + "deepseek-v2": [ + [160, 6, 5120, 12288], + ], + "custom-small": [ + [8, 2, 2048, 7168], + ], + "glm45-fp8": [ + [128, 8, 4096, 1408], + ], + "Llama-4-Maverick-17B-128E-Instruct-FP8": [ + [128, 1, 5120, 8192], + ], +} + +DEFAULT_MODELS = [ + "mixtral-8x7b", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False, True] +PER_OUT_CH_OPTS = [False, True] + +FP8_DTYPE = current_platform.fp8_dtype() + + +def bench_run( + results: list, + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): + init_workspace_manager(torch.cuda.current_device()) + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + + # Create input activations + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + + # Create weights + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + # Create FP8 quantized weights and scales for both kernels + w1_fp8q = torch.empty((num_experts, 2 * n, k), device=device, dtype=FP8_DTYPE) + w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=FP8_DTYPE) + + # Create scales based on quantization strategy + if per_out_ch: + # Per-channel quantization + w1_scale = torch.empty( + (num_experts, 2 * n, 1), device=device, dtype=torch.float32 + ) + w2_scale = torch.empty((num_experts, k, 1), device=device, dtype=torch.float32) + else: + # Per-tensor quantization + w1_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + + # Quantize weights + for expert in range(num_experts): + if per_out_ch: + # Per-channel quantization - not yet implemented properly + # For now, fall back to per-tensor quantization + w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert]) + # Expand scalar scales to the expected per-channel shape + w1_scale[expert] = w1_scale_temp.expand(2 * n, 1) + w2_scale[expert] = w2_scale_temp.expand(k, 1) + else: + # Per-tensor quantization + w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert]) + # Store scalar scales in [1, 1] tensors + w1_scale[expert, 0, 0] = w1_scale_temp + w2_scale[expert, 0, 0] = w2_scale_temp + + # Prepare weights for CUTLASS (no transpose needed) + w1_fp8q_cutlass = w1_fp8q # Keep original [E, 2N, K] + w2_fp8q_cutlass = w2_fp8q # Keep original [E, K, N] + + # Create router scores and get topk + score = torch.randn((m, num_experts), device=device, dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + + # WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization + # Force per-tensor quantization for all cases to match working e2e setup + a1_scale = torch.full((), 1e-2, device=device, dtype=torch.float32) + a2_scale = torch.full((), 1e-2, device=device, dtype=torch.float32) + + # Force per-tensor quantization for all cases + per_act_token = False + + # Pre-create quantization config to avoid creating it inside CUDA graph + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + ) + + moe_config = make_dummy_moe_config( + num_experts=num_experts, + hidden_dim=k, + intermediate_size_per_partition=n, + in_dtype=a.dtype, + ) + fn = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + CutlassExpertsFp8( + moe_config=moe_config, + quant_config=quant_config, + ), + ) + + # Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly) + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + # Capture 10 invocations like benchmark_moe.py + for _ in range(10): + fn( + a, + w1_fp8q_cutlass, + w2_fp8q_cutlass, + topk_weights, + topk_ids, + activation=MoEActivation.SILU, + global_num_experts=num_experts, + ) + torch.cuda.synchronize() + + # Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly) + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + # Capture 10 invocations like benchmark_moe.py + for _ in range(10): + fused_experts( + a, + w1_fp8q, + w2_fp8q, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + torch.cuda.synchronize() + + def bench_cuda_graph(graph, num_warmup=5, num_iters=100): + """Benchmark CUDA graph using events like benchmark_moe.py""" + # Warmup + for _ in range(num_warmup): + graph.replay() + torch.cuda.synchronize() + + # Timing + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies = [] + for _ in range(num_iters): + torch.cuda.synchronize() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + + # Divide by 10 since graph contains 10 calls + return sum(latencies) / (num_iters * 10) + + # Benchmark parameters + num_warmup = 5 + num_iters = 100 + + # Benchmark only CUDA graphs (more reliable and faster) + # Benchmark Triton MoE with CUDA graphs + triton_graph_time = bench_cuda_graph( + triton_graph, num_warmup=num_warmup, num_iters=num_iters + ) + + # Benchmark CUTLASS MoE with CUDA graphs + cutlass_graph_time = bench_cuda_graph( + cutlass_graph, num_warmup=num_warmup, num_iters=num_iters + ) + + # Convert ms to us and return results + triton_time_us = triton_graph_time * 1000 + cutlass_time_us = cutlass_graph_time * 1000 + + return { + "batch_size": m, + "triton_time_us": triton_time_us, + "cutlass_time_us": cutlass_time_us, + } + + +def main(args): + # Initialize workspace manager (required for CUTLASS MoE kernels) + device = torch.device("cuda:0") + init_workspace_manager(device) + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + all_results = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in args.per_act_token_opts: + for per_out_ch in args.per_out_ch_opts: + print( + f"\n=== {model}, experts={num_experts}, topk={topk}," + f"per_act={per_act_token}, per_out_ch={per_out_ch} ===" + ) + + config_results = [] + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + result = bench_run( + [], # Not used anymore + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) + if result: + config_results.append(result) + + # Print results table for this configuration + if config_results: + print( + f"\n{'Batch Size':<12}" + f"{'Triton (us)':<15}" + f"{'CUTLASS (us)':<15}" + ) + print("-" * 45) + for result in config_results: + print( + f"{result['batch_size']:<12}" + f"{result['triton_time_us']:<15.2f}" + f"{result['cutlass_time_us']:<15.2f}" + ) + + all_results.extend(config_results) + + print(f"\nTotal benchmarks completed: {len(all_results)}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE + across specified models/shapes/batches + + Example usage: + python benchmark_cutlass_moe_fp8.py \ + --model "Llama-4-Maverick-17B-128E-Instruct-FP8" \ + --tp-sizes 8 \ + --batch-size 2 4 8 \ + --per-act-token-opts false \ + --per-out-ch-opts false + + """ + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument( + "--per-act-token-opts", + nargs="+", + type=lambda x: x.lower() == "true", + default=[False, True], + help="Per-activation token quantization options (true/false)", + ) + parser.add_argument( + "--per-out-ch-opts", + nargs="+", + type=lambda x: x.lower() == "true", + default=[False, True], + help="Per-output channel quantization options (true/false)", + ) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb1489dadf2efc0febd750169642b1a6f8698ea --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py @@ -0,0 +1,540 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe +kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit +activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8) +and 16-bit activations. +""" + +import nvtx +import torch +import torch.utils.benchmark as benchmark + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config, + nvfp4_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4, +) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.scalar_type import scalar_types +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.v1.worker.workspace import init_workspace_manager + +WEIGHT_SHAPES_MOE = { + "nvidia/DeepSeek-R1-FP4": [ + [256, 8, 2048, 7168], + ], +} + +DEFAULT_MODELS = [ + "nvidia/DeepSeek-R1-FP4", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +def bench_run( + results: list[benchmark.Measurement], + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): + label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( + model, num_experts, topk, per_act_token, per_out_ch, mkn + ) + ) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + _, a_fp8_scale = ops.scaled_fp8_quant(a) + + w1_fp8q = torch.empty( + (num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn + ) + w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn) + w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert]) + + w1_fp8q_notransp = w1_fp8q.clone() + w2_fp8q_notransp = w2_fp8q.clone() + w1_fp8q = w1_fp8q.transpose(1, 2) + w2_fp8q = w2_fp8q.transpose(1, 2) + + score = torch.randn((m, num_experts), device=device, dtype=dtype) + + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + + quant_blocksize = 16 + w1_blockscale = torch.empty( + (num_experts, 2 * n, k // quant_blocksize), + device=device, + dtype=torch.float8_e4m3fn, + ) + w2_blockscale = torch.empty( + (num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn + ) + + # n_b_scales = 2 * n if per_out_ch else 1 + # k_b_scales = k if per_out_ch else 1 + w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8) + w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8) + + w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) + w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) + a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) + a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_e = w1[expert] + w2_e = w2[expert] + w1_amax = torch.abs(w1_e).max().to(torch.float32) + w2_amax = torch.abs(w2_e).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1_e, w1_gs[expert] + ) + + w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2_e, w2_gs[expert] + ) + + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, + num_repeats: int, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + + for _ in range(num_repeats): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + + def run_cutlass_moe_fp4( + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + w1_gs: torch.Tensor, + w2_gs: torch.Tensor, + a1_gs: torch.Tensor, + a2_gs: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + num_repeats: int, + ): + quant_config = nvfp4_moe_quant_config( + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + g1_alphas=w1_gs, + g2_alphas=w2_gs, + ) + + moe_config = make_dummy_moe_config( + num_experts=num_experts, + hidden_dim=k, + intermediate_size_per_partition=n, + in_dtype=a.dtype, + ) + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + CutlassExpertsFp4( + moe_config=moe_config, + quant_config=quant_config, + ), + ) + + for _ in range(num_repeats): + with nvtx.annotate("cutlass_moe_fp4", color="green"): + kernel( + hidden_states=a, + w1=w1_fp4, + w2=w2_fp4, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + + def run_cutlass_from_graph( + a: torch.Tensor, + a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + ): + quant_config = nvfp4_moe_quant_config( + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + g1_alphas=w1_gs, + g2_alphas=w2_gs, + ) + moe_config = make_dummy_moe_config() + + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + CutlassExpertsFp4( + moe_config=moe_config, + quant_config=quant_config, + ), + ) + + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return kernel( + hidden_states=a, + w1=w1_fp4, + w2=w2_fp4, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + + def run_triton_from_graph( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, + ): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + return fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph( + a, + w1_fp8q_notransp, + w2_fp8q_notransp, + topk_weights, + topk_ids, + w1_fp8scale, + w2_fp8scale, + a_fp8_scale, + ) + torch.cuda.synchronize() + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "w1": w1, + "w2": w2, + "score": score, + "topk": topk, + "w1_fp8q_notransp": w1_fp8q_notransp, + "w2_fp8q_notransp": w2_fp8q_notransp, + "w1_fp8scale": w1_fp8scale, + "w2_fp8scale": w2_fp8scale, + "a_fp8_scale": a_fp8_scale, + # Cutlass params + "a": a, + "a1_gscale": a1_gs, + "w1_fp4": w1_fp4, + "w1_blockscale": w1_blockscale, + "w1_alphas": w1_gs, + "a2_gscale": a2_gs, + "w2_fp4": w2_fp4, + "w2_blockscale": w2_blockscale, + "w2_alphas": w2_gs, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "m": m, + "n": n, + "k": k, + "e": num_experts, + "device": device, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe_fp4": run_cutlass_moe_fp4, + "replay_graph": replay_graph, + } + + # Warmup + run_triton_moe( + a, + w1_fp8q_notransp, + w2_fp8q_notransp, + topk_weights, + topk_ids, + w1_fp8scale, + w2_fp8scale, + a_fp8_scale, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + + run_cutlass_moe_fp4( + a, + w1_fp4, + w2_fp4, + w1_blockscale, + w2_blockscale, + w1_gs, + w2_gs, + a1_gs, + a2_gs, + topk_weights, + topk_ids, + m, + n, + k, + num_experts, + device, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + +def main(args): + # Initialize workspace manager (required for CUTLASS MoE kernels) + device = torch.device("cuda:0") + init_workspace_manager(device) + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + bench_run( + results, + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches" + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py new file mode 100644 index 0000000000000000000000000000000000000000..d1005461ab932e2fa01be85e037a2c46bddc18b1 --- /dev/null +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -0,0 +1,571 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark script for device communicators: +CustomAllreduce (oneshot, twoshot), PyNcclCommunicator, +and SymmMemCommunicator (multimem, two-shot). + +for NCCL symmetric memory you need to set the environment variables +NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does +not use fast NVLS implementation for all reduce. + +Usage: + torchrun --nproc_per_node= benchmark_device_communicators.py [options] + +Example: + torchrun --nproc_per_node=2 benchmark_device_communicators.py + --sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100 +""" + +import json +import os +import time +from collections.abc import Callable +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce +from vllm.distributed.device_communicators.flashinfer_all_reduce import ( + FlashInferAllReduce, +) +from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator, + register_nccl_symmetric_ops, +) +from vllm.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id, +) +from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator +from vllm.logger import init_logger +from vllm.utils.argparse_utils import FlexibleArgumentParser + +logger = init_logger(__name__) + +# Default sequence lengths to benchmark +DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192] + +# Fixed hidden size and dtype for all benchmarks +HIDDEN_SIZE = 8192 +BENCHMARK_DTYPE = torch.bfloat16 + +# CUDA graph settings +CUDA_GRAPH_CAPTURE_CYCLES = 10 + + +class CommunicatorBenchmark: + """Benchmark class for testing device communicators.""" + + def __init__( + self, + rank: int, + world_size: int, + device: torch.device, + cpu_group: ProcessGroup, + sequence_lengths: list[int], + ): + self.rank = rank + self.world_size = world_size + self.device = device + self.cpu_group = cpu_group + + # Calculate max_size_override based on largest sequence length + max_seq_len = max(sequence_lengths) + max_tensor_elements = max_seq_len * HIDDEN_SIZE + self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1 + + # Initialize communicators + self.custom_allreduce = None + self.pynccl_comm = None + self.symm_mem_comm = None + self.symm_mem_comm_multimem = None + self.symm_mem_comm_two_shot = None + self.fi_ar_comm = None + + self._init_communicators() + + def _init_communicators(self): + """Initialize all available communicators.""" + try: + self.custom_allreduce = CustomAllreduce( + group=self.cpu_group, + device=self.device, + max_size=self.max_size_override, + ) + if not self.custom_allreduce.disabled: + logger.info("Rank %s: CustomAllreduce initialized", self.rank) + else: + logger.info("Rank %s: CustomAllreduce disabled", self.rank) + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e + ) + self.custom_allreduce = None + + try: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, device=self.device + ) + if not self.pynccl_comm.disabled: + logger.info("Rank %s: PyNcclCommunicator initialized", self.rank) + register_nccl_symmetric_ops(self.pynccl_comm) + else: + logger.info("Rank %s: PyNcclCommunicator disabled", self.rank) + self.pynccl_comm = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e + ) + self.pynccl_comm = None + + # Initialize variants for SymmMemCommunicator + try: + self.symm_mem_comm_multimem = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=True, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_multimem.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (multimem) initialized", self.rank + ) + else: + self.symm_mem_comm_multimem = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s", + self.rank, + e, + ) + self.symm_mem_comm_multimem = None + + try: + self.symm_mem_comm_two_shot = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=False, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_two_shot.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank + ) + else: + self.symm_mem_comm_two_shot = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s", + self.rank, + e, + ) + self.symm_mem_comm_two_shot = None + + try: + self.fi_ar_comm = FlashInferAllReduce( + group=self.cpu_group, + device=self.device, + ) + if not self.fi_ar_comm.disabled: + logger.info("Rank %s: FlashInferAllReduce initialized", self.rank) + else: + logger.info("Rank %s: FlashInferAllReduce disabled", self.rank) + self.fi_ar_comm = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e + ) + self.fi_ar_comm = None + + def benchmark_allreduce( + self, sequence_length: int, num_warmup: int, num_trials: int + ) -> dict[str, float]: + """Benchmark allreduce operations for all available communicators.""" + + results = {} + + # Define communicators with their benchmark functions + communicators = [] + + if self.custom_allreduce is not None: + comm = self.custom_allreduce + # CustomAllreduce one-shot + communicators.append( + ( + "ca_1stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + {"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"}, + None, # no destroy function + ) + ) + # CustomAllreduce two-shot + communicators.append( + ( + "ca_2stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + {"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"}, + None, # no destroy function + ) + ) + + if self.pynccl_comm is not None: + comm = self.pynccl_comm + communicators.append( + ( + "pynccl", + lambda t, c=comm: c.all_reduce(t), + lambda t: True, # Always available if initialized + nullcontext(), + {}, # no env variable needed + None, # no destroy function + ) + ) + communicators.append( + ( + "pynccl-symm", + lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t), + lambda t: True, # Always available if initialized + nullcontext(), + {}, # no env variable needed + None, # no destroy function + ) + ) + + if self.symm_mem_comm_multimem is not None: + comm = self.symm_mem_comm_multimem + communicators.append( + ( + "symm_mem_multimem", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + {}, # no env variable needed + None, # no destroy function + ) + ) + + if self.symm_mem_comm_two_shot is not None: + comm = self.symm_mem_comm_two_shot + communicators.append( + ( + "symm_mem_two_shot", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + {}, # no env variable needed + None, # no destroy function needed + ) + ) + + if self.fi_ar_comm is not None: + comm = self.fi_ar_comm + communicators.append( + ( + "flashinfer_trtllm", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_fi_ar(t), + nullcontext(), + {"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"}, + lambda c=comm: c.destroy(), + ) + ) + communicators.append( + ( + "flashinfer_mnnvl", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_fi_ar(t), + nullcontext(), + {"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"}, + lambda c=comm: c.destroy(), + ) + ) + + # Benchmark each communicator + for ( + name, + allreduce_fn, + should_use_fn, + context, + env_dict, + destroy_fn, + ) in communicators: + # Save original values and apply new environment variables + saved_env = {key: os.environ.get(key) for key in env_dict} + for key, value in env_dict.items(): + os.environ[key] = value + try: + latency = self.benchmark_allreduce_single( + sequence_length, + allreduce_fn, + should_use_fn, + context, + num_warmup, + num_trials, + ) + if latency is not None: + results[name] = latency + finally: + if destroy_fn is not None: + destroy_fn() + # Restore environment variables to their original state + for key, original_value in saved_env.items(): + if original_value is None: + os.environ.pop(key, None) + else: + os.environ[key] = original_value + + return results + + def benchmark_allreduce_single( + self, + sequence_length: int, + allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None], + should_use_fn: Callable[[torch.Tensor], bool], + context, + num_warmup: int, + num_trials: int, + ) -> float | None: + """Benchmark method with CUDA graph optimization.""" + try: + # Create test tensor (2D: sequence_length x hidden_size) + tensor = torch.randn( + sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device + ) + if not should_use_fn(tensor): + return None + + torch.cuda.synchronize() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + graph_input = tensor.clone() + + # Warmup before capture + for _ in range(3): + allreduce_fn(graph_input) + + # Capture the graph using context manager + with context: + graph = torch.cuda.CUDAGraph() + graph_pool = torch.cuda.graph_pool_handle() + set_graph_pool_id(graph_pool) + with torch.cuda.graph(graph, pool=graph_pool, stream=stream): + for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): + allreduce_fn(graph_input) + + torch.cuda.synchronize() + for _ in range(num_warmup): + graph.replay() + torch.cuda.synchronize() + + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(num_trials): + graph.replay() + torch.cuda.synchronize() + + end_time = time.perf_counter() + + # Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES + return ( + (end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000 + ) + + except Exception as e: + logger.error("CUDA graph benchmark failed: %s", e) + raise RuntimeError( + f"CUDA graph benchmark failed for communicator: {e}" + ) from e + + +def _calculate_speedup_info(comm_results: dict[str, float]) -> str: + """Calculate speedup information for a single tensor size.""" + if not comm_results: + return "N/A" + + # Find the fastest communicator + fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k]) + fastest_time = comm_results[fastest_comm] + + # Calculate speedup vs PyNccl if available + if "pynccl" in comm_results: + pynccl_time = comm_results["pynccl"] + speedup = pynccl_time / fastest_time + return f"{fastest_comm} ({speedup:.2f}x)" + else: + return f"{fastest_comm} (N/A)" + + +def print_results( + results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int +): + """Print benchmark results in a formatted table.""" + + print(f"\n{'=' * 130}") + print("Device Communicator Benchmark Results") + print( + f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, " + f"Hidden Size: {HIDDEN_SIZE}" + ) + print(f"{'=' * 130}") + + # Get all communicator names + all_comms = set() + for size_results in results.values(): + all_comms.update(size_results.keys()) + + all_comms = sorted(list(all_comms)) + + # Print header + header = f"{'Tensor Shape':<20}{'Tensor Size':<15}" + for comm in all_comms: + header += f"{comm:<20}" + header += f"{'Best (Speedup vs PyNccl)':<30}" + print(header) + print("-" * len(header)) + + # Print results for each sequence length + for seq_len in sequence_lengths: + if seq_len in results: + # Calculate tensor size in elements and bytes + tensor_elements = seq_len * HIDDEN_SIZE + tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize + + # Format tensor size (MB) + tensor_size_mb = tensor_bytes / (1024 * 1024) + tensor_size_str = f"{tensor_size_mb:.2f} MB" + + # Format tensor shape + tensor_shape = f"({seq_len}, {HIDDEN_SIZE})" + + row = f"{tensor_shape:<20}{tensor_size_str:<15}" + for comm in all_comms: + if comm in results[seq_len]: + row += f"{results[seq_len][comm]:<20.3f}" + else: + row += f"{'N/A':<20}" + + # Calculate speedup information + speedup_info = _calculate_speedup_info(results[seq_len]) + row += f"{speedup_info:<30}" + + print(row) + + print(f"{'=' * 130}") + print("All times are in milliseconds (ms) per allreduce operation") + print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)") + + +def main(): + parser = FlexibleArgumentParser(description="Benchmark device communicators") + + parser.add_argument( + "--sequence-lengths", + type=int, + nargs="+", + default=DEFAULT_SEQUENCE_LENGTHS, + help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)", + ) + + parser.add_argument( + "--num-warmup", type=int, default=5, help="Number of warmup iterations" + ) + + parser.add_argument( + "--num-trials", type=int, default=50, help="Number of benchmark trials" + ) + + parser.add_argument("--output-json", type=str, help="Output results to JSON file") + + args = parser.parse_args() + + # Initialize distributed + if not dist.is_initialized(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Set device + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + # Get CPU process group + cpu_group = dist.new_group(backend="gloo") + + # Disable USE_SYMM_MEM to avoid affecting the max_sizes + # in symm_mem and custom_all_reduce for benchmark + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + # Initialize benchmark + benchmark = CommunicatorBenchmark( + rank, world_size, device, cpu_group, args.sequence_lengths + ) + + # Run benchmarks + all_results = {} + + for seq_len in args.sequence_lengths: + if rank == 0: + logger.info( + "Benchmarking sequence length: %s (tensor shape: %s x %s)", + seq_len, + seq_len, + HIDDEN_SIZE, + ) + + results = benchmark.benchmark_allreduce( + sequence_length=seq_len, + num_warmup=args.num_warmup, + num_trials=args.num_trials, + ) + + all_results[seq_len] = results + + # Synchronize between ranks + dist.barrier() + + # Print results (only rank 0) + if rank == 0: + print_results(all_results, args.sequence_lengths, world_size) + + # Save to JSON if requested + if args.output_json: + # Add speedup information to results + enhanced_results = {} + for seq_len, comm_results in all_results.items(): + enhanced_results[seq_len] = { + "timings": comm_results, + "speedup_info": _calculate_speedup_info(comm_results), + } + + output_data = { + "world_size": world_size, + "dtype": str(BENCHMARK_DTYPE), + "hidden_size": HIDDEN_SIZE, + "sequence_lengths": args.sequence_lengths, + "num_warmup": args.num_warmup, + "num_trials": args.num_trials, + "cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES, + "results": enhanced_results, + } + + with open(args.output_json, "w") as f: + json.dump(output_data, f, indent=2) + + logger.info("Results saved to %s", args.output_json) + + # Cleanup + if cpu_group != dist.group.WORLD: + dist.destroy_process_group(cpu_group) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_fp8_gemm.py b/benchmarks/kernels/benchmark_fp8_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..920961899038061c13125bdd37f9843b5e4d548a --- /dev/null +++ b/benchmarks/kernels/benchmark_fp8_gemm.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "fp8-tensor-w-token-a": dict( + w="tensor", a="token", no_a_quant=False, enabled=False + ), + "fp8-tensor-w-tensor-a": dict( + w="tensor", a="tensor", no_a_quant=False, enabled=True + ), + "fp8-channel-w-token-a": dict( + w="channel", a="token", no_a_quant=False, enabled=True + ), + "fp8-channel-w-tensor-a": dict( + w="channel", a="tensor", no_a_quant=False, enabled=False + ), + "fp8-tensor-w-token-a-noquant": dict( + w="tensor", a="token", no_a_quant=True, enabled=False + ), + "fp8-tensor-w-tensor-a-noquant": dict( + w="tensor", a="tensor", no_a_quant=True, enabled=True + ), + "fp8-channel-w-token-a-noquant": dict( + w="channel", a="token", no_a_quant=True, enabled=True + ), + "fp8-channel-w-tensor-a-noquant": dict( + w="channel", a="tensor", no_a_quant=True, enabled=False + ), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str): + if w_type == "tensor": + scale_b = torch.ones(1, device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + else: + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True) + return b_fp8.t(), scale_b_fp8 + + +def build_fp8_runner(cfg, a, b, dtype, device): + b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device) + + scale_a_const = ( + torch.ones(1, device=device, dtype=torch.float32) + if cfg["a"] == "tensor" + else None + ) + + if cfg["no_a_quant"]: + if cfg["a"] == "tensor": + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) + else: + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) + + def run(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + return run + + if cfg["a"] == "tensor": + + def run(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + else: + + def run(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs FP8 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_fp8_runner(cfg, a, b, dtype, device) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_fp8_res_n{N}_k{K}", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py new file mode 100644 index 0000000000000000000000000000000000000000..e18f6a7580fbfb4598f42461c03e7d4a99d85ac5 --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -0,0 +1,1137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's allreduce_fusion with trtllm backend + (fused allreduce + rmsnorm + optional FP8/FP4 quant) +2. FlashInfer's allreduce_fusion with mnnvl backend + (fused allreduce + rmsnorm only, no quantization support) +3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark_fused_collective.py + +""" + +import argparse +import itertools +import os +import time + +import pandas as pd +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.distributed import ( + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm # noqa +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa +from vllm.platforms import current_platform # noqa + +RMS_NORM_OP = torch.ops._C.rms_norm +FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm +RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant +FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant +) +SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant + +logger = init_logger(__name__) + +# Try to import FlashInfer +TorchDistBackend = None +try: + import flashinfer.comm as flashinfer_comm # type: ignore + from flashinfer.comm.mnnvl import ( # type: ignore + TorchDistBackend, + ) + + if not ( + hasattr(flashinfer_comm, "allreduce_fusion") + and hasattr(flashinfer_comm, "create_allreduce_fusion_workspace") + ): + flashinfer_comm = None + logger.warning("FlashInfer comm module found but missing allreduce_fusion API") +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +FP8_DTYPE = current_platform.fp8_dtype() +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensors for FlashInfer (keyed by backend name) +_FI_WORKSPACES: dict = {} + +# Backends to benchmark +FLASHINFER_BACKENDS = ["trtllm", "mnnvl"] + + +def setup_flashinfer_workspace( + backend: str, + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + dtype: torch.dtype, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global FI_WORKSPACES + + if flashinfer_comm is None: + return None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None + + try: + kwargs = {} + if TorchDistBackend is not None: + kwargs["comm_backend"] = TorchDistBackend(group=dist.group.WORLD) + + workspace = flashinfer_comm.create_allreduce_fusion_workspace( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + **kwargs, + ) + + _FI_WORKSPACES[backend] = workspace + return workspace + except Exception as e: + logger.error( + "Failed to setup FlashInfer workspace (backend=%s): %s", backend, e + ) + return None + + +def cleanup_flashinfer_workspaces(): + """Cleanup all FlashInfer workspaces.""" + if flashinfer_comm is None: + return + + for backend, workspace in _FI_WORKSPACES.items(): + try: + workspace.destroy() + except Exception as e: + logger.error( + "Failed to cleanup FlashInfer workspace (backend=%s): %s", + backend, + e, + ) + _FI_WORKSPACES.clear() + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + max_token_num: int = 1024, + ): + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_flashinfer_fused_allreduce_kwargs(self): + return { + "launch_with_pdl": self.launch_with_pdl, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + workspace: object, + use_oneshot: bool, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or workspace is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + layout_code = None + if workspace.backend == "trtllm": + layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4 + + flashinfer_comm.allreduce_fusion( + input=input_tensor, + workspace=workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + quant_out=None, + scale_out=None, + layout_code=layout_code, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + workspace: object, + use_oneshot: bool = True, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization. + + Note: Only supported by the trtllm backend. + """ + if flashinfer_comm is None or workspace is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.allreduce_fusion( + input=input_tensor, + workspace=workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + quant_out=quant_out, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + workspace: object, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization. + + Note: Only supported by the trtllm backend. + """ + if flashinfer_comm is None or workspace is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.allreduce_fusion( + input=input_tensor, + workspace=workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + quant_out=quant_out, + scale_out=output_scale, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), + ) + + +class VllmFusedAllreduce: + def __init__(self, hidden_dim, dtype): + self.rms_eps = 1e-6 + self.rms_norm = RMSNorm(hidden_dim, eps=self.rms_eps, dtype=dtype) + self.fp8_quant = QuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + ) + + def allreduce_rmsnorm( + self, input_tensor: torch.Tensor, residual: torch.Tensor | None + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + return self.rms_norm(allreduce_out, residual) + + def allreduce_rmsnorm_fp8_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + scale_factor: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out + else: + rms_out, residual_out = rms_out + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out, residual_out + + def allreduce_rmsnorm_fp4_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, output_scale + else: + rms_out, residual_out = rms_out + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, residual_out, output_scale + + +def create_test_tensors( + num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(num_tokens, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with graph_capture(device=device), torch.cuda.graph(graph): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: FlashInferFusedAllReduceParams | None, + workspaces: dict, + quant_modes: set[str], + no_oneshot: bool, +): + """Run all benchmarks for given configuration. + + Args: + allreduce_params: Shared parameters for FlashInfer fused allreduce. + workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace. + quant_modes: Set of quantization modes: "none", "fp8", "fp4". + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(num_tokens, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + use_oneshot_options = [False] if no_oneshot else [True, False] + + if "none" in quant_modes: + # Standard AllReduce + RMSNorm + # Re-create VllmFusedAllreduce per config so CustomOp binds the + # correct forward method (native vs custom kernel). + for custom_op in ["-rms_norm", "+rms_norm"]: + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op])) + ): + try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + suffix = ( + "_custom_rms_norm" if "+" in custom_op else "_native_rms_norm" + ) + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm, + input_tensor, + residual=residual, + ) + results[f"standard_allreduce_{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results[f"standard_allreduce_{suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + standard_allreduce_rmsnorm_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm (all backends) + for backend, workspace in workspaces.items(): + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_{backend}_fused_allreduce_rmsnorm{suffix}" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + workspace=workspace, + use_oneshot=use_oneshot, + ) + results[key] = time_ms + except Exception as e: + logger.error( + "FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s", + backend, + e, + ) + results[key] = float("inf") + + if "fp8" in quant_modes: + # Standard AllReduce + RMSNorm + FP8 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]: + op_suffix = suffix + ( + "_custom_quant_fp8" + if "+" in quant_fp8_custom_op + else "_native_quant_fp8" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op, quant_fp8_custom_op] + ) + ) + ): + try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results[f"standard_allreduce{op_suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results[f"standard_allreduce{op_suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=["-rms_norm", "-quant_fp8"] + ) + ) + ): + try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = ( + time_ms + ) + except Exception as e: + logger.error( + "Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only) + if "trtllm" in workspaces: + trtllm_ws = workspaces["trtllm"] + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant{suffix}" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + workspace=trtllm_ws, + use_oneshot=use_oneshot, + ) + results[key] = time_ms + except Exception as e: + logger.error( + "FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s", + e, + ) + results[key] = float("inf") + + if "fp4" in quant_modes and current_platform.has_device_capability(100): + # Standard AllReduce + RMSNorm + FP4 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op] + ) + ) + ): + try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results[f"standard_allreduce_{suffix}_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results[f"standard_allreduce_{suffix}_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + quant_out=fp4_quant_out, + input_global_scale=scale_fp4, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = ( + time_ms + ) + except Exception as e: + logger.error( + "Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only) + if "trtllm" in workspaces: + trtllm_ws = workspaces["trtllm"] + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant{suffix}" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + workspace=trtllm_ws, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=use_oneshot, + ) + results[key] = time_ms + except Exception as e: + logger.error( + "FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s", + e, + ) + results[key] = float("inf") + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results( + results_dict, + num_tokens, + hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, +): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print( + f"Results: num_tokens={num_tokens}, hidden_dim={hidden_dim} " + f"(input size: {input_size_mb:.2f} MB)" + ) + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_modes={','.join(sorted(list(quant_modes)))}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + lines: list[str] = [] + lines.append("# FlashInfer Fused Collective Operations Benchmark Results") + lines.append("") + lines.append(f"**World Size:** {world_size} ") + lines.append(f"**Hidden Dimension:** {args.hidden_dim} ") + lines.append(f"**Warmup Iterations:** {args.warmup} ") + lines.append(f"**Benchmark Trials:** {args.trials} ") + modes = ",".join(all_results[0]["quant_modes"]) if all_results else "N/A" + lines.append(f"**Quantization Modes:** {modes} ") + lines.append("") + lines.append("---") + lines.append("") + + for entry in all_results: + num_tokens = entry["num_tokens"] + dtype = entry["dtype"] + use_residual = entry["use_residual"] + results_dict = entry["results"] + input_size_mb = entry["input_size_mb"] + residual_str = "with residual" if use_residual else "no residual" + + lines.append( + f"## Configuration: num_tokens={num_tokens}, dtype={dtype}, {residual_str}" + ) + lines.append(f"**Input Size:** {input_size_mb:.2f} MB") + lines.append("") + + prepared = prepare_results_with_speedups(results_dict) + # Build DataFrame for markdown export + rows = [ + { + "Operation": r["operation"].replace("_", " ").title(), + "Time (ms)": r["time_str"], + "Speedup": r["speedup_str"], + } + for r in prepared + ] + df = pd.DataFrame(rows) + if df.empty: + lines.append("No results.") + else: + lines.append(df.to_markdown(index=False)) + lines.append("") + + return "\n".join(lines) + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "a") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--num-tokens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Numbers of tokens to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + parser.add_argument( + "--quant-modes", + type=str, + default="none,fp8,fp4", + help=( + "Comma-separated quantization modes to run: none, fp8, fp4. " + "Default: none,fp8,fp4" + ), + ) + + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + parser.add_argument( + "--no-oneshot", + action="store_true", + help="Skip oneshot benchmarks", + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Parse quantization modes + valid_quant_modes = {"none", "fp8", "fp4"} + raw_modes = [ + m.strip().lower() for m in (args.quant_modes or "").split(",") if m.strip() + ] + quant_modes = set(raw_modes) if raw_modes else {"none", "fp8", "fp4"} + invalid = sorted(list(quant_modes - valid_quant_modes)) + if invalid: + raise ValueError( + f"Invalid --quant-modes entries: {','.join(invalid)}. " + f"Valid options are: {','.join(sorted(valid_quant_modes))}." + ) + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization modes: %s", ",".join(sorted(list(quant_modes)))) + if flashinfer_comm is not None: + logger.info( + "FlashInfer available - will benchmark fused operations", + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + + configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) + + # Setup FlashInfer workspaces for all backends + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_element_size = max(torch.finfo(dt).bits // 8 for dt in dtypes) + workspace_dtype = ( + torch.float32 + if max_element_size == 4 + else (torch.bfloat16 if torch.bfloat16 in dtypes else torch.float16) + ) + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * max_element_size + ) + + for backend in FLASHINFER_BACKENDS: + setup_flashinfer_workspace( + backend=backend, + world_size=world_size, + rank=rank, + hidden_dim=args.hidden_dim, + max_token_num=max_num_token, + dtype=workspace_dtype, + ) + + if _FI_WORKSPACES: + allreduce_params = FlashInferFusedAllReduceParams( + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for num_tokens, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s", + num_tokens, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + num_tokens, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + workspaces=_FI_WORKSPACES, + quant_modes=quant_modes, + no_oneshot=args.no_oneshot, + ) + + # Store results for markdown export + if rank == 0: + # Calculate input size in MB + input_size_mb = ( + num_tokens * args.hidden_dim * torch.finfo(dtype).bits + ) / (8 * 1024 * 1024) + all_results.append( + { + "num_tokens": num_tokens, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_modes": sorted(list(quant_modes)), + "input_size_mb": input_size_mb, + "results": results, + } + ) + + print_results( + results, + num_tokens, + args.hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + cleanup_flashinfer_workspaces() + + dist.barrier() + + +if __name__ == "__main__": + from vllm.config import VllmConfig, set_current_vllm_config + + with set_current_vllm_config(VllmConfig()): + main() diff --git a/benchmarks/kernels/benchmark_fused_topk.py b/benchmarks/kernels/benchmark_fused_topk.py new file mode 100644 index 0000000000000000000000000000000000000000..72bf2d97cc9fde335b839641fbb341c1e3e75dad --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_topk.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools + +import torch + +from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser + +num_tokens_range = [2**i for i in range(0, 8, 2)] +num_experts_range = [16, 32, 64, 128, 256, 512] +topk_range = [3, 4] +configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) + + +def torch_topk( + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + scoring_func: str = "softmax", +): + if scoring_func == "softmax": + scores = torch.softmax(gating_output.float(), dim=-1) + else: + scores = torch.sigmoid(gating_output.float()) + topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_benchmark(scoring_func): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "num_experts", "topk"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch", "vllm"], + line_names=["Torch", "vLLM"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name=f"fused-topk-perf-{scoring_func}", + args={}, + ) + ) + def benchmark(num_tokens, num_experts, topk, provider): + dtype = torch.bfloat16 + hidden_size = 1024 + renormalize = True + hidden_states = torch.randn( + (num_tokens, hidden_size), dtype=dtype, device="cuda" + ) + gating_output = torch.randn( + (num_tokens, num_experts), dtype=dtype, device="cuda" + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch_topk( + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + scoring_func=scoring_func, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: fused_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + scoring_func=scoring_func, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the MoE topk kernel.") + parser.add_argument("--scoring-func", type=str, default="softmax") + parser.add_argument("--save-path", type=str, default="./configs/fused_topk/") + args = parser.parse_args() + + # Get the benchmark function + benchmark = get_benchmark(args.scoring_func) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..60ec94b878ce2c661ab92d65d05ad0a880bb264f --- /dev/null +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -0,0 +1,429 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES_MOE + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts, + fused_topk, +) +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.v1.worker.workspace import init_workspace_manager + +DEFAULT_MODELS = [ + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "deepseek-ai/DeepSeek-V2-Lite", + "ibm-granite/granite-3.0-1b-a400m", + "ibm-granite/granite-3.0-3b-a800m", +] +DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +def bench_run( + results: list[benchmark.Measurement], + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): + init_workspace_manager(torch.cuda.current_device()) + label = "Quant Matmul" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( + model, num_experts, topk, per_act_token, per_out_ch, mkn + ) + ) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 + + _, a_scale = ops.scaled_fp8_quant(a) + + w1_q = torch.empty( + (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn + ) + w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) + + for expert in range(num_experts): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + + score = torch.randn((m, num_experts), device="cuda", dtype=dtype) + + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score, topk, renormalize=False + ) + + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_scale: torch.Tensor, + num_repeats: int, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) + for _ in range(num_repeats): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + + def run_cutlass_moe( + a: torch.Tensor, + a_scale: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + per_act_token: bool, + num_repeats: int, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token, + ) + moe_config = make_dummy_moe_config( + num_experts=w2.shape[0], + hidden_dim=w2.shape[1], + intermediate_size_per_partition=w2.shape[2], + in_dtype=a.dtype, + ) + + fn = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + CutlassExpertsFp8( + moe_config=moe_config, + quant_config=quant_config, + ), + ) + + for _ in range(num_repeats): + fn(a, w1, w2, topk_weights, topk_ids) + + def run_cutlass_from_graph( + a: torch.Tensor, + a_scale: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token, + ) + moe_config = make_dummy_moe_config( + num_experts=w2.shape[0], + hidden_dim=w2.shape[1], + intermediate_size_per_partition=w2.shape[2], + in_dtype=a.dtype, + ) + + fn = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + CutlassExpertsFp8( + moe_config=moe_config, + quant_config=quant_config, + ), + ) + + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return fn(a, w1, w2, topk_weights, topk_ids) + + def run_triton_from_graph( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_scale: torch.Tensor, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph( + a, + a_scale, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + w1_scale, + w2_scale, + a_scale, + ) + torch.cuda.synchronize() + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "w1": w1, + "w2": w2, + "score": score, + "topk": topk, + # Cutlass params + "a_scale": a_scale, + "w1_q": w1_q, + "w2_q": w2_q, + "w1_scale": w1_scale, + "w2_scale": w2_scale, + "per_act_token": per_act_token, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "a": a, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe": run_cutlass_moe, + "replay_graph": replay_graph, + } + + # Warmup + run_triton_moe( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + w1_scale, + w2_scale, + a_scale, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + run_cutlass_moe( + a, + a_scale, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + per_act_token, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm_moe", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + +def main(args): + # Initialize workspace manager (required for CUTLASS MoE kernels) + device = torch.device("cuda:0") + init_workspace_manager(device) + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in DEFAULT_BATCH_SIZES: + mkn = (size_m, size_k, size_n) + bench_run( + results, + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches" + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_int8_gemm.py b/benchmarks/kernels/benchmark_int8_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c6d64404d0dc6958aeb675bf6b893623649ffa --- /dev/null +++ b/benchmarks/kernels/benchmark_int8_gemm.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "int8-tensor-w-token-a": dict( + w="tensor", a="token", no_a_quant=False, enabled=False + ), + "int8-tensor-w-tensor-a": dict( + w="tensor", a="tensor", no_a_quant=False, enabled=True + ), + "int8-channel-w-token-a": dict( + w="channel", a="token", no_a_quant=False, enabled=True + ), + "int8-channel-w-tensor-a": dict( + w="channel", a="tensor", no_a_quant=False, enabled=False + ), + "int8-tensor-w-token-a-noquant": dict( + w="tensor", a="token", no_a_quant=True, enabled=False + ), + "int8-tensor-w-tensor-a-noquant": dict( + w="tensor", a="tensor", no_a_quant=True, enabled=True + ), + "int8-channel-w-token-a-noquant": dict( + w="channel", a="token", no_a_quant=True, enabled=True + ), + "int8-channel-w-tensor-a-noquant": dict( + w="channel", a="tensor", no_a_quant=True, enabled=False + ), +} + + +def _quant_weight(b, w_type, device): + if w_type == "tensor": + scale_b = torch.ones(1, device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) + assert scale_b_int8.numel() == 1 + else: # channel + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) + assert scale_b_int8.numel() == b.shape[0] + return b_int8.t(), scale_b_int8 + + +def build_int8_runner(cfg, a, b, dtype, device): + # quant before running the kernel + b_int8, scale_b_int8 = _quant_weight(b, cfg["w"], device) + + scale_a_const = None + if cfg["a"] == "tensor": + scale_a_const = torch.ones(1, device=device, dtype=torch.float32) + + # no quant, create activation ahead + if cfg["no_a_quant"]: + if cfg["a"] == "tensor": + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) + else: # token + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + + def run_quant(): + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + return run_quant + + # dynamic quant, create activation inside + if cfg["a"] == "tensor": + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + else: # token + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + return run_quant + + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v.get("enabled")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=[k for k in _enabled], + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs INT8 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_int8_runner(cfg, a, b, dtype, device) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + print(f"{model}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_int8_res_n{N}_k{K}", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1c1cf09efbdecc66ccc9743018e60b57853233 --- /dev/null +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time + +import torch + +from vllm.benchmarks.lib.utils import default_vllm_config +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed + + +@torch.inference_mode() +@default_vllm_config() +def main( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: + set_random_seed(seed) + torch.set_default_device("cuda") + + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + layer(x, residual) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStop() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--add-residual", action="store_true") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored", + ) + + args = parser.parse_args() + print(args) + + main( + num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + ) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca3cf78f0fb22bee49becc5f4325398930a0c04 --- /dev/null +++ b/benchmarks/kernels/benchmark_lora.py @@ -0,0 +1,1490 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import json +import pickle +import time +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum, auto +from itertools import product +from pathlib import Path +from typing import Any + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import ArgPool, Bench, CudaGraphBenchParams +from weight_shapes import WEIGHT_SHAPES + +from vllm.lora.ops.triton_ops.utils import get_lora_op_configs +from vllm.triton_utils import HAS_TRITON, triton + +if HAS_TRITON: + from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora + LoRAKernelMeta, + fused_moe_lora_expand, + fused_moe_lora_shrink, + lora_expand, + lora_shrink, + ) + from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( + _LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora + ) + from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm import _custom_ops as ops +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.math_utils import round_up + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_TP_SIZES = [1] +DEFAULT_BATCH_SIZES = [ + 1, + 16, + 32, + 64, + 128, + 192, + 256, + 320, + 384, + 448, + 512, + 640, + 768, + 896, + 1024, + 2048, + 3072, + 4096, + 5120, + 6144, + 7168, + 8192, +] +DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384] +DEFAULT_LORA_RANKS = [16] +DEFAULT_NUM_LORAS = [1, 2, 3, 4] +DEFAULT_SORT_BY_LORA_IDS = [False, True] +DEFAULT_SEQ_LENGTHS = [1] +DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] +DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k +DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts + + +# Utilities +def dtype_to_str(dtype: torch.dtype): + if dtype == torch.float16: + return "f16" + if dtype == torch.bfloat16: + return "bf16" + if dtype == torch.float32: + return "f32" + raise ValueError(f"Unsupported dtype {dtype}") + + +def make_rand_lora_weight_tensor( + k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda" +) -> torch.Tensor: + # LoRA weights column major + return torch.rand((num_loras, n, k), dtype=dtype).to(device) + + +def make_rand_tensors( + a_shape: tuple[int, ...], + b_shape: tuple[int, ...], + c_shape: tuple[int, ...], + a_dtype: torch.dtype, + b_dtype: torch.dtype, + c_dtype: torch.dtype, + num_slices: int, + device: str = "cuda", +) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: + """ + Make LoRA input/output matrices. + """ + A = torch.rand(a_shape, dtype=a_dtype).to(device) + + # LoRA weights column major + Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)] + + C = torch.zeros(c_shape, dtype=c_dtype).to(device) + return A, Bs, C + + +def make_prompt_lora_mapping( + num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str +) -> torch.Tensor: + """ + All prompts are mapped to a LoRA ID in range [0, num_active_loras). + where 0 refers to first lora, 1 refers to second lora and so on. + """ + assert num_active_loras > 0 + + if not sort_by_lora_id: + return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long) + + # Divide LoRAs equally and in order. + part_size = num_prompts // num_active_loras + part_size = max(part_size, 1) + + lora_id = 0 + prompt_lora_mapping = [] + while len(prompt_lora_mapping) < num_prompts: + prompt_lora_mapping.extend([lora_id] * part_size) + lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id + return torch.tensor( + prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device + ) + + +def make_token_lora_mapping( + num_tokens: int, + num_prompts: int, + prompt_lora_mapping: torch.Tensor, + seq_len_tensor: torch.Tensor, + device: str, +): + """ + Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor + """ + assert prompt_lora_mapping.shape[0] == num_prompts + + # token to lora index mapping + token_lora_mapping = [0] * num_tokens + current_offset = 0 + for b_id in range(num_prompts): + lora_index = prompt_lora_mapping[b_id].item() + s = current_offset + e = s + seq_len_tensor[b_id].item() + token_lora_mapping[s:e] = [lora_index] * (e - s) + current_offset += seq_len_tensor[b_id].item() + + return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) + + +def ref_group_gemm( + ref_out: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + seq_lens_cpu: torch.Tensor, + prompt_lora_mapping_cpu: torch.Tensor, + scaling: float, + add_inputs: bool | None, +): + """ + Torch group gemm reference implementation to test correctness of + benchmarking operations. + """ + batches = seq_lens_cpu.size(0) + out_list = [] + current_offset = 0 + for lora_index, b_length in zip(range(batches), seq_lens_cpu): + x = input[current_offset : b_length + current_offset, :] + current_offset += b_length + w = lora_weights[prompt_lora_mapping_cpu[lora_index]] + result = torch.nn.functional.linear(x, w) + result *= scaling + out_list.append(result) + + cat_result = torch.cat(out_list, dim=0) + + if add_inputs: + ref_out += cat_result + else: + ref_out.copy_(cat_result) + + +class OpType(Enum): + """ + LoRA Ops to benchmark and its properties. + """ + + LORA_SHRINK = auto() + LORA_EXPAND = auto() + ## Adding support for fused moe lora + FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink + FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand + FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink + FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand + + @staticmethod + def from_str(s: str) -> "OpType": + if s.lower() == "lora_shrink": + return OpType.LORA_SHRINK + if s.lower() == "lora_expand": + return OpType.LORA_EXPAND + # Adding support for fused moe lora, both in gate_up and down + if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink + return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK + if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand + return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND + if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink + return OpType.FUSED_MOE_LORA_DOWN_SHRINK + if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand + return OpType.FUSED_MOE_LORA_DOWN_EXPAND + raise ValueError(f"Unrecognized str {s} to convert to OpType") + + def is_shrink_fn(self) -> bool: + return self in [OpType.LORA_SHRINK] + + def is_expand_fn(self) -> bool: + return self in [OpType.LORA_EXPAND] + + def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_gate_up_fn( + self, + ) -> bool: ## adding for fused MoE LoRA Gate/Up + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + ] + + def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down + return self in [ + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_shrink_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ] + + def is_fused_moe_lora_expand_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def num_slices(self) -> list[int]: + if self.is_fused_moe_lora_gate_up_fn(): + return [2] + elif self.is_fused_moe_lora_down_fn(): + return [1] + return [1, 2, 3] + + def mkn( + self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int + ) -> tuple[int, int, int]: + num_tokens = batch_size * seq_length + if self.is_shrink_fn() or self.is_fused_moe_lora_fn(): + m = num_tokens + k = hidden_size + n = lora_rank + elif self.is_expand_fn(): + m = num_tokens + k = lora_rank + n = hidden_size + return m, k, n + + def matmul_dtypes( + self, op_dtype: torch.dtype + ) -> tuple[torch.dtype, torch.dtype, torch.dtype]: + """ + return a type, b type and c type for A x B = C + """ + if self.is_shrink_fn(): + return op_dtype, op_dtype, torch.float32 + elif self.is_expand_fn(): + return torch.float32, op_dtype, op_dtype + else: + assert self.is_fused_moe_lora_fn() + return op_dtype, op_dtype, op_dtype + + def matmul_shapes_fused_moe_lora( + self, + m: int, + n: int, + k: int, + num_loras: int, + num_slices: int, + top_k_num: int, + num_experts: int, + ) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]: + if self.is_fused_moe_lora_shrink_fn(): + input_shape = ( + (m * top_k_num, n) + if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else (m, n) + ) + output_shape = (num_slices, m, top_k_num, k) + weight_shape = (num_loras, num_experts, k, n) + else: + assert self.is_fused_moe_lora_expand_fn() + input_shape = (num_slices, m, top_k_num, k) + output_shape = (m, top_k_num, n * num_slices) + weight_shape = (num_loras, num_experts, n, k) + return (input_shape, weight_shape, output_shape) + + def matmul_shapes( + self, + batch_size: int, + seq_length: int, + hidden_size: int, + lora_rank: int, + num_loras: int, + num_slices: int, + top_k_num: int | None = None, + num_experts: int | None = None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + """ + Given num_slices, return the shapes of the A, B, and C matrices + in A x B = C, for the op_type + """ + m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) + + b_shape = (num_loras, n, k) # col-major + if self in [OpType.LORA_SHRINK]: + # LoRA shrink kernels support num_slices inherently in the kernel. + return ((m, k), b_shape, (num_slices, m, n)) + if self in [OpType.LORA_EXPAND]: + # LoRA expand kernels support num_slices inherently in the kernel + return ((num_slices, m, k), b_shape, (m, n * num_slices)) + if self.is_fused_moe_lora_fn(): + return self.matmul_shapes_fused_moe_lora( + m, + k, + n, + num_loras, + num_slices, + top_k_num, + num_experts, + ) + raise ValueError(f"Unrecognized op_type {self}") + + def bench_fn(self) -> Callable: + if self == OpType.LORA_SHRINK: + return lora_shrink + if self == OpType.LORA_EXPAND: + return lora_expand + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ]: + return fused_moe_lora_shrink + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ]: + return fused_moe_lora_expand + + raise ValueError(f"Unrecognized optype {self}") + + def run_ref_group_gemm( + self, + output: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + **kwargs, + ) -> Callable: + """Each benchmark operation expects the input, lora_weights and outputs + in a slightly different format. Refer to self.matmul_shapes(). + run_ref_group_gemm accounts for those differences in executing a + reference group gemm for correctness testing. + """ + w_dtype = lora_weights[0].dtype + num_slices = len(lora_weights) + if self in [OpType.LORA_SHRINK]: + for slice_idx in range(num_slices): + ref_group_gemm( + ref_out=output[slice_idx, :], + input=input, + lora_weights=lora_weights[slice_idx], + **kwargs, + ) + elif self in [OpType.LORA_EXPAND]: + hidden_size = lora_weights[0].shape[1] + for slice_idx in range(num_slices): + slice_offset = slice_idx * hidden_size + ref_group_gemm( + ref_out=output[:, slice_offset : slice_offset + hidden_size], + input=input[slice_idx].clone().to(dtype=w_dtype), + lora_weights=lora_weights[slice_idx], + **kwargs, + ) + else: + raise ValueError(f"Unrecognized optype {self}") + + +@dataclass +class BenchmarkContext: + """ + LoRA benchmark context + """ + + batch_size: int + hidden_size: int + num_loras: int + num_active_loras: int + lora_rank: int + sort_by_lora_id: bool + dtype: torch.dtype + seq_length: int | None = None + num_experts: int | None = None # num_experts for MoE based ops + top_k_num: int | None = None # top_k for MoE based ops + num_slices: int | None = None # num_slices for slice based ops + + def with_seq_length(self, seq_length: int) -> "BenchmarkContext": + ctx = copy.copy(self) + ctx.seq_length = seq_length + return ctx + + def with_num_slices(self, num_slices: int) -> "BenchmarkContext": + ctx = copy.copy(self) + ctx.num_slices = num_slices + return ctx + + def bench_label(self) -> str: + return f"lora-{self.dtype}" + + def bench_sublabel(self, op_type: OpType) -> str: + m, k, n = op_type.mkn( + self.batch_size, self.seq_length, self.hidden_size, self.lora_rank + ) + desc = { + "bs": self.batch_size, + "sl": self.seq_length, + "m": m, + "k": k, + "n": n, + "num_loras": self.num_loras, + "sort_by_lora": self.sort_by_lora_id, + "num_slices": self.num_slices, + } + return json.dumps(desc) + + +@dataclass +class BenchmarkTensors: + """ + Input/Output tensors used for benchmarks + """ + + # matmul tensors + input: torch.Tensor + lora_weights_lst: list[torch.Tensor] + output: torch.Tensor + # LoRA kernel metadata + lora_kernel_meta: LoRAKernelMeta + # Metadata tensors used in testing correctness + seq_lens: torch.Tensor + prompt_lora_mapping: torch.Tensor + + def io_types(self) -> str: + return ( + f"{dtype_to_str(self.input.dtype)}x" + f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" + f"{dtype_to_str(self.output.dtype)}" + ) + + def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType): + return ( + size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size + ) + + @staticmethod + def make( + ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" + ) -> "BenchmarkTensors": + # Make input / output matmul tensors. + a_shape, b_shape, c_shape = op_type.matmul_shapes( + ctx.batch_size, + ctx.seq_length, + ctx.hidden_size, + ctx.lora_rank, + ctx.num_loras, + ctx.num_slices, + ctx.top_k_num, + ctx.num_experts, + ) + a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) + input_tensor, lora_weights, output_tensor = make_rand_tensors( + a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices + ) + + # Make metadata tensors. + # Keep the metadata tensors in the CPU for further processing if needed. + # The tensors get moved to the GPU before benchmarking. + assert ctx.num_active_loras <= ctx.num_loras + total_tokens = ctx.batch_size * ctx.seq_length + + # Make metadata tensors involved in correctness testing. + # Prepare seq lens tensor + seq_len_tensor = torch.randint( + ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,) + ) + assert total_tokens == seq_len_tensor.sum() + # Prepare prompt lora indices tensor + prompt_lora_indices_tensor = make_prompt_lora_mapping( + ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu" + ) + + # Make LoRAKernelMeta + token_lora_indices_tensor = make_token_lora_mapping( + total_tokens, + ctx.batch_size, + prompt_lora_indices_tensor, + seq_len_tensor, + "cpu", + ) + lora_kernel_meta = LoRAKernelMeta.make( + max_loras=ctx.num_loras, + max_num_tokens=token_lora_indices_tensor.size(0), + device="cpu", + ) + lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor) + + return BenchmarkTensors( + input_tensor, + lora_weights, + output_tensor, + lora_kernel_meta, + seq_len_tensor, + prompt_lora_indices_tensor, + ) + + def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None: + """ + Fails asserts when non-conformality is detected. + """ + num_tokens = ( + self.input.shape[1] + if op_type.is_fused_moe_lora_expand_fn() + else self.input.shape[-2] + ) + # check metadata tensors + ## In down shrink case, each token is repeated top_k_num times + assert num_tokens == self.get_num_tokens( + torch.sum(self.seq_lens), ctx.top_k_num, op_type + ), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}" + num_seqs = self.seq_lens.shape[0] + # assert self.seq_start_loc.shape[0] == num_seqs + ## In down shrink case, each prompt corresponds to top_k_num sequences + assert self.prompt_lora_mapping.shape[0] == num_seqs + assert self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type + ) + + def to_device(self, device: str): + """ + Transfer tensors to device if the tensors aren't already on the device + """ + + def to_device(tensor: torch.Tensor): + if tensor.device != device: + tensor = tensor.to(device=device) + return tensor + + self.input = to_device(self.input) + self.output = to_device(self.output) + self.seq_lens = to_device(self.seq_lens) + self.prompt_lora_mapping = to_device(self.prompt_lora_mapping) + for i in range(len(self.lora_weights_lst)): + self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) + + # LoRA meta + for field_name in LoRAKernelMeta.__dataclass_fields__: + field = getattr(self.lora_kernel_meta, field_name) + assert isinstance(field, torch.Tensor) + setattr( + self.lora_kernel_meta, + field_name, + to_device(field) if field_name != "no_lora_flag_cpu" else field, + ) + + def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]: + """ + Return num_seqs, num_tokens and max_seq_len + """ + num_seqs = self.seq_lens.shape[0] + num_tokens = self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type + ) + max_seq_len = torch.max(self.seq_lens).item() + num_slices = len(self.lora_weights_lst) + return num_seqs, num_tokens, max_seq_len, num_slices + + def fused_moe_lora_data_prepare( + self, + block_size: int, + token_lora_mapping: torch.Tensor, + ctx: BenchmarkContext, + ): + def moe_lora_align_block_size( + topk_ids: torch.Tensor, + token_lora_mapping: torch.Tensor, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) + + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad + + num_tokens = ctx.batch_size + curr_topk_ids = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + topk_weights = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + + (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = ( + moe_lora_align_block_size( + topk_ids=curr_topk_ids, + token_lora_mapping=token_lora_mapping, + block_size=block_size, + num_experts=ctx.num_experts, + max_loras=ctx.num_loras, + ) + ) + + sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1) + expert_ids = expert_ids_lora.view(ctx.num_loras, -1) + num_tokens_post_padded = num_tokens_post_padded_lora + return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) + + def as_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + # Expected input shape [num_tokens, hidden_size] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [num_loras, lora_rank, hidden_size] + assert len(lw_shape) == 3 + assert lw_shape[2] == hidden_size + lora_rank = lw_shape[1] + # Expected output shape [num_slices, num_tokens, lora_rank] + assert len(o_shape) == 3 + assert o_shape == (num_slices, num_tokens, lora_rank) + + return { + "inputs": self.input, + "lora_a_weights": self.lora_weights_lst, + "output_tensor": self.output, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "token_indices_sorted_by_lora_ids": ( + self.lora_kernel_meta.token_indices_sorted_by_lora_ids + ), + "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, + "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, + "lora_ids": self.lora_kernel_meta.active_lora_ids, + "scaling": 1.0, + "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, + } + + def as_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + # Expected input shape : [num_slices, num_tokens, lora_rank] + assert len(i_shape) == 3 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[2] + # Expected lora weight shape : [num_lora, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape : [num_tokens, hidden_size * num_slices] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size * num_slices) + + return { + "inputs": self.input, + "lora_b_weights": self.lora_weights_lst, + "output_tensor": self.output, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "token_indices_sorted_by_lora_ids": ( + self.lora_kernel_meta.token_indices_sorted_by_lora_ids + ), + "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, + "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, + "lora_ids": self.lora_kernel_meta.active_lora_ids, + "offset_start": 0, + "add_inputs": add_inputs, + "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, + } + + def as_fused_moe_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + # Expected input shape : [num_tokens, hidden_size] for gate_up + # Expected input shape : [top_k_num * num_tokens, hidden_size] for down + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size] + assert len(lw_shape) == 4 + assert lw_shape[-1] == hidden_size + lora_rank = lw_shape[-2] + # Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(o_shape) == 4 + assert ( + o_shape + == (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank) + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank) + ) + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "qcurr_hidden_states": self.input, + "lora_a_stacked": self.lora_weights_lst, + "a_intermediate_cache1": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "shrink_block_size_m": kernel_config["BLOCK_SIZE_M"], + "shrink_block_size_n": kernel_config["BLOCK_SIZE_N"], + "shrink_block_size_k": kernel_config["BLOCK_SIZE_K"], + "shrink_group_size_m": kernel_config["GROUP_SIZE_M"], + "shrink_num_warps": kernel_config["NUM_WARPS"], + "shrink_num_stages": kernel_config["NUM_STAGES"], + "shrink_split_k": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + + def as_fused_moe_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + + # Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(i_shape) == 4 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[-1] + # Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank] + assert len(lw_shape) == 4 + assert lw_shape[-1] == lora_rank + hidden_size = lw_shape[-2] + # Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices] + assert len(o_shape) == 3 + assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices) + + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "a_intermediate_cache1": self.input, + "lora_b_stacked": self.lora_weights_lst, + "output": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "max_lora_rank": lora_rank, + "w1_output_dim_size": lw_shape[2], + "expand_block_size_m": kernel_config["BLOCK_SIZE_M"], + "expand_block_size_n": kernel_config["BLOCK_SIZE_N"], + "expand_block_size_k": kernel_config["BLOCK_SIZE_K"], + "expand_group_size_m": kernel_config["GROUP_SIZE_M"], + "expand_num_warps": kernel_config["NUM_WARPS"], + "expand_num_stages": kernel_config["NUM_STAGES"], + "expand_split_k": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + + def bench_fn_kwargs( + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None + ) -> dict[str, Any]: + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): + assert add_inputs is None + else: + assert add_inputs is not None + + if op_type == OpType.LORA_SHRINK: + return self.as_lora_shrink_kwargs(ctx, op_type) + if op_type == OpType.LORA_EXPAND: + return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) + if op_type.is_fused_moe_lora_shrink_fn(): + return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) + if op_type.is_fused_moe_lora_expand_fn(): + return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) + raise ValueError(f"Unrecognized optype {self}") + + def test_correctness( + self, op_type: OpType, expand_fn_add_inputs: bool | None + ) -> bool: + """ + Test correctness of op_type implementation against a grouped gemm + reference implementation. + """ + seq_lens_cpu = self.seq_lens.to(device="cpu") + prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu") + ref_output = self.output.clone() + + self.output.zero_() + op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) + + op_type.run_ref_group_gemm( + ref_output, + self.input, + self.lora_weights_lst, + seq_lens_cpu=seq_lens_cpu, + prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, + scaling=1.0, + add_inputs=expand_fn_add_inputs, + ) + + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[self.output.dtype] + + return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) + + +def bench_optype( + ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: int | None = None, + expand_fn_add_inputs: bool | None = None, + test_correctness: bool = False, +) -> TMeasurement: + assert arg_pool_size >= 1 + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): + assert expand_fn_add_inputs is None + else: + assert expand_fn_add_inputs is not None + + # BenchmarkContext -> BenchmarkTensors + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) + ] + for bt in bench_tensors: + bt.sanity_check(ctx, op_type) + + # Test correctness of our implementation. + if test_correctness: + assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], ( + f"Correctness testing is not supported for {op_type.name}." + ) + assert all( + [ + bt.test_correctness(ctx, op_type, expand_fn_add_inputs) + for bt in bench_tensors + ] + ) + + # BenchmarkTensors -> dict (kwargs) + kwargs_list = [ + bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs) + for bt in bench_tensors + ] + + # Clear LoRA optimization hash-maps. + _LORA_A_PTR_DICT.clear() + _LORA_B_PTR_DICT.clear() + _LORA_PTR_DICT.clear() + # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up + for kwargs in kwargs_list: + op_type.bench_fn()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + describe_args = ( + f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "" + ) + description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})" + + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) + timer = None + with Bench( + cuda_graph_params, + ctx.bench_label(), + ctx.bench_sublabel(op_type), + description, + op_type.bench_fn(), + **kwargs, + ) as bench: + timer = bench.run() + return timer + + +def bench_torch_mm( + ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: int | None = None, +) -> TMeasurement: + """ + Benchmark basic torch.mm as a roofline. + + When all the input tokens have the same LoRA ID, the LoRA kernels are just + a matmul. This torch.mm benchmark serves as a roofline for that case. + + input op_type is used in determining the m, k, n dimensions for the matmul. + """ + + batch_size, hidden_size, lora_rank, seq_length, dtype = ( + ctx.batch_size, + ctx.hidden_size, + ctx.lora_rank, + ctx.seq_length, + ctx.dtype, + ) + + m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) + # For a fairer comparison. + n = n * ctx.num_slices + + # Get matmul input and output tensors for A x B = C + As, Bs, Cs = [], [], [] + for _ in range(arg_pool_size): + As.append(torch.rand((m, k), dtype=dtype).to("cuda")) + Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t()) + Cs.append(torch.rand((m, n), dtype=dtype).to("cuda")) + + # Make torch.mm kwargs + mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)} + + description = ( + f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}" + f"x{dtype_to_str(dtype)}" + f"=>{dtype_to_str(dtype)})" + ) + cuda_graph_params = None + if cuda_graph_nops: + cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) + with Bench( + cuda_graph_params, + ctx.bench_label(), + ctx.bench_sublabel(op_type), + description, + torch.mm, + **mm_kwargs, + ) as bench: + return bench.run() + + +# runner +def use_cuda_graph_recommendation() -> str: + return """ + Triton kernels have a significant launch overhead with + launched directly via python. This overhead is more noticeable + for small the problem sizes. For these cases, it is recommended + to use the script with `--cuda-graph-nops N` to benchmark N + consecutive invocations of the benchmarking operations from + inside a CUDA Graph. Note that the returned measurement is for N + invocations of the operation. + """ + + +def print_timers(timers: list[TMeasurement], args: argparse.Namespace | None = None): + compare = TBenchmark.Compare(timers) + compare.print() + + if args and args.cuda_graph_nops: + print( + f"Note : The timings reported above is for {args.cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {args.cuda_graph_nops} for single invocation " + "timings." + ) + + print( + "Note on Comparison with torch.mm : The torch.mm numbers are " + "benchmark numbers of a simple matmul emulating the single lora " + "case. It is provided as a roofline for comparing our LoRA Kernel " + "implementations. It is expected that the LoRA kernels will be " + "slower than torch.mm in cases where num_loras is big. But for " + "small num_loras the goal should be to match the torch.mm numbers." + ) + + +def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): + if args.cuda_graph_nops is not None: + assert args.cuda_graph_nops > 0 + print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph") + else: + print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") + + timers = [] + for bench_ctx in bench_ctxs: + for seq_len in args.seq_lengths: + bench_ops: list[OpType] = args.op_types + seq_len_timers = [] + for bench_op in bench_ops: + for num_slices in bench_op.num_slices(): + _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( + num_slices + ) + # Benchmark torch.mm as a roofline + seq_len_timers.append( + bench_torch_mm( + _ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops + ) + ) + + # Benchmark bench_op + expand_fn_add_inputs = ( + [None] + if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn() + else args.expand_fn_add_inputs + ) + for add_input_arg in expand_fn_add_inputs: + seq_len_timers.append( + bench_optype( + _ctx, + args.arg_pool_size, + bench_op, + args.cuda_graph_nops, + add_input_arg, + args.test_correctness, + ) + ) + + print_timers(seq_len_timers) + timers.extend(seq_len_timers) + + # Result stdout dump + print("== All Results ====") + print_timers(timers, args) + + if args.output_directory: + # Result file dump + od = Path(args.output_directory) + if not od.exists(): + od.mkdir() + + timestamp = int(time.time()) + pkl_file = od / f"lora_bench-{timestamp}.pkl" + print(f"Writing benchmarks to {pkl_file}") + with open(pkl_file, "wb") as f: + pickle.dump(timers, f) + + +def as_benchmark_contexts( + hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace +) -> list[BenchmarkContext]: + ctxs: list[BenchmarkContext] = [] + for ( + batch_size, + hidden_size, + lora_rank, + num_loras, + sort_by_lora_id, + top_k_num, + num_experts, + ) in product( # noqa + args.batch_sizes, + list(hidden_sizes), + lora_ranks, + args.num_loras, + args.sort_by_lora_id, + args.top_k_nums, + args.num_experts, + ): + ctxs.append( + BenchmarkContext( + batch_size=batch_size, + hidden_size=hidden_size, + lora_rank=lora_rank, + num_loras=num_loras, + num_active_loras=args.num_active_loras + if args.num_active_loras + else num_loras, + # To be filled based on the OpType to benchmark + seq_length=None, + sort_by_lora_id=sort_by_lora_id, + dtype=args.dtype, + top_k_num=top_k_num, + num_experts=num_experts, + # To be filled based on the OpType to benchmark + num_slices=None, + ) + ) + + return ctxs + + +def run_list_bench(args: argparse.Namespace): + print(args) + + print( + "List bench :\n" + f" Hidden Sizes {args.hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}" + ) + + # Get all benchmarking contexts + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args + ) + + run(args, bench_contexts) + + +def run_range_bench(args: argparse.Namespace): + print(args) + + hidden_sizes = list( + range( + args.hidden_sizes_start, + args.hidden_sizes_end + 1, + args.hidden_sizes_increment, + ) + ) + lora_ranks = list( + range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment) + ) + + print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args + ) + + run(args, bench_contexts) + + +def run_model_bench(args: argparse.Namespace): + print(args) + + def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]: + hidden_sizes = set() + for KN, tp_split_dim in WEIGHT_SHAPES[model]: + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + hidden_sizes.add(KN[1]) + return hidden_sizes + + # Get all hidden sizes + hidden_sizes: set[int] = set() + for model_name, tp_size in product(args.models, args.tp_sizes): + hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size)) + + print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}") + + # Get all benchmarking contexts + bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( + hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args + ) + + run(args, bench_contexts) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + if dt == "torch.float16": + return torch.float16 + if dt == "torch.bfloat16": + return torch.bfloat16 + raise ValueError("unsupported dtype") + + def get_bool(s: str) -> bool: + return s.lower() in ["true", "1"] + + def add_common_command_args(p: argparse.ArgumentParser): + p.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['torch.float16', 'torch.bfloat16']", + ) + + p.add_argument( + "--arg-pool-size", + type=int, + default=32, + help="Run profiles with a pool of input/output/meta tensors instead" + "of simply reusing the same tensors for all runs. A bigger arg-pool" + "mitigates hardware caching effects during benchmarking.", + ) + + p.add_argument( + "--cuda-graph-nops", + type=int, + help=( + "when set profiling is done using cudagraph, " + "with the given number of operations in a graph." + "Note that the measurement returned is the time " + "taken for N consecutive executions of the benchmarking " + "functions, where N is the value of this argument." + ), + ) + p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS) + p.add_argument( + "--num-active-loras", + type=int, + default=None, + help="Active LoRAs. When None, all LoRAs are active", + ) + p.add_argument( + "--sort-by-lora-id", + nargs="+", + type=get_bool, + default=DEFAULT_SORT_BY_LORA_IDS, + ) + p.add_argument( + "--op-types", nargs="+", type=OpType.from_str, default=list(OpType) + ) + p.add_argument( + "--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS + ) + p.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + p.add_argument( + "--expand-fn-add-inputs", + nargs="+", + type=get_bool, + default=DEFAULT_EXPAND_FN_ADD_INPUTS, + ) + p.add_argument( + "-o", + "--output-directory", + type=str, + help=( + "Output directory to store a the list of benchmarking" + "TMeasurement objects as a pickle file" + ), + ) + + p.add_argument( + "--test-correctness", + action="store_true", + help=( + "When enabled, the benchmarking functions are tested" + "for correctness before the actual benchmarking" + ), + ) + + p.add_argument( + "--top-k-nums", + nargs="+", + type=int, + default=DEFAULT_TOP_K_NUMS, + help="Top-K values for MoE LoRA operations", + ) + + p.add_argument( + "--num-experts", + nargs="+", + type=int, + default=DEFAULT_NUM_EXPERTS, + help="Number of experts for MoE LoRA operations", + ) + + parser = FlexibleArgumentParser( + description=f""" +Benchmark LoRA kernels: + {use_cuda_graph_recommendation()} + + list_bench example: + python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + + model_bench example: + python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 + + range_bench example: + python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + + subparsers = parser.add_subparsers(dest="cmd", required=True) + + list_parser = subparsers.add_parser("list_bench") + list_parser.add_argument( + "--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES + ) + list_parser.add_argument( + "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS + ) + add_common_command_args(list_parser) + list_parser.set_defaults(func=run_list_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--hidden-sizes-start", type=int, required=True) + range_parser.add_argument("--hidden-sizes-end", type=int, required=True) + range_parser.add_argument("--hidden-sizes-increment", type=int, required=True) + range_parser.add_argument("--lora-ranks-start", type=int, required=True) + range_parser.add_argument("--lora-ranks-end", type=int, required=True) + range_parser.add_argument("--lora-ranks-increment", type=int, required=True) + add_common_command_args(range_parser) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS + ) + add_common_command_args(model_parser) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6f09866555c171f9f2eb77beb86ac9edd45357 --- /dev/null +++ b/benchmarks/kernels/benchmark_machete.py @@ -0,0 +1,745 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import itertools +import math +import os +import pickle as pkl +import time +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from itertools import product + +import pandas as pd +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + marlin_permute_scales, + marlin_zero_points, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, + quantize_weights, +) +from vllm.scalar_type import ScalarType, scalar_types +from vllm.utils.argparse_utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] +DEFAULT_TP_SIZES = [1] + +NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) + +if NVTX_PROFILE: + import nvtx + + +def terse_type_name(dt): + return { + torch.bfloat16: "bf16", + torch.float16: "fp16", + torch.int8: "int8", + torch.float8_e4m3fn: "fp8", + torch.float: "float", + torch.int: "int", + }[dt] + + +@dataclass +class BenchmarkTensors: + w_ref: torch.Tensor + a: torch.Tensor + + w_q: torch.Tensor + group_size: int | None + wtype: ScalarType + w_g_s: torch.Tensor + w_g_zp: torch.Tensor | None + w_ch_s: torch.Tensor | None + w_tok_s: torch.Tensor | None + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: torch.dtype | None + group_scale_type: torch.dtype | None + group_zero_type: torch.dtype | None + channel_scale_type: torch.dtype | None + token_scale_type: torch.dtype | None + + +def rand_data(shape, dtype=torch.float16, scale=1): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype) + else: + return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") + + +def quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: torch.dtype | None, + group_size: int | None, + zero_points: bool = False, +): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True, + ) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + return w_ref, w_q, w_s, w_zp + + +def create_bench_tensors( + shape: tuple[int, int, int], types: TypeConfig, group_size: int | None +) -> list[BenchmarkTensors]: + m, n, k = shape + + # we want to make sure that weights don't fit into L2 cache between runs so + # we construct enough weights to exceed L2 cache, which is 50mb on a H100 + # so we target total weight size > 2*50mb + num_weights = math.ceil( + 2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits) + ) + + a = rand_data((m, k), types.act_type, scale=5) + + benchmark_tensors: list[BenchmarkTensors] = [] + for _ in range(num_weights): + w = rand_data((k, n), types.act_type, scale=5) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( + a.dtype, + w, + types.weight_type, + types.group_scale_type, + group_size, + types.group_zero_type is not None, + ) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + w_ref = w_ref.to(torch.float32) + + w_ch_s = ( + None + if types.channel_scale_type is None + else rand_data((n,), types.channel_scale_type) + ) + w_tok_s = ( + None + if types.token_scale_type is None + else rand_data((m,), types.token_scale_type) + ) + + benchmark_tensors.append( + BenchmarkTensors( + w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) + ) + + return benchmark_tensors + + +def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable: + a = bt.a + w = bt.w_ref.to(bt.a.dtype) # use float reference tensor + if a.dtype not in [torch.float16, torch.bfloat16]: + a = a.to(torch.float16) + w = w.to(torch.float16) + return lambda: torch.matmul(a, w) + + +def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: + if bt.w_ch_s is not None and bt.w_tok_s is not None: + scale_a = bt.w_tok_s.to(torch.float32) + scale_b = bt.w_ch_s.to(torch.float32) + else: + scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() + return lambda: ops.cutlass_scaled_mm( + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16 + ) + + +def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: + device = bt.a.device + + workspace = MarlinWorkspace( + bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ) + + if bt.w_g_zp is None: + w_zp = torch.empty(0, dtype=torch.int, device=device) + else: + w_zp = marlin_zero_points( + bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits + ) + + if bt.group_size is None: + w_s = torch.tensor([], device="cuda", dtype=torch.half) + else: + w_s = marlin_permute_scales( + bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size + ) + + sort_indices = torch.empty(0, dtype=torch.int, device=device) + g_idx = torch.empty(0, dtype=torch.int, device=device) + w_q = ops.gptq_marlin_repack( + bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits + ) + + if bt.a.dtype.is_floating_point: + assert bt.w_ch_s is None + assert bt.w_tok_s is None + assert bt.group_size is not None + + fn = lambda: ops.marlin_gemm( + a=bt.a, + c=None, + b_q_weight=w_q, + b_bias=None, + b_scales=w_s, + a_scales=None, + global_scale=None, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True, + is_zp_float=False, + ) + else: + assert bt.a.dtype == torch.int8 + assert bt.wtype == scalar_types.uint4b8 + raise NotImplementedError("QQQ is not supported anymore") + + return fn + + +def machete_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.machete_prepack_B( + w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype + ) + + w_g_zp = bt.w_g_zp + if w_g_zp is not None: + w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype)) + + return lambda: ops.machete_mm( + a=bt.a, + b_q=w_q, + b_type=bt.wtype, + b_group_scales=bt.w_g_s, + b_group_zeros=w_g_zp, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + out_type=out_type, + schedule=schedule, + ) + + +def cutlass_w4a8_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.cutlass_encode_and_reorder_int4b(w_q) + # expects fp8 scales + w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn)) + + return lambda: ops.cutlass_w4a8_mm( + a=bt.a, + b_q=w_q, + b_group_scales=w_s, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + maybe_schedule=schedule, + ) + + +# impl + +# bench + + +def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]): + min_run_time = 1 if not NVTX_PROFILE else 0.1 + res = TBenchmark.Timer( + stmt=""" + for fn in fns: + fn() + """, + globals={"fns": fns}, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + if NVTX_PROFILE: + with ( + nvtx.annotate("mm-bench"), + nvtx.annotate(f"{label}|{sub_label}|{description}"), + ): + fns[0]() + + return res + + +_SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None +_SWEEP_SCHEDULES_RESULTS_CSV: str | None = None + + +def bench( + types: TypeConfig, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + sweep_schedules: bool = True, +) -> list[TMeasurement]: + benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) + sub_label += f", L={len(benchmark_tensors)}" + + name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}" + if types.group_scale_type is not None: + name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" + if types.group_zero_type is not None: + name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}" + if group_size is not None: + name_type_string += f"-G{group_size}" + if types.channel_scale_type is not None: + name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}" + if types.token_scale_type is not None: + name_type_string += f"-TS{terse_type_name(types.token_scale_type)}" + + timers = [] + # pytorch impl + timers.append( + bench_fns( + label, + sub_label, + "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) + + if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: + timers.append( + bench_fns( + label, + sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", + [cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) + + if types.act_type != torch.float8_e4m3fn: + timers.append( + bench_fns( + label, + sub_label, + f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) + + # machete + timers.append( + bench_fns( + label, + sub_label, + f"machete ({name_type_string})", + [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) + + # cutlass w4a8 + if types.act_type == torch.float8_e4m3fn and group_size == 128: + timers.append( + bench_fns( + label, + sub_label, + f"cutlass w4a8 ({name_type_string})", + [ + cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) + + if sweep_schedules: + global _SWEEP_SCHEDULES_RESULTS + + print("Finding best schedule for machete") + best = None + best_schedule = None + schedules = ops.machete_supported_schedules( + a_type=types.act_type, + b_type=types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_zero_type, + token_scales_type=types.token_scale_type, + channel_scales_type=types.channel_scale_type, + out_type=types.output_type, + ) + + if schedules is None or len(schedules) == 0: + raise ValueError("No schedules found to sweep") + + for schedule in reversed(schedules): + schedule_M = int(schedule.split("_")[0].split("x")[1]) + + # Prune known bad schedules + if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: + continue + + res = bench_fns( + label, + sub_label, + "machete_best", + [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule + ) + for bt in benchmark_tensors + ], + ) + + results_row = { + "M": m, + "K": k, + "N": n, + "group_size": group_size, + "schedule": schedule, + "median": res.median, + } + if _SWEEP_SCHEDULES_RESULTS is None: + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + + print(f" {res.median:5.5} ", schedule) + if not best or res.median < best.median: + best = res + best_schedule = schedule + print("Best schedule:", best_schedule) + timers.append(best) + + return timers + + +# runner +def print_timers(timers: list[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: + types = TypeConfig( + act_type=args.act_type, + weight_type=scalar_types.uint4b8 + if args.group_zero_type is None + else scalar_types.uint4, + output_type=args.out_type, + group_scale_type=args.group_scale_type, + group_zero_type=args.group_zero_type, + channel_scale_type=args.channel_scale_type, + token_scale_type=args.token_scale_type, + ) + + results: list[TMeasurement] = [] + for m, k, n in MKNs: + timers = bench( + types, + args.group_size, + m, + k, + n, + f"{args.act_type}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=args.sweep_schedules, + ) + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output( + data: list[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + m_start, k_start, n_start = (int(x) for x in args.dim_start.split(",")) + m_end, k_end, n_end = (int(x) for x in args.dim_end.split(",")) + m_increment, k_increment, n_increment = ( + int(x) for x in args.dim_increment.split(",") + ) + Ms = list(range(m_start, m_end + 1, m_increment)) + Ks = list(range(k_start, k_end + 1, k_increment)) + Ns = list(range(n_start, n_end + 1, n_increment)) + MKNs = list(product(Ms, Ks, Ns)) + + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args, MKNs) + model_bench_data.append(data) + + type_string = f"{args.act_type}" + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {type_string} {model}-TP{tp_size} ====") + print_timers(data) + + timestr = time.strftime("%Y%m%d-%H%M%S") + + all_results = [] + for d in model_bench_data: + all_results.extend(d) + + # pickle all data + with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: + args_dict = vars(args) + args_dict.pop("func") + pkl.dump( + { + "args": args_dict, + "results": all_results, + }, + f, + ) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + return { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "int8": torch.int8, + "float8_e4m3fn": torch.float8_e4m3fn, + "int": torch.int, + "float": torch.float, + }[dt] + + class ToTorchDtype(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, to_torch_dtype(values)) + + parser = FlexibleArgumentParser( + description=""" +Benchmark Machete GEMM. + + To run square GEMMs: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--act-type", + action=ToTorchDtype, + required=True, + choices=["bfloat16", "float16", "int8", "float8_e4m3fn"], + ) + parser.add_argument( + "--group-scale-type", + action=ToTorchDtype, + choices=["bfloat16", "float16"], + ) + parser.add_argument( + "--group-zero-type", + type=to_torch_dtype, + choices=["bfloat16", "float16"], + ) + parser.add_argument( + "--channel-scale-type", + action=ToTorchDtype, + choices=["float"], + ) + parser.add_argument( + "--token-scale-type", + action=ToTorchDtype, + choices=["float"], + ) + parser.add_argument( + "--out-type", + action=ToTorchDtype, + choices=["bfloat16", "float16"], + ) + parser.add_argument( + "--group-size", + type=int, + help="Available options are ['None', '-1', '128'], default=128", + default=128, + ) + parser.add_argument( + "--sweep-schedules", + action="store_true", + help="Run a sweep over all supported schedules", + ) + parser.add_argument( + "--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv", + ) + subparsers = parser.add_subparsers(dest="cmd", required=True) + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument( + "--dim-start", + type=str, + required=True, + help="Start value for M,K,N as common separated list", + ) + range_parser.add_argument( + "--dim-end", + type=str, + required=True, + help="End value (inclusive) for M,K,N as common separated list", + ) + range_parser.add_argument( + "--dim-increment", + type=str, + required=True, + help="Increment value for M,K,N as common separated list", + ) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + + _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out + args.func(args) + + if _SWEEP_SCHEDULES_RESULTS is not None: + _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..c0019a51cdd0e452c33b6eb2dc36abd9880a1508 --- /dev/null +++ b/benchmarks/kernels/benchmark_marlin.py @@ -0,0 +1,365 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_SUPPORTED_QUANT_TYPES, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + MARLIN_SUPPORTED_GROUP_SIZES, + query_marlin_supported_quant_types, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + FP4_MARLIN_SUPPORTED_GROUP_SIZES, + rand_marlin_weight_fp4_like, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + marlin_quant_fp8_torch, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace, + awq_marlin_quantize, + marlin_quantize, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) +from vllm.scalar_type import ScalarType, scalar_types +from vllm.utils.argparse_utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + + +def bench_run( + results: list[benchmark.Measurement], + model: str, + act_order: bool, + is_k_full: bool, + quant_type: ScalarType, + group_size: int, + size_m: int, + size_k: int, + size_n: int, +): + label = "Quant Matmul" + sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( + model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n + ) + print(f"Testing: {sub_label}") + + a = torch.randn(size_m, size_k).to(torch.half).cuda() + b = torch.rand(size_k, size_n).to(torch.half).cuda() + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + if act_order and (group_size == -1 or group_size == size_k or has_zp): + return + if size_k % group_size != 0: + return + + repack_supported = group_size in MARLIN_SUPPORTED_GROUP_SIZES + allspark_supported = ( + quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 + and not act_order + and is_k_full + ) + + def gen_marlin_params(): + # Marlin quant + marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None + if quant_type == scalar_types.float4_e2m1f: + if group_size != 16 or act_order: + return + marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( + b.T, group_size + ) + elif quant_type == scalar_types.float8_e4m3fn: + if group_size not in [-1, 128] or act_order: + return + marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size) + elif group_size == 16: + return + elif has_zp: + marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b, quant_type, group_size + ) + else: + marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = ( + marlin_quantize(b, quant_type, group_size, act_order) + ) + return ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_s2, + marlin_zp, + marlin_g_idx, + marlin_sort_indices, + ) + + def gen_repack_params(): + q_w_gptq = None + repack_sort_indices = None + if repack_supported: + (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( + b, quant_type, group_size, act_order + ) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) + return q_w_gptq, repack_sort_indices + + def gen_allspark_params(): + qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = ( + CUBLAS_M_THRESHOLD + ) = None + nonlocal allspark_supported + if allspark_supported: + properties = torch.cuda.get_device_properties(b.device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + + supported_arch = sm_version >= 80 and sm_version < 90 + allspark_supported = allspark_supported and supported_arch + if supported_arch: + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) + qw = qw.to(torch.uint8) + + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( + qw, s, zp, has_zp + ) + CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD + return ( + qw_reorder, + s_reorder, + zp_reorder, + sm_count, + sm_version, + CUBLAS_M_THRESHOLD, + ) + + ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_s2, + marlin_zp, + marlin_g_idx, + marlin_sort_indices, + ) = gen_marlin_params() + q_w_gptq, repack_sort_indices = gen_repack_params() + qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = ( + gen_allspark_params() + ) + + # Prepare + marlin_workspace = MarlinWorkspace( + size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ) + + globals = { + # Gen params + "quant_type": quant_type, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "a": a, + # Marlin params + "marlin_w_ref": marlin_w_ref, + "marlin_q_w": marlin_q_w, + "marlin_s": marlin_s, + "marlin_s2": marlin_s2, + "marlin_zp": marlin_zp, + "marlin_g_idx": marlin_g_idx, + "marlin_sort_indices": marlin_sort_indices, + "marlin_workspace": marlin_workspace, + "is_k_full": is_k_full, + # GPTQ params + "q_w_gptq": q_w_gptq, + "repack_sort_indices": repack_sort_indices, + # AllSpark W8A16 params + "qw_reorder": qw_reorder, + "s_reorder": s_reorder, + "zp_reorder": zp_reorder, + "sm_count": sm_count, + "sm_version": sm_version, + "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD, + # Kernels + "marlin_gemm": ops.marlin_gemm, + "gptq_marlin_repack": ops.gptq_marlin_repack, + "allspark_w8a16_gemm": ops.allspark_w8a16_gemm, + } + + min_run_time = 1 + + # Warmup pytorch + for _ in range(5): + torch.matmul(a, marlin_w_ref) + + results.append( + benchmark.Timer( + stmt="torch.matmul(a, marlin_w_ref)", + globals=globals, + label=label, + sub_label=sub_label, + description="pytorch_gemm", + ).blocked_autorange(min_run_time=min_run_time) + ) + + results.append( + benchmark.Timer( + stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="marlin_gemm", + ).blocked_autorange(min_run_time=min_run_time) + ) + + results.append( + benchmark.Timer( + stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="marlin_gemm_fp32", + ).blocked_autorange(min_run_time=min_run_time) + ) + + if repack_supported: + results.append( + benchmark.Timer( + stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange(min_run_time=min_run_time) + ) + + if allspark_supported: + results.append( + benchmark.Timer( + stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="allspark_w8a16_gemm_fp32", + ).blocked_autorange(min_run_time=min_run_time) + ) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + results: list[benchmark.Measurement] = [] + + for model in args.models: + for layer in WEIGHT_SHAPES[model]: + size_k = layer[0] + size_n = layer[1] + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for act_order in ACT_ORDER_OPTS: + if ( + len(args.limit_act_order) > 0 + and act_order not in args.limit_act_order + ): + continue + + for is_k_full in K_FULL_OPTS: + if ( + len(args.limit_k_full) > 0 + and is_k_full not in args.limit_k_full + ): + continue + + for quant_type in query_marlin_supported_quant_types(): + if ( + len(args.limit_num_bits) > 0 + and quant_type.size_bits not in args.limit_num_bits + ): + continue + + for group_size in ( + MARLIN_SUPPORTED_GROUP_SIZES + + FP4_MARLIN_SUPPORTED_GROUP_SIZES + ): + if ( + len(args.limit_group_size) > 0 + and group_size not in args.limit_group_size + ): + continue + + # For act_order, the group_size must be less than + # size_k + if act_order and (group_size == size_k or group_size == -1): + continue + + for size_m in args.batch_sizes: + bench_run( + results, + model, + act_order, + is_k_full, + quant_type, + group_size, + size_m, + size_k, + size_n, + ) + + compare = benchmark.Compare(results) + compare.print() + + +# For quick benchmarking use: +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 +# +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches" + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) + parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) + parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_mla_k_concat.py b/benchmarks/kernels/benchmark_mla_k_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3b6c8f12003e0049dddd4d057c6c31a4aa5dfb --- /dev/null +++ b/benchmarks/kernels/benchmark_mla_k_concat.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation +in MLA (Multi-head Latent Attention) prefill. + +This validates that the optimization from commit 8d4142bd is beneficial across +various batch sizes, not just the originally tested batch size of 32768. +""" + +import time +from collections.abc import Callable + +import torch + +# DeepSeek-V3 MLA dimensions +NUM_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +PE_DIM = 64 + + +def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor: + """Original torch.cat approach with expand.""" + return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + +def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor: + """Optimized direct copy approach (avoids expand + cat overhead).""" + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + + +def benchmark_method( + method: Callable, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + num_warmup: int = 10, + num_iters: int = 100, +) -> float: + """Benchmark a concatenation method and return mean latency in ms.""" + # Warmup + for _ in range(num_warmup): + _ = method(k_nope, k_pe) + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(num_iters): + _ = method(k_nope, k_pe) + torch.cuda.synchronize() + end = time.perf_counter() + + return (end - start) / num_iters * 1000 # Convert to ms + + +@torch.inference_mode() +def run_benchmark(dtype: torch.dtype, dtype_name: str): + """Run benchmark for a specific dtype.""" + torch.set_default_device("cuda") + + # Batch sizes to test (powers of 2 from 32 to 65536) + batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536] + + print("=" * 80) + print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation") + print("=" * 80) + print( + f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], " + f"k_pe=[B, 1, {PE_DIM}]" + ) + print(f"dtype: {dtype_name}") + print() + print( + f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | " + f"{'Speedup':>8} | {'Reduction':>10}" + ) + print("-" * 70) + + results = [] + for batch_size in batch_sizes: + # Create input tensors (generate in float32 then convert for FP8 compatibility) + k_nope = torch.randn( + batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda" + ).to(dtype) + k_pe = torch.randn( + batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda" + ).to(dtype) + + # Benchmark both methods + cat_time = benchmark_method(cat_method, k_nope, k_pe) + direct_time = benchmark_method(direct_copy_method, k_nope, k_pe) + + speedup = cat_time / direct_time + reduction = (1 - direct_time / cat_time) * 100 + + results.append((batch_size, cat_time, direct_time, speedup, reduction)) + + print( + f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | " + f"{speedup:>7.2f}x | {reduction:>9.1f}%" + ) + + print("=" * 80) + + # Summary statistics + speedups = [r[3] for r in results] + print("\nSpeedup summary:") + print(f" Min: {min(speedups):.2f}x") + print(f" Max: {max(speedups):.2f}x") + print(f" Mean: {sum(speedups) / len(speedups):.2f}x") + + # Find crossover point + crossover_batch = None + for batch_size, _, _, speedup, _ in results: + if speedup >= 1.0: + crossover_batch = batch_size + break + + print("\nConclusion:") + if crossover_batch: + print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}") + # Filter for large batches (>= 512 which is typical for prefill) + large_batch_speedups = [r[3] for r in results if r[0] >= 512] + if large_batch_speedups: + avg_large = sum(large_batch_speedups) / len(large_batch_speedups) + print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x") + print(" - MLA prefill typically uses large batches, so optimization is effective") + + return results + + +@torch.inference_mode() +def main(): + # Test bfloat16 + print("\n") + run_benchmark(torch.bfloat16, "bfloat16") + + # Test float8_e4m3fn + print("\n") + run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..4abeaefd774a11b12c163dc88e53c98d7c28b632 --- /dev/null +++ b/benchmarks/kernels/benchmark_moe.py @@ -0,0 +1,1041 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import gc +import json +import os +import time +from contextlib import nullcontext +from datetime import datetime +from itertools import product +from typing import Any, TypedDict + +import ray +import torch +from ray.experimental.tqdm_ray import tqdm + +from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + RoutingMethodType, + _get_config_dtype_str, +) +from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts, +) +from vllm.transformers_utils.config import get_config +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import set_random_seed + +FP8_DTYPE = current_platform.fp8_dtype() + +# Default interval for clearing Triton JIT cache during tuning +# Set to 0 to disable automatic cache clearing +_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL" +TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50")) + + +def clear_triton_cache(): + """Clear Triton JIT compilation cache and Python/CUDA memory. + + This helps prevent OOM during tuning with large models (many experts). + """ + # Force Python garbage collection + gc.collect() + + # Clear CUDA memory cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Try to clear Triton's runtime cache + try: + if ( + hasattr(triton, "runtime") + and hasattr(triton.runtime, "cache") + and hasattr(triton.runtime.cache, "clear") + ): + triton.runtime.cache.clear() + except ImportError: + # Triton not installed, skip cache clearing + pass + except AttributeError: + # Triton version doesn't have expected cache API + pass + except Exception as e: + print(f"Warning: Failed to clear Triton cache: {e}") + + # Additional garbage collection after clearing caches + gc.collect() + + +def ensure_divisibility(numerator, denominator, text): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format( + text, numerator, denominator + ) + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool = False, + num_iters: int = 100, + block_quant_shape: list[int] = None, + use_deep_gemm: bool = False, +) -> float: + init_dtype = torch.float16 if use_fp8_w8a8 else dtype + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + if use_int4_w4a16: + # Int4 packed weights: 2 int4 values per uint8 byte + # K dimension is packed (halved) + intermediate_size = shard_intermediate_size // 2 # after silu_and_mul + w1 = torch.randint( + 0, + 255, + ( + num_experts, + shard_intermediate_size, + hidden_size // 2, # int4 packing + ), + dtype=torch.uint8, + ) + w2 = torch.randint( + 0, + 255, + ( + num_experts, + hidden_size, + intermediate_size // 2, # int4 packing + ), + dtype=torch.uint8, + ) + elif use_int8_w8a16: + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + ) + else: + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_int4_w4a16: + if block_quant_shape is None: + raise ValueError("block_quant_shape is required for int4_w4a16") + group_size = block_quant_shape[1] + # Scales shape: (E, N, K // group_size) in fp16 + w1_scale = torch.rand( + (num_experts, shard_intermediate_size, hidden_size // group_size), + dtype=dtype, + ) + w2_scale = torch.rand( + (num_experts, hidden_size, intermediate_size // group_size), + dtype=dtype, + ) + elif use_int8_w8a16: + w1_scale = torch.randn( + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + ) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_deep_gemm: + # we use the default block shape for deepgemm + block_quant_shape = [128, 128] + if use_fp8_w8a8: + if block_quant_shape: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + E = num_experts + N = shard_intermediate_size // 2 + K = hidden_size + factor_for_scale = 1e-2 + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + w1_scale = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_scale = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) + else: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + + w1 = w1.to(FP8_DTYPE) + w2 = w2.to(FP8_DTYPE) + + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + from vllm.model_executor.layers.fused_moe import override_config + + if use_fp8_w8a8: + quant_dtype = torch.float8_e4m3fn + elif use_int8_w8a16: + quant_dtype = torch.int8 + else: + quant_dtype = None + + quant_config = FusedMoEQuantConfig.make( + quant_dtype=quant_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + weight_dtype="int4" if use_int4_w4a16 else None, + ) + + deep_gemm_experts = None + if use_deep_gemm: + moe_config = ( + FusedMoEConfig( + num_experts=num_experts, + experts_per_token=topk, + hidden_dim=hidden_size, + intermediate_size_per_partition=shard_intermediate_size, + num_local_experts=num_experts, + num_logical_experts=num_experts, + activation=MoEActivation.SILU, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=init_dtype, + routing_method=RoutingMethodType.TopK, + device="cuda", + ), + ) + deep_gemm_experts = mk.FusedMoEKernel( + prepare_finalize=maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + fused_experts=TritonOrDeepGemmExperts( + moe_config=moe_config, + quant_config=quant_config, + ), + inplace=not disable_inplace(), + ) + + with override_config(config): + topk_weights, topk_ids, token_expert_indices = fused_topk( + x, input_gating, topk, renormalize=not use_deep_gemm + ) + + inplace = not disable_inplace() + if use_deep_gemm: + return deep_gemm_experts.apply( + x, + w1, + w2, + topk_weights, + topk_ids, + activation=MoEActivation.SILU, + global_num_experts=num_experts, + apply_router_weight_on_input=False, + expert_map=False, + ) + return fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + quant_config=quant_config, + ) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def get_rocm_tuning_space(use_fp16): + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + if not use_fp16: + block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + num_stage_range = [2] + waves_per_eu_range = [0, 1, 2, 4] + matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] + kpack_range = [1, 2] if use_fp16 else [] + + param_ranges = { + "BLOCK_SIZE_M": block_mn_range, + "BLOCK_SIZE_N": block_mn_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + "waves_per_eu": waves_per_eu_range, + } + if use_fp16: + param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range + param_ranges["kpack"] = kpack_range + + return param_ranges + + +def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]: + configs: list[BenchmarkConfig] = [] + + if current_platform.is_rocm(): + param_ranges = get_rocm_tuning_space(use_fp16) + else: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + block_m_range = [16, 32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [64, 128, 256] + num_warps_range = [4, 8] + group_m_range = [1, 16, 32, 64] + num_stage_range = [2, 3, 4, 5] + + param_ranges = { + "BLOCK_SIZE_M": block_m_range, + "BLOCK_SIZE_N": block_n_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + } + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) + + # Remove configs that are not compatible with fp8 block quantization + # BLOCK_SIZE_K must be a multiple of block_k + # BLOCK_SIZE_N must be a multiple of block_n + if block_quant_shape is not None and not use_fp16: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + for config in configs[:]: + if ( + config["BLOCK_SIZE_K"] % block_k != 0 + or config["BLOCK_SIZE_N"] % block_n != 0 + ): + configs.remove(config) + return configs + + +def prune_rocm_search_space( + num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk +): + N1, K1 = shard_intermediate_size, hidden_size + N2, K2 = hidden_size, shard_intermediate_size // 2 + pruned_space_1 = prune_rocm_configs( + num_tokens * topk, N1, K1, search_space, is_fp16 + ) + pruned_space_2 = prune_rocm_configs( + num_tokens * topk, N2, K2, search_space, is_fp16 + ) + search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) + return search_space + + +# The following code is inspired by ROCm/Triton GEMM tuning script: +# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89 +def prune_rocm_configs(M, N, K, configs, is_fp16=True): + pruned_configs = [] + elemBytes_a = 2 if is_fp16 else 1 + elemBytes_b = 2 if is_fp16 else 1 + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + + if is_fp16: + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = config.get("SPLIT_K", 1) + GROUP_M = config.get("GROUP_SIZE_M") + if is_fp16: + if ( + matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N + ): + continue + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: + continue + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = ( + BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + ) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def merge_unique_dicts(list1, list2): + result = [] + combined_list = list1.copy() + combined_list.extend(list2) + for dictionary in combined_list: + if dictionary not in result: + result.append(dictionary) + return result + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + set_random_seed(seed) + self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. This is required for Ray to work + # correctly with multi-GPU tuning on the ROCm platform. + self.device_id = int(ray.get_gpu_ids()[0]) + + def benchmark( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool = False, + block_quant_shape: list[int] = None, + use_deep_gemm: bool = False, + ) -> tuple[dict[str, int], float]: + # local import to allow serialization by ray + + set_random_seed(self.seed) + dtype_str = _get_config_dtype_str( + dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8, + use_int4_w4a16=use_int4_w4a16, + ) + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + block_n = block_quant_shape[0] if block_quant_shape else None + block_k = block_quant_shape[1] if block_quant_shape else None + op_config = get_moe_configs( + num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k + ) + if op_config is None: + config = get_default_config( + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype_str, + block_quant_shape, + ) + else: + config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + num_iters=100, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + ) + return config, kernel_time + + def tune( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + search_space: list[dict[str, int]], + block_quant_shape: list[int], + use_deep_gemm: bool, + ) -> dict[str, int]: + # local import to allow serialization by ray + from vllm.platforms import current_platform + + best_config = None + best_time = float("inf") + if current_platform.is_rocm(): + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16) + search_space = prune_rocm_search_space( + num_tokens, + shard_intermediate_size, + hidden_size, + search_space, + is_fp16, + topk, + ) + + need_device_guard = False + if current_platform.is_rocm(): + visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None) + if visible_device != f"{self.device_id}": + need_device_guard = True + + with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): + for idx, config in enumerate(tqdm(search_space)): + try: + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + num_iters=20, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + + # Periodically clear Triton JIT cache to prevent OOM + # This is especially important for large models with many experts + if ( + TRITON_CACHE_CLEAR_INTERVAL > 0 + and idx > 0 + and idx % TRITON_CACHE_CLEAR_INTERVAL == 0 + ): + clear_triton_cache() + + # Final cleanup after tuning completes + clear_triton_cache() + + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + assert best_config is not None + return best_config + + +def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: + return { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + **( + {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + ), + **( + {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]} + if "matrix_instr_nonkdim" in config + else {} + ), + **({"kpack": config["kpack"]} if "kpack" in config else {}), + **({"SPLIT_K": config["SPLIT_K"]} if "SPLIT_K" in config else {}), + } + + +def save_configs( + configs: dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_quant_shape: list[int], + save_dir: str, +) -> None: + dtype_str = _get_config_dtype_str( + dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8, + use_int4_w4a16=use_int4_w4a16, + ) + + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) + os.makedirs(save_dir, exist_ok=True) + filename = os.path.join(save_dir, filename) + print(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump({"triton_version": triton.__version__, **configs}, f, indent=4) + f.write("\n") + + +def get_compressed_tensors_block_structure(config, default_value=None): + config_groups = config.get("config_groups", {}) + if len(config_groups) != 1: + return default_value + group = next(iter(config_groups.values())) + weights = group.get("weights", {}) + block_structure = weights.get("block_structure", default_value) + return block_structure + + +def get_weight_block_size_safety(config, default_value=None): + quantization_config = getattr(config, "quantization_config", {}) + if isinstance(quantization_config, dict): + if "weight_block_size" in quantization_config: + return quantization_config["weight_block_size"] + return get_compressed_tensors_block_structure( + quantization_config, default_value + ) + return default_value + + +def get_model_params(config): + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + hidden_size = config.hidden_size + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + hidden_size = config.hidden_size + elif config.architectures[0] in ( + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "DeepseekV32ForCausalLM", + "GlmMoeDsaForCausalLM", + "Glm4MoeForCausalLM", + "Glm4MoeLiteForCausalLM", + "NemotronHForCausalLM", + "MistralLarge3ForCausalLM", + ): + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + hidden_size = config.hidden_size + elif config.architectures[0] in ( + "Qwen2MoeForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3NextForCausalLM", + ): + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + hidden_size = config.hidden_size + elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration": + text_config = config.get_text_config() + E = text_config.num_experts + topk = text_config.num_experts_per_tok + intermediate_size = text_config.moe_intermediate_size + hidden_size = text_config.hidden_size + elif config.architectures[0] == "HunYuanMoEV1ForCausalLM": + E = config.num_experts + topk = config.moe_topk[0] + intermediate_size = config.moe_intermediate_size[0] + hidden_size = config.hidden_size + elif config.architectures[0] == "Qwen3OmniMoeForConditionalGeneration": + E = config.thinker_config.text_config.num_experts + topk = config.thinker_config.text_config.num_experts_per_tok + intermediate_size = config.thinker_config.text_config.moe_intermediate_size + hidden_size = config.thinker_config.text_config.hidden_size + elif config.architectures[0] == "PixtralForConditionalGeneration": + # Pixtral can contain different LLM architectures, + # recurse to get their parameters + return get_model_params(config.get_text_config()) + else: + # Support for llama4 + config = config.get_text_config() + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + hidden_size = config.hidden_size + return E, topk, intermediate_size, hidden_size + + +def get_quantization_group_size(config) -> int | None: + """Extract the quantization group size from the HF model config. + + This reads directly from the HuggingFace config object (as returned by + ``get_config()``), not from vLLM's quantization config classes. + + Supports AWQ/GPTQ-style configs (direct 'group_size' key) and + compressed-tensors configs (nested inside 'config_groups'). + """ + quantization_config = getattr(config, "quantization_config", {}) + if not isinstance(quantization_config, dict): + return None + # AWQ / GPTQ style: group_size is a top-level key + gs = quantization_config.get("group_size") + if gs is not None: + return gs + # compressed-tensors style: group_size is nested in config_groups + config_groups = quantization_config.get("config_groups", {}) + if not isinstance(config_groups, dict): + return None + for group_cfg in config_groups.values(): + if not isinstance(group_cfg, dict): + continue + weights = group_cfg.get("weights", {}) + if not isinstance(weights, dict): + continue + gs = weights.get("group_size") + if gs is not None: + return gs + return None + + +def main(args: argparse.Namespace): + print(args) + + config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) + if args.model_prefix: + config = getattr(config, args.model_prefix) + E, topk, intermediate_size, hidden_size = get_model_params(config) + enable_ep = bool(args.enable_expert_parallel) + if enable_ep: + ensure_divisibility(E, args.tp_size, "Number of experts") + E = E // args.tp_size + shard_intermediate_size = 2 * intermediate_size + else: + ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") + shard_intermediate_size = 2 * intermediate_size // args.tp_size + dtype = torch.float16 if current_platform.is_rocm() else config.dtype + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" + use_int4_w4a16 = args.dtype == "int4_w4a16" + block_quant_shape = get_weight_block_size_safety(config) + if use_int4_w4a16: + group_size = get_quantization_group_size(config) + if group_size is None: + raise ValueError( + "Could not determine group_size from model config. " + "The model's quantization_config must contain a 'group_size' " + "field (AWQ/GPTQ) or 'config_groups.*.weights.group_size' " + "(compressed-tensors)." + ) + # For int4_w4a16, block_shape = [0, group_size] + # block_shape[0]=0 means no block quantization on N dimension + block_quant_shape = [0, group_size] + + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + else: + batch_sizes = args.batch_size + + use_deep_gemm = bool(args.use_deep_gemm) + + if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ: + # Ray will set ROCR_VISIBLE_DEVICES for device visibility + logger.warning( + "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility." + "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES." + ) + val = os.environ["HIP_VISIBLE_DEVICES"] + os.environ["ROCR_VISIBLE_DEVICES"] = val + del os.environ["HIP_VISIBLE_DEVICES"] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: list[Any]) -> list[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + if args.tune: + # int4_w4a16 weights are uint8-packed, not fp16; treat like fp8 for + # search space generation (no matrix_instr_nonkdim/kpack exploration). + is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16) + # For int4_w4a16, the group_size constraint on BLOCK_SIZE_K does not + # apply: the gptq_awq kernel handles arbitrary BLOCK_SIZE_K regardless + # of group_size. Skip block_quant_shape filtering to keep the full + # search space (e.g. BLOCK_SIZE_K=64 with group_size=128). + tune_block_quant_shape = None if use_int4_w4a16 else block_quant_shape + search_space = get_configs_compute_bound(is_fp16, tune_block_quant_shape) + if use_int4_w4a16: + # SPLIT_K is a required kernel constexpr for gptq_awq kernel; + # only SPLIT_K=1 is used at runtime, so fix it during tuning. + for cfg in search_space: + cfg["SPLIT_K"] = 1 + print(f"Start tuning over {len(search_space)} configurations...") + if use_deep_gemm: + raise ValueError( + "Tuning with --use-deep-gemm is not supported as it only tunes Triton " + "kernels. Please remove the flag." + ) + start = time.time() + configs = _distribute( + "tune", + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + search_space, + block_quant_shape, + use_deep_gemm, + ) + for batch_size in batch_sizes + ], + ) + best_configs = { + M: sort_config(config) for M, config in zip(batch_sizes, configs) + } + save_configs( + best_configs, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + block_quant_shape, + args.save_dir, + ) + end = time.time() + print(f"Tuning took {end - start:.2f} seconds") + else: + outputs = _distribute( + "benchmark", + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + block_quant_shape, + use_deep_gemm, + ) + for batch_size in batch_sizes + ], + ) + + for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): + print(f"Batch size: {batch_size}, config: {config}") + print(f"Kernel time: {kernel_time:.2f} us") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument( + "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2 + ) + parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true") + parser.add_argument( + "--dtype", + type=str, + choices=["auto", "fp8_w8a8", "int8_w8a16", "int4_w4a16"], + default="auto", + ) + parser.add_argument("--use-deep-gemm", action="store_true") + parser.add_argument( + "--save-dir", type=str, default="./", help="Directory to save tuned results" + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, nargs="+", required=False) + parser.add_argument("--tune", action="store_true") + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--model-prefix", type=str, required=False) + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/benchmark_moe_align_block_size.py b/benchmarks/kernels/benchmark_moe_align_block_size.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9a131f79b0ee4419db8b863193c45f7d6eca7b --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_align_block_size.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import itertools + +import torch + +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size, +) +from vllm.triton_utils import triton + + +def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: + return torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(num_tokens) + ] + ) + + +# test configurations +num_tokens_range = [1, 16, 256, 4096] +num_experts_range = [16, 64, 224, 256, 280, 512] +topk_range = [1, 2, 8] +ep_size_range = [1, 8] +configs = list( + itertools.product(num_tokens_range, num_experts_range, topk_range, ep_size_range) +) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "num_experts", "topk", "ep_size"], + x_vals=configs, + line_arg="provider", + line_vals=["vllm"], + line_names=["vLLM"], + plot_name="moe-align-block-size-performance", + args={}, + ) +) +def benchmark(num_tokens, num_experts, topk, ep_size, provider): + """Benchmark function for Triton.""" + block_size = 256 + torch.cuda.manual_seed_all(0) + topk_ids = get_topk_ids(num_tokens, num_experts, topk) + + e_map = None + if ep_size != 1: + local_e = num_experts // ep_size + e_ids = torch.randperm(num_experts, device="cuda", dtype=torch.int32)[:local_e] + e_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "vllm": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: moe_align_block_size( + topk_ids, block_size, num_experts, e_map, ignore_invalid_experts=True + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_experts", + type=int, + default=64, + choices=[8, 16, 32, 64, 128, 256], + ) + parser.add_argument( + "--topk", + type=int, + default=8, + choices=[2, 4, 8], + help="Top-k value for correctness check.", + ) + args = parser.parse_args() + + benchmark.run(print_data=True, show_plots=True) diff --git a/benchmarks/kernels/benchmark_moe_defaults.py b/benchmarks/kernels/benchmark_moe_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..9527878bc3581f10e99760ffeb17d9c37ae62503 --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_defaults.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark comparing old vs new default fused MoE configs. + +Runs the triton fused_moe kernel with three configurations for each scenario: + 1. Tuned config (from JSON file, if available) — the target to match + 2. Old default (the hardcoded defaults before this change) + 3. New default (the improved defaults) + +Usage: + python benchmarks/kernels/benchmark_moe_defaults.py + +Produces a table showing kernel time (us) and speedup of new vs old defaults. +""" + +import torch + +from vllm.model_executor.layers.fused_moe import fused_topk, override_config +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts, + get_default_config, + get_moe_configs, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils.torch_utils import set_random_seed + +FP8_DTYPE = current_platform.fp8_dtype() + + +def old_default_config(M, E, N, K, topk, dtype=None, block_shape=None): + """The original defaults before https://github.com/vllm-project/vllm/pull/34846, + for comparison.""" + if dtype == "fp8_w8a8" and block_shape is not None: + return { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "SPLIT_K": 1, + "num_warps": 4, + "num_stages": 3 if not current_platform.is_rocm() else 2, + } + elif M <= E: + return { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "SPLIT_K": 1, + } + else: + return { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + } + + +def benchmark_config( + config, + M, + E, + N, + K, + topk, + dtype, + use_fp8=False, + block_shape=None, + num_iters=100, +): + """Time a single kernel config. Returns kernel time in microseconds.""" + init_dtype = torch.float16 if use_fp8 else dtype + + a = torch.randn(M, K, device="cuda", dtype=init_dtype) / 10 + w1 = torch.randn(E, 2 * N, K, device="cuda", dtype=init_dtype) / 10 + w2 = torch.randn(E, K, N, device="cuda", dtype=init_dtype) / 10 + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_fp8: + if block_shape is not None: + bsn, bsk = block_shape + n_tiles_w1 = triton.cdiv(2 * N, bsn) + k_tiles_w1 = triton.cdiv(K, bsk) + n_tiles_w2 = triton.cdiv(K, bsn) + k_tiles_w2 = triton.cdiv(N, bsk) + w1_scale = torch.rand( + E, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32 + ) + w2_scale = torch.rand( + E, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32 + ) + else: + w1_scale = torch.rand(E, device="cuda", dtype=torch.float32) + w2_scale = torch.rand(E, device="cuda", dtype=torch.float32) + a1_scale = torch.rand(1, device="cuda", dtype=torch.float32) + a2_scale = torch.rand(1, device="cuda", dtype=torch.float32) + # Only weights are stored in fp8; activations stay in bf16/fp16 + # and get dynamically quantized inside the kernel. + w1 = w1.to(FP8_DTYPE) + w2 = w2.to(FP8_DTYPE) + + quant_config = FusedMoEQuantConfig.make( + quant_dtype=torch.float8_e4m3fn if use_fp8 else None, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + + gating = torch.randn(M, E, device="cuda", dtype=torch.float32) + + # Warmup + for _ in range(20): + with override_config(config): + topk_weights, topk_ids, _ = fused_topk(a, gating, topk, renormalize=True) + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + torch.cuda.synchronize() + + # Benchmark + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(num_iters): + with override_config(config): + topk_weights, topk_ids, _ = fused_topk(a, gating, topk, renormalize=True) + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / num_iters * 1000 # ms -> us + + +# Model configurations: (name, E, N, K, topk, dtype_str, use_fp8, block_shape) +# N = moe_intermediate_size // tp_size (the value used in config file lookup) +MODELS = [ + # --- Few experts --- + ("Mixtral bf16", 8, 7168, 4096, 2, None, False, None), + ("Mixtral fp8", 8, 7168, 4096, 2, "fp8_w8a8", True, None), + # --- Many experts: real model shapes at tp=1 --- + # Qwen2-MoE-57B: E=60, topk=4, N=1408, K=2048 + ("Qwen2-MoE bf16", 60, 1408, 2048, 4, None, False, None), + # DeepSeek-V2: E=64, topk=6, N=1407, K=4096 + # (use 1408 to avoid odd alignment; real model is 1407) + ("DeepSeek-V2 bf16", 64, 1408, 4096, 6, None, False, None), + # OLMoE-7B: E=64, topk=8, N=2048, K=2048 + ("OLMoE bf16", 64, 2048, 2048, 8, None, False, None), + # GLM-4-100B-A10B: E=128, topk=8, N=1408, K=4096 + ("GLM-4-MoE bf16", 128, 1408, 4096, 8, None, False, None), + # Qwen3-30B-A3B: E=128, topk=8, N=768, K=2048 + ("Qwen3-MoE bf16", 128, 768, 2048, 8, None, False, None), + # DeepSeek-V3 / MiMo-V2-Flash: E=256, topk=8, N=2048, K=7168 + ("DeepSeek-V3 bf16", 256, 2048, 7168, 8, None, False, None), + # Qwen3.5-70B-A22B (Qwen3-Next): E=512, topk=10, N=512, K=2048 + ("Qwen3-Next bf16", 512, 512, 2048, 10, None, False, None), + # E=128 N=1856 bf16 + ("E128 N1856 bf16", 128, 1856, 4096, 8, None, False, None), + # E=256 N=512 bf16 (DS-V3 tp=4) + ("DS-V3 tp4 bf16", 256, 512, 7168, 8, None, False, None), + # E=512 N=512 bf16 (Qwen3-Next tp=1) + ("Qwen3-Next bf16", 512, 512, 2048, 10, None, False, None), + # E=512 N=256 bf16 (Qwen3-Next tp=2) + ("Qwen3-Next tp2", 512, 256, 2048, 10, None, False, None), + # --- FP8 block quant (many experts) --- + # DS-V3 tp=4: E=256, N=512, fp8 block + ("DS-V3 tp4 fp8blk", 256, 512, 7168, 8, "fp8_w8a8", True, [128, 128]), + # DS-V3 tp=8: E=256, N=256, fp8 block + ("DS-V3 tp8 fp8blk", 256, 256, 7168, 8, "fp8_w8a8", True, [128, 128]), + # Qwen3-Next tp=2 fp8 block + ("Qwen3-Next tp2 fp8blk", 512, 256, 2048, 10, "fp8_w8a8", True, [128, 128]), +] + +BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] + + +def main(): + set_random_seed(0) + torch.set_default_device("cuda") + dtype = torch.bfloat16 + + for name, E, N, K, topk, dtype_str, use_fp8, block_shape in MODELS: + print(f"\n{'=' * 90}") + print(f" {name} (E={E}, N={N}, K={K}, topk={topk})") + print(f"{'=' * 90}") + + # Try to load tuned config + block_n = block_shape[0] if block_shape else None + block_k = block_shape[1] if block_shape else None + tuned = get_moe_configs(E, N, dtype_str, block_n, block_k) + has_tuned = tuned is not None + print(f" Tuned config available: {has_tuned}") + + hdr = ( + f"{'Batch':>6} | {'Tuned (us)':>11} | {'Old (us)':>11} | " + f"{'New (us)':>11} | {'New/Old':>8} | {'New/Tuned':>10}" + ) + print(f" {hdr}") + print(f" {'-' * len(hdr)}") + + for M in BATCH_SIZES: + old_cfg = old_default_config(M, E, N, K, topk, dtype_str, block_shape) + new_cfg = get_default_config(M, E, N, K, topk, dtype_str, block_shape) + + if has_tuned: + tuned_cfg = tuned[min(tuned.keys(), key=lambda x: abs(x - M))] + t_tuned = benchmark_config( + tuned_cfg, + M, + E, + N, + K, + topk, + dtype, + use_fp8=use_fp8, + block_shape=block_shape, + ) + else: + t_tuned = None + + t_old = benchmark_config( + old_cfg, + M, + E, + N, + K, + topk, + dtype, + use_fp8=use_fp8, + block_shape=block_shape, + ) + t_new = benchmark_config( + new_cfg, + M, + E, + N, + K, + topk, + dtype, + use_fp8=use_fp8, + block_shape=block_shape, + ) + + ratio_new_old = t_new / t_old + tuned_str = f"{t_tuned:11.2f}" if t_tuned else f"{'N/A':>11}" + ratio_tuned = f"{t_new / t_tuned:10.2f}x" if t_tuned else f"{'N/A':>10}" + # flag regressions where new default is >5% slower than old + marker = " <--" if ratio_new_old > 1.05 else "" + + print( + f" {M:>6} | {tuned_str} | {t_old:11.2f} | {t_new:11.2f} " + f"| {ratio_new_old:7.2f}x | {ratio_tuned}{marker}" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a1d33038fdef348873a0a739cfd2af00a5172e --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +from typing import Any, TypedDict + +import ray +import torch +from transformers import AutoConfig + +from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, + moe_unpermute, +) +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.platforms import current_platform +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import set_random_seed + +FP8_DTYPE = current_platform.fp8_dtype() + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def benchmark_permute( + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, +) -> float: + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + # output_hidden_states = torch.empty_like(hidden_states) + if use_fp8_w8a8: + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) + else: + qhidden_states = hidden_states + + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + topk_weights, topk_ids, token_expert_indices = fused_topk( + qhidden_states, input_gating, topk, False + ) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + moe_permute( + qhidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=num_experts, + expert_map=None, + ) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def benchmark_unpermute( + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, +) -> float: + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + if use_fp8_w8a8: + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) + else: + qhidden_states = hidden_states + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + + topk_weights, topk_ids, token_expert_indices = fused_topk( + qhidden_states, input_gating, topk, False + ) + + def prepare(): + ( + permuted_hidden_states, + _, + first_token_off, + inv_perm_idx, + _, + ) = moe_permute( + qhidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=num_experts, + expert_map=None, + ) + # convert to fp16/bf16 as gemm output + return ( + permuted_hidden_states.to(dtype), + first_token_off, + inv_perm_idx, + ) + + def run(input: tuple): + (permuted_hidden_states, first_token_off, inv_perm_idx) = input + output = torch.empty_like(hidden_states) + moe_unpermute( + output, + permuted_hidden_states, + topk_weights, + inv_perm_idx, + first_token_off, + ) + + # JIT compilation & warmup + input = prepare() + run(input) + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run(input) + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + set_random_seed(seed) + self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. This is required for Ray to work + # correctly with multi-GPU tuning on the ROCm platform. + self.device_id = int(ray.get_gpu_ids()[0]) + + def benchmark( + self, + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + ) -> tuple[float, float]: + set_random_seed(self.seed) + + permute_time = benchmark_permute( + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + ) + unpermute_time = benchmark_unpermute( + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + ) + return permute_time, unpermute_time + + +def get_weight_block_size_safety(config, default_value=None): + quantization_config = getattr(config, "quantization_config", {}) + if isinstance(quantization_config, dict): + return quantization_config.get("weight_block_size", default_value) + return default_value + + +def main(args: argparse.Namespace): + print(args) + + config = AutoConfig.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code + ) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + elif ( + config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM" + or config.architectures[0] == "Glm4MoeForCausalLM" + or config.architectures[0] == "Glm4MoeLiteForCausalLM" + ): + E = config.n_routed_experts + topk = config.num_experts_per_tok + elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: + E = config.num_experts + topk = config.num_experts_per_tok + + else: + # Support for llama4 + config = config.get_text_config() + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok + + hidden_size = config.hidden_size + dtype = torch.float16 if current_platform.is_rocm() else config.dtype + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" + + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + else: + batch_sizes = [args.batch_size] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: list[Any]) -> list[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + outputs = _distribute( + "benchmark", + [ + ( + batch_size, + E, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + ) + for batch_size in batch_sizes + ], + ) + + for batch_size, (permute, unpermute) in zip(batch_sizes, outputs): + print(f"Batch size: {batch_size}") + print(f"Permute time: {permute:.2f} us") + print(f"Unpermute time: {unpermute:.2f} us") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument( + "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/benchmark_mrope.py b/benchmarks/kernels/benchmark_mrope.py new file mode 100644 index 0000000000000000000000000000000000000000..2c086870c42a2fc7a1c9ae5d56e59090b865a680 --- /dev/null +++ b/benchmarks/kernels/benchmark_mrope.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models). +# It generates test data, runs benchmarks, and saves results to a CSV file. +# +# The CSV file (named with current date/time) contains these columns: +# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position, +# is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99, +# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max, +# speedup +# +# == Usage Examples == +# +# Single model benchmark: +# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \ +# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models benchmark: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models with different TP sizes: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models with different token counts: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384 +import csv +import os +import time +from datetime import datetime +from typing import Any + +import numpy as np +import torch + +from vllm.benchmarks.lib.utils import default_vllm_config +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.transformers_utils.config import get_config +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import set_random_seed + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def generate_test_data( + num_tokens: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + max_position_embeddings: int, + dtype: torch.dtype, + device: torch.device, +): + """Generate test data for given configuration.""" + # Create 2D positions (3, num_tokens) for multimodal case + positions = torch.randint( + 0, max_position_embeddings // 4, (3, num_tokens), device=device + ) + + # Create query and key tensors + query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) + key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) + + return positions, query, key + + +def calculate_stats(times: list[float]) -> dict[str, float]: + """Calculate statistics from a list of times.""" + times_array = np.array(times) + return { + "mean": np.mean(times_array), + "median": np.median(times_array), + "p99": np.percentile(times_array, 99), + "min": np.min(times_array), + "max": np.max(times_array), + } + + +@default_vllm_config() +def benchmark_mrope( + model_name: str, + num_tokens: int, + head_dim: int, + tp_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 8192, + is_neox_style: bool = True, + rope_parameters: dict[str, Any] | None = None, + dtype: torch.dtype = torch.bfloat16, + seed: int = 0, + warmup_iter: int = 10, + benchmark_iter: int = 100, + csv_writer=None, +): + set_random_seed(seed) + torch.set_default_device(device) + # the parameters to compute the q k v size based on tp_size + mrope_helper_class = get_rope( + head_size=head_dim, + max_position=max_position, + is_neox_style=is_neox_style, + rope_parameters=rope_parameters, + dtype=dtype, + ).to(device=device) + + print(80 * "=") + print( + f"Evaluating model: {model_name} " + f"with tp_size: {tp_size} " + f"and num_tokens: {num_tokens}, " + f"dtype: {dtype}" + ) + + # create q k v input tensors + # create rotary pos emb input tensors + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) + + # Warm up + for _ in range(warmup_iter): + mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + mrope_helper_class.forward_cuda( + positions, + query.clone(), + key.clone(), + ) + + torch.cuda.synchronize() + + # Time reference implementation + torch_times = [] + for _ in range(benchmark_iter): + query_clone = query.clone() + key_clone = key.clone() + torch.cuda.synchronize() + start_time = time.time() + + mrope_helper_class.forward_native( + positions, + query_clone, + key_clone, + ) + + torch.cuda.synchronize() + torch_times.append(time.time() - start_time) + + # Time triton kernel implementation + triton_times = [] + for _ in range(benchmark_iter): + query_clone = query.clone() + key_clone = key.clone() + torch.cuda.synchronize() + start_time = time.time() + mrope_helper_class.forward_cuda( + positions, + query_clone, + key_clone, + ) + torch.cuda.synchronize() + triton_times.append(time.time() - start_time) + + # Calculate statistics + torch_stats = calculate_stats(torch_times) + triton_stats = calculate_stats(triton_times) + print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):") + + print( + f"Torch implementation: " + f"mean={torch_stats['mean']:.8f}s, " + f"median={torch_stats['median']:.8f}s, " + f"p99={torch_stats['p99']:.8f}s" + ) + + print( + f"Triton implementation: " + f"mean={triton_stats['mean']:.8f}s, " + f"median={triton_stats['median']:.8f}s, " + f"p99={triton_stats['p99']:.8f}s" + ) + + print( + f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x" + ) + + # Write to CSV + if csv_writer: + row = [ + model_name, + tp_size, + num_tokens, + num_heads, + num_kv_heads, + head_dim, + max_position, + is_neox_style, + str(rope_parameters), + str(dtype).split(".")[-1], + torch_stats["mean"], + torch_stats["median"], + torch_stats["p99"], + torch_stats["min"], + torch_stats["max"], + triton_stats["mean"], + triton_stats["median"], + triton_stats["p99"], + triton_stats["min"], + triton_stats["max"], + torch_stats["mean"] / triton_stats["mean"], # speedup + ] + csv_writer.writerow(row) + + return torch_stats, triton_stats + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the rotary embedding kernels." + ) + parser.add_argument("--model-name", type=str, default="") + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument("--warmup-iter", type=int, default=10) + parser.add_argument("--benchmark-iter", type=int, default=100) + parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num-tokens", type=int, nargs="+", required=False) + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv") + args = parser.parse_args() + print(args) + + # Create CSV file for results + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv" + + with open(csv_filename, "w", newline="") as csvfile: + csv_writer = csv.writer(csvfile) + # Write header + header = [ + "model_name", + "tp_size", + "num_tokens", + "num_heads", + "num_kv_heads", + "head_dim", + "max_position", + "is_neox_style", + "rope_parameters", + "dtype", + "torch_mean", + "torch_median", + "torch_p99", + "torch_min", + "torch_max", + "triton_mean", + "triton_median", + "triton_p99", + "triton_min", + "triton_max", + "speedup", + ] + csv_writer.writerow(header) + + model_tp_dict = {} + if args.model_name == "": + model_tp_dict = { + "Qwen/Qwen2-VL-2B-Instruct": [1], + "Qwen/Qwen2-VL-7B-Instruct": [1], + "Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8], + "Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8], + "Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8], + "Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8], + } + else: + model_tp_dict[args.model_name] = [args.tp_size] + + if args.num_tokens is None: + num_tokens_list = [2**i for i in range(0, 18)] + else: + num_tokens_list = args.num_tokens + + for model_name, tp_list in model_tp_dict.items(): + config = get_config(model_name, trust_remote_code=args.trust_remote_code) + for tp_size in tp_list: + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = config.hidden_size // total_num_heads + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + is_neox_style = True + rope_parameters = config.rope_parameters + max_position = config.max_position_embeddings + + for num_tokens in num_tokens_list: + benchmark_mrope( + model_name=model_name, + num_tokens=num_tokens, + head_dim=head_dim, + tp_size=tp_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + is_neox_style=is_neox_style, + rope_parameters=rope_parameters, + dtype=getattr(torch, args.dtype), + seed=args.seed, + warmup_iter=args.warmup_iter, + benchmark_iter=args.benchmark_iter, + csv_writer=csv_writer, + ) + + print(f"Benchmark results saved to {csv_filename}") diff --git a/benchmarks/kernels/benchmark_mxfp4_qutlass.py b/benchmarks/kernels/benchmark_mxfp4_qutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc7721876a1739055f851f86468ce01fb3fd670 --- /dev/null +++ b/benchmarks/kernels/benchmark_mxfp4_qutlass.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import itertools + +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "mxfp4": dict(no_a_quant=False, enabled=True), + "mxfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _quant_weight_mxfp4( + b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str +): + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx( + b, forward_hadamard_matrix, method="abs_max" + ) + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton") + return weight_hf_e2m1, weight_hf_scale_block + + +def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device): + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4( + b, forward_hadamard_matrix, device + ) + alpha = torch.tensor([1.0], device="cuda") + + if cfg["no_a_quant"]: + # Pre-quantize activation + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( + a, forward_hadamard_matrix, method="abs_max" + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") + + def run(): + return matmul_mxf4_bf16_tn( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + ) + + return run + + # Quantize activation on-the-fly + def run(): + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( + a, forward_hadamard_matrix, method="abs_max" + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") + return matmul_mxf4_bf16_tn( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 24576, + 32768, + ], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs MXFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, had_size): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_mxfp4_runner( + cfg, a, b, forward_hadamard_matrix, dtype, device + ) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), rep=200, quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + for had_size in [32, 64, 128]: + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_mxfp4_res_n{N}_k{K}", + N=N, + K=K, + had_size=had_size, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/benchmark_nvfp4_gemm.py b/benchmarks/kernels/benchmark_nvfp4_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..6b19eb113f3e77f93043ae2e92ffaf55e771b14e --- /dev/null +++ b/benchmarks/kernels/benchmark_nvfp4_gemm.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools +import os + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.triton_utils import triton + +if not current_platform.has_device_capability(100): + raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)") + + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "nvfp4": dict(no_a_quant=False, enabled=True), + "nvfp4-noquant": dict(no_a_quant=True, enabled=True), + "fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True), + "fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True), +} + +_needs_fbgemm = any( + v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False) +) +if _needs_fbgemm: + try: + from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import ( + triton_scale_nvfp4_quant, + ) + except ImportError: + print( + "WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. " + "These providers will be skipped. Please install fbgemm_gpu with: " + "'pip install fbgemm-gpu-genai' to run them." + ) + # Disable FBGEMM providers so the benchmark can run. + for cfg in PROVIDER_CFGS.values(): + if cfg.get("fbgemm"): + cfg["enabled"] = False + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg): + # Compute global scale for weight + b_amax = torch.abs(b).max().to(torch.float32) + b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax + if "fbgemm" in cfg and cfg["fbgemm"]: + b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale) + else: + b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale) + return b_fp4, scale_b_fp4, b_global_scale + + +def build_nvfp4_runner(cfg, a, b, dtype, device): + b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg) + + # Compute global scale for activation + # NOTE: This is generally provided ahead-of-time by the model checkpoint. + a_amax = torch.abs(a).max().to(torch.float32) + a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax + + # Alpha for the GEMM operation + alpha = 1.0 / (a_global_scale * b_global_scale) + if "fbgemm" in cfg and cfg["fbgemm"]: + if cfg["no_a_quant"]: + a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale) + + def run(): + return torch.ops.fbgemm.f4f4bf16( + a_fp4, + b_fp4, + scale_a_fp4, + scale_b_fp4, + global_scale=alpha, + use_mx=False, + ) + + return run + else: + + def run(): + a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale) + return torch.ops.fbgemm.f4f4bf16( + a_fp4, + b_fp4, + scale_a_fp4, + scale_b_fp4, + global_scale=alpha, + use_mx=False, + ) + + return run + + if cfg["no_a_quant"]: + # Pre-quantize activation + a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale) + + def run(): + return ops.cutlass_scaled_fp4_mm( + a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype + ) + + return run + + # Quantize activation on-the-fly + def run(): + a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale) + return ops.cutlass_scaled_fp4_mm( + a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs NVFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_nvfp4_runner(cfg, a, b, dtype, device) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:") + save_dir = f"bench_nvfp4_res_n{N}_k{K}" + os.makedirs(save_dir, exist_ok=True) + + benchmark.run( + print_data=True, + show_plots=True, + save_path=save_dir, + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/benchmark_nvfp4_quant.py b/benchmarks/kernels/benchmark_nvfp4_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..c48353820b98dfc8d2b136f2cf02a7bd3c098a90 --- /dev/null +++ b/benchmarks/kernels/benchmark_nvfp4_quant.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.triton_utils import triton +from vllm.utils.flashinfer import flashinfer_fp4_quantize + +if not current_platform.has_device_capability(100): + raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)") + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +PROVIDER_CFGS = { + "vllm": dict(backend="vllm", is_sf_swizzled_layout=False, enabled=True), + "vllm-swizzle": dict(backend="vllm", is_sf_swizzled_layout=True, enabled=True), + "flashinfer": dict(backend="flashinfer", is_sf_swizzled_layout=False, enabled=True), + "flashinfer-swizzle": dict( + backend="flashinfer", is_sf_swizzled_layout=True, enabled=True + ), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor: + """Compute global scale for FP4 quantization.""" + amax = torch.abs(tensor).max().to(torch.float32) + return FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="us (lower is better)", + plot_name="NVFP4 Input Quantization Latency (us)", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + # Create input tensor + a = torch.randn((M, K), device=device, dtype=dtype) + + # Compute global scale for activation + a_global_scale = compute_global_scale(a) + + quantiles = [0.5, 0.2, 0.8] + + cfg = PROVIDER_CFGS[provider] + + if cfg["backend"] == "vllm": + # vLLM's FP4 quantization + if cfg["is_sf_swizzled_layout"]: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.scaled_fp4_quant( + a, a_global_scale, is_sf_swizzled_layout=True + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.scaled_fp4_quant( + a, a_global_scale, is_sf_swizzled_layout=False + ), + quantiles=quantiles, + ) + elif cfg["backend"] == "flashinfer": + # FlashInfer's FP4 quantization + if cfg["is_sf_swizzled_layout"]: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: flashinfer_fp4_quantize( + a, a_global_scale, is_sf_swizzled_layout=True + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: flashinfer_fp4_quantize( + a, a_global_scale, is_sf_swizzled_layout=False + ), + quantiles=quantiles, + ) + + # Convert ms to us for better readability at small batch sizes + to_us = lambda t_ms: t_ms * 1000 + return to_us(ms), to_us(max_ms), to_us(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +def _test_accuracy_once( + M: int, K: int, dtype: torch.dtype, device: str, is_sf_swizzled_layout: bool +): + """Test accuracy between vLLM and FlashInfer FP4 quantization.""" + # Create input tensor + a = torch.randn((M, K), device=device, dtype=dtype) + + # Compute global scale + a_global_scale = compute_global_scale(a) + + # vLLM quantization + vllm_fp4, vllm_scale = ops.scaled_fp4_quant( + a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + + # FlashInfer quantization (with swizzled layout to match vLLM's output) + flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize( + a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn) + + # Compare outputs + torch.testing.assert_close( + vllm_fp4, + flashinfer_fp4, + ) + # Compare scales + torch.testing.assert_close( + vllm_scale, + flashinfer_scale, + ) + print( + f"M={M}, K={K}, dtype={dtype}, is_sf_swizzled_layout={is_sf_swizzled_layout}: PASSED" # noqa: E501 + ) + + +def test_accuracy(): + """Run accuracy tests across various shapes.""" + print("\n" + "=" * 60) + print("Running accuracy tests: vLLM vs FlashInfer") + print("=" * 60) + + device = "cuda" + dtype = torch.bfloat16 + + # Test various batch sizes and hidden dimensions + Ms = [1, 1024] + Ks = [4096] + + for is_sf_swizzled_layout in [True, False]: + for M in Ms: + for K in Ks: + _test_accuracy_once(M, K, dtype, device, is_sf_swizzled_layout) + + print("\nAll accuracy tests passed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark NVFP4 quantization: vLLM vs FlashInfer" + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + parser.add_argument( + "--save-path", + type=str, + default=None, + help="Path to save benchmark results", + ) + parser.add_argument( + "--accuracy", + action="store_true", + help="Run accuracy tests", + ) + args = parser.parse_args() + + if args.accuracy: + test_accuracy() + + for K, N, model in prepare_shapes(args): + print(f"\n{model}, N={N} K={K}") + benchmark.run( + print_data=True, + save_path=args.save_path, + N=N, + K=K, + ) + + print("\nBenchmark finished!") diff --git a/benchmarks/kernels/benchmark_nvfp4_qutlass.py b/benchmarks/kernels/benchmark_nvfp4_qutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..6fecc816f9466ce3d2cf7e0f1780da0ea50cf39d --- /dev/null +++ b/benchmarks/kernels/benchmark_nvfp4_qutlass.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import itertools + +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm +from vllm._custom_ops import fusedQuantizeNv +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "nvfp4": dict(no_a_quant=False, enabled=True), + "nvfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _quant_weight_nvfp4( + b: torch.Tensor, + forward_hadamard_matrix: torch.Tensor, + global_scale: torch.Tensor, + device: str, + M: int, + N: int, + K: int, +): + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv( + b, forward_hadamard_matrix, global_scale + ) + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + return weight_hf_e2m1, weight_hf_scale_block + + +def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K): + alpha = torch.tensor([1.0], device="cuda") + global_scale = torch.tensor([1.0], device="cuda") + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4( + b, forward_hadamard_matrix, global_scale, device, M, N, K + ) + + if cfg["no_a_quant"]: + # Pre-quantize activation + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv( + a, forward_hadamard_matrix, global_scale + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + + def run(): + return ops.cutlass_scaled_fp4_mm( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + torch.bfloat16, + ) + + return run + + # Quantize activation on-the-fly + def run(): + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv( + a, forward_hadamard_matrix, global_scale + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + return ops.cutlass_scaled_fp4_mm( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + torch.bfloat16, + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 24576, + 32768, + ], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs NVFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, had_size): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_nvfp4_runner( + cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K + ) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), rep=200, quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + for had_size in [16, 32, 64, 128]: + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_nvfp4_res_n{N}_k{K}", + N=N, + K=K, + had_size=had_size, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..be871d3d1aa082b510748c46f4a08ae94579237c --- /dev/null +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import time + +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + create_kv_caches_with_random, + set_random_seed, +) + +logger = init_logger(__name__) + +NUM_BLOCKS = 128 * 1024 +PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 + + +@torch.inference_mode() +def main( + version: str, + num_seqs: int, + seq_len: int, + num_query_heads: int, + num_kv_heads: int, + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + do_profile: bool, + device: str = "cuda", + kv_cache_dtype: str | None = None, +) -> None: + set_random_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + query = torch.empty( + num_seqs, num_query_heads, head_size, dtype=dtype, device=device + ) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device) + + seq_lens = [seq_len for _ in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables_lst: list[list[int]] = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) + ] + block_tables_lst.append(block_table) + + block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device) + + # Create the KV cache. + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Prepare for the paged attention kernel. + output = torch.empty_like(query) + if version == "v2": + if current_platform.is_rocm(): + global PARTITION_SIZE + if not args.custom_paged_attn and not current_platform.is_navi(): + PARTITION_SIZE = 1024 + else: + PARTITION_SIZE = PARTITION_SIZE_ROCM + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE + tmp_output = torch.empty( + size=(num_seqs, num_query_heads, num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_query_heads, num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + # Using default kv_scale + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + for _ in range(num_iters): + if version == "v1": + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + elif version == "v2": + if not args.custom_paged_attn: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + None, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + raise ValueError(f"Invalid version: {version}") + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStop() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=3, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=100, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == "__main__": + logger.warning( + "This script benchmarks the paged attention kernel. " + "By default this is no longer used in vLLM inference." + ) + + parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") + parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--num-query-heads", type=int, default=64) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--use-alibi", action="store_true") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], + default="auto", + help="Data type for kv cache storage. If 'auto', will use model " + "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)", + ) + parser.add_argument( + "--custom-paged-attn", action="store_true", help="Use custom paged attention" + ) + args = parser.parse_args() + print(args) + + if args.num_query_heads % args.num_kv_heads != 0: + raise ValueError("num_query_heads must be divisible by num_kv_heads") + main( + version=args.version, + num_seqs=args.batch_size, + seq_len=args.seq_len, + num_query_heads=args.num_query_heads, + num_kv_heads=args.num_kv_heads, + head_size=args.head_size, + block_size=args.block_size, + use_alibi=args.use_alibi, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + kv_cache_dtype=args.kv_cache_dtype, + ) diff --git a/benchmarks/kernels/benchmark_per_token_group_quant.py b/benchmarks/kernels/benchmark_per_token_group_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..eba4d510258b67ba22e59d3000a1516048ba71b1 --- /dev/null +++ b/benchmarks/kernels/benchmark_per_token_group_quant.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import math +from collections.abc import Callable +from contextlib import contextmanager +from unittest.mock import patch + +import torch + +from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils +from vllm.platforms import current_platform + + +@contextmanager +def _triton_mode(): + """Temporarily force the Triton fallback path""" + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + yield + + +def _time_cuda( + fn: Callable[[], tuple[torch.Tensor, torch.Tensor]], + warmup_iters: int, + bench_iters: int, +) -> float: + # warmup + for _ in range(warmup_iters): + fn() + torch.cuda.synchronize() + + start = torch.Event(enable_timing=True) + end = torch.Event(enable_timing=True) + + start.record() + for _ in range(bench_iters): + fn() + end.record() + torch.cuda.synchronize() + + return start.elapsed_time(end) / bench_iters # ms/iter + + +def _run_single( + shape: tuple[int, int], + group_size: int, + dtype: str, + *, + column_major: bool = False, + scale_ue8m0: bool = False, + warmup_iters: int, + bench_iters: int, +) -> None: + num_tokens, hidden_dim = shape + + device = torch.device("cuda") + torch.manual_seed(42) + x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8 + + if dtype == "fp8": + + def cuda_impl(): + return fp8_utils.per_token_group_quant_fp8( + x, + group_size, + column_major_scales=column_major, + use_ue8m0=scale_ue8m0, + ) + + def triton_impl(): + with _triton_mode(): + return fp8_utils.per_token_group_quant_fp8( + x, + group_size, + column_major_scales=column_major, + use_ue8m0=scale_ue8m0, + ) + elif dtype == "int8": + + def cuda_impl(): + return int8_utils.per_token_group_quant_int8(x, group_size) + + def triton_impl(): + with _triton_mode(): + return int8_utils.per_token_group_quant_int8(x, group_size) + else: + raise ValueError("dtype must be 'fp8' or 'int8'") + + cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters) + triton_ms = _time_cuda(triton_impl, warmup_iters, bench_iters) + + speedup = triton_ms / cuda_ms if cuda_ms else math.inf + + cfg_desc = ( + f"shape={shape} gs={group_size:<3} col_major={column_major:<5} " + f"ue8m0={scale_ue8m0:<5} dtype={dtype}" + ) + print( + f"{cfg_desc:55} | CUDA {cuda_ms:7.3f} ms | Triton {triton_ms:7.3f} ms | " + f"speed-up ×{speedup:5.2f}" + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--warmup-iters", type=int, default=10) + parser.add_argument("--bench-iters", type=int, default=100) + parser.add_argument("--dtype", choices=["fp8", "int8", "both"], default="both") + return parser.parse_args() + + +if __name__ == "__main__": + if not current_platform.is_cuda(): + raise RuntimeError("CUDA device is required to run this benchmark.") + + args = parse_args() + warmup_iters, bench_iters = args.warmup_iters, args.bench_iters + + shapes = [(32, 128), (64, 256), (16, 512)] + group_sizes = [64, 128] + + dtypes = ["fp8", "int8"] if args.dtype == "both" else [args.dtype] + + header = ( + "Configuration".ljust(55) + + " | " + + "CUDA (ms)".center(12) + + " | " + + "Triton (ms)".center(13) + + " | " + + "Speed-up" + ) + print(header) + print("-" * len(header)) + + for dtype in dtypes: + for shape in shapes: + for gs in group_sizes: + if dtype == "fp8": + for col_major in (False, True): + for ue8m0 in (False, True): + _run_single( + shape, + gs, + dtype, + column_major=col_major, + scale_ue8m0=ue8m0, + warmup_iters=warmup_iters, + bench_iters=bench_iters, + ) + else: # INT8 has no col-major / ue8m0 switches + _run_single( + shape, + gs, + dtype, + warmup_iters=warmup_iters, + bench_iters=bench_iters, + ) diff --git a/benchmarks/kernels/benchmark_per_token_quant_fp8.py b/benchmarks/kernels/benchmark_per_token_quant_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce97e30368b735a5c860c9d7549ffbb42e610e8 --- /dev/null +++ b/benchmarks/kernels/benchmark_per_token_quant_fp8.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from collections.abc import Callable +from unittest.mock import patch + +import pandas as pd +import torch + +from vllm.benchmarks.lib.utils import default_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + + +def with_triton_mode(fn): + """Temporarily force the Triton fallback path""" + + def wrapped(*args, **kwargs): + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + return fn(*args, **kwargs) + + return wrapped + + +# TODO(luka): use standalone_compile utility +def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): + def inner(*args): + torch._dynamo.mark_dynamic(args[arg_index], dim_index) + return fn(*args) + + return inner + + +def bench_compile(fn: Callable): + # recompile for different shapes + fwd = torch.compile(fn, fullgraph=True, dynamic=False) + + # First dim is explicitly dynamic to simulate vLLM usage + return with_dyn_arg(fwd, 0, 0) + + +torch._dynamo.config.recompile_limit = 8888 + + +def calculate_diff( + batch_size: int, + hidden_size: int, + group_shape: GroupShape, + dtype: torch.dtype, +): + """Calculate the difference between Inductor and CUDA implementations.""" + device = torch.device("cuda") + x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device) + + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False) + + torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x) + torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x) + cuda_out, cuda_scale = quant_fp8.forward_cuda(x) + + try: + torch.testing.assert_close( + cuda_out.to(torch.float32), + torch_out.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5) + torch.testing.assert_close( + cuda_out.to(torch.float32), + torch_eager_out.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5) + print("✅ All implementations match") + except AssertionError as e: + print("❌ Implementations differ") + print(e) + + +configs = [] + + +@default_vllm_config() +def benchmark_quantization( + batch_size, + hidden_size, + provider, + group_shape: GroupShape, + col_major: bool, + dtype: torch.dtype, +): + device = torch.device("cuda") + + x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major) + + if provider == "torch": + fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone()) + elif provider == "cuda": + fn = lambda: quant_fp8.forward_cuda(x.clone()) + elif provider == "triton": + if not group_shape.is_per_group(): + # Triton only supported for per-group + return 0, 0, 0 + + fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone()) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +# TODO(luka) extract to utils +def compute_geomean_speedups( + df: pd.DataFrame, + baseline_col: str, + speedup_cols: list[str], + groupby_cols: list[str] | None = None, +) -> pd.DataFrame: + """ + Compute geometric mean speedups over a baseline column. + + Args: + df: Input dataframe + baseline_col: Column to use as baseline + speedup_cols: Columns to compute speedups for + groupby_cols: Columns to group by. If None, compute over entire df. + + Returns: + pd.DataFrame with geometric mean speedups + """ + from scipy.stats import gmean + + def geo_speedup(group: pd.DataFrame) -> pd.Series: + ratios = { + col: (group[baseline_col] / group[col]).values for col in speedup_cols + } + return pd.Series({col: gmean(vals) for col, vals in ratios.items()}) + + if groupby_cols is None: + result = geo_speedup(df).to_frame().T + else: + result = ( + df.groupby(groupby_cols) + .apply(geo_speedup, include_groups=False) + .reset_index() + ) + + return result + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the various implementations of QuantFP8 (dynamic-only)" + ) + parser.add_argument("-c", "--check", action="store_true") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" + ) + parser.add_argument( + "--hidden-sizes", + type=int, + nargs="+", + default=[896, 1024, 2048, 4096, 7168], + help="Hidden sizes to benchmark", + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=[1, 16, 128, 512, 1024], + help="Batch sizes to benchmark", + ) + parser.add_argument( + "--group-sizes", + type=int, + nargs="+", + default=None, + help="Group sizes for GroupShape(1,N) to benchmark. " + "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)", + ) + parser.add_argument( + "--no-column-major", + action="store_true", + help="Disable column-major scales testing", + ) + + args = parser.parse_args() + assert args + + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + hidden_sizes = args.hidden_sizes + batch_sizes = args.batch_sizes + + if args.group_sizes is not None: + group_shapes = [] + for size in args.group_sizes: + if size == 0: + group_shapes.append(GroupShape.PER_TENSOR) + elif size == -1: + group_shapes.append(GroupShape.PER_TOKEN) + else: + group_shapes.append(GroupShape(1, size)) + else: + group_shapes = [ + GroupShape.PER_TENSOR, + GroupShape.PER_TOKEN, + GroupShape(1, 64), + GroupShape(1, 128), + ] + + column_major_scales = [False] if args.no_column_major else [True, False] + + config_gen = itertools.product( + group_shapes, + column_major_scales, + batch_sizes, + hidden_sizes, + ) + + # filter out column-major scales for non-group, reverse order + configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1])) + + print(f"Running {len(configs)} configurations:") + print(f" Hidden sizes: {hidden_sizes}") + print(f" Batch sizes: {batch_sizes}") + print(f" Group shapes: {[str(g) for g in group_shapes]}") + print(f" Column major scales: {column_major_scales}") + print() + + if args.check: + for group_shape in group_shapes: + group_size = group_shape[1] + print(f"{group_size=}") + calculate_diff( + batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype + ) + + benchmark = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size", "col_major", "group_shape"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "cuda", "triton"], + line_names=["Torch (Compiled)", "CUDA", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("black", "-")], + ylabel="us", + plot_name="QuantFP8 performance", + args={}, + ) + )(benchmark_quantization) + + df = benchmark.run(print_data=True, dtype=dtype, return_df=True) + + # Print geomean speedups + geo_table_grouped = compute_geomean_speedups( + df, + baseline_col="Torch (Compiled)", + speedup_cols=["CUDA", "Triton"], + groupby_cols=["col_major", "group_shape"], + ) + + print("Speedup over Torch (Compiled)") + print(geo_table_grouped.to_string(index=False)) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..9a21cfe94e5be1d69114fe049a6f8167eaf36592 --- /dev/null +++ b/benchmarks/kernels/benchmark_quant.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time + +import torch + +from vllm import _custom_ops as ops +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed + + +@torch.inference_mode() +def main( + num_tokens: int, + hidden_size: int, + static_scale: bool, + quant_dtype: torch.dtype, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: + set_random_seed(seed) + torch.set_default_device("cuda") + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if quant_dtype == torch.int8: + ops.scaled_int8_quant(x, scale) + else: + ops.scaled_fp8_quant(x, scale) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStop() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError(f"Unsupported dtype: {dt}") + + parser = FlexibleArgumentParser( + description="Benchmark the quantization (fp8 or int8) kernel." + ) + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--static-scale", action="store_true") + parser.add_argument( + "--quant-dtype", type=str, choices=["fp8", "int8"], default="int8" + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) + + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored", + ) + + args = parser.parse_args() + print(args) + + main( + num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + static_scale=args.static_scale, + quant_dtype=to_torch_dtype(args.quant_dtype), + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + ) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..99067d8ac3710fc7f86dcd3017b3a8ea218426de --- /dev/null +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +import time + +import torch +from tabulate import tabulate + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + create_kv_caches_with_random, + set_random_seed, +) + +logger = init_logger(__name__) + + +@torch.inference_mode() +def run_benchmark( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + kv_cache_dtype: str, + num_iters: int, + benchmark_mode: str, + device: str = "cuda", +) -> float: + """Return latency (seconds) for given num_tokens.""" + + if kv_cache_dtype == "fp8" and head_size % 16: + raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") + + set_random_seed(42) + torch.set_default_device(device) + + # create random key / value tensors [T, H, D]. + key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device) + value = torch.randn_like(key) + + # prepare the slot mapping. + # each token is assigned a unique slot in the KV-cache. + num_slots = block_size * num_blocks + if num_tokens > num_slots: + raise ValueError("num_tokens cannot exceed the total number of cache slots") + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + key_caches, value_caches = create_kv_caches_with_random( + num_blocks, + block_size, + 1, # num_layers + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + # to free unused memory + del key_caches, value_caches + + # compute per-kernel scaling factors for fp8 conversion (if used). + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) + + function_under_test = lambda: ops.reshape_and_cache( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + + if benchmark_mode == "cudagraph": + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + function_under_test() + torch.cuda.synchronize() + function_under_test = lambda: g.replay() + + def run_cuda_benchmark(n_iters: int) -> float: + nonlocal key, value, key_cache, value_cache, slot_mapping + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(n_iters): + function_under_test() + torch.cuda.synchronize() + end = time.perf_counter() + return (end - start) / n_iters + + # warm-up + run_cuda_benchmark(3) + + lat = run_cuda_benchmark(num_iters) + + # free tensors to mitigate OOM when sweeping + del key, value, key_cache, value_cache, slot_mapping + torch.cuda.empty_cache() + + return lat + + +def main(args): + rows = [] + for exp in range(1, 17): + n_tok = 2**exp + lat = run_benchmark( + num_tokens=n_tok, + num_heads=args.num_heads, + head_size=args.head_size, + block_size=args.block_size, + num_blocks=args.num_blocks, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + kv_cache_dtype=args.kv_cache_dtype, + num_iters=args.iters, + benchmark_mode=args.mode, + device="cuda", + ) + rows.append([n_tok, lat * 1e6]) # convert to microseconds + + print(f"Benchmark results for implementation cuda (measuring with {args.mode}):") + print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f")) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + + parser.add_argument("--num-heads", type=int, default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--num-blocks", type=int, default=128 * 128) + + parser.add_argument( + "--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="bfloat16", + ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8"], + default="auto", + ) + + parser.add_argument("--iters", type=int, default=200) + + parser.add_argument( + "--mode", + type=str, + choices=["cudagraph", "no_graph"], + default="cudagraph", + ) + + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py new file mode 100644 index 0000000000000000000000000000000000000000..ef6be1f3c3597c9d4922b6bba8ad4128fecfbd0a --- /dev/null +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +import time + +import torch +from tabulate import tabulate + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + create_kv_caches_with_random_flash, + set_random_seed, +) +from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, +) + +logger = init_logger(__name__) + + +@torch.inference_mode() +def run_benchmark( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + kv_cache_dtype: str, + kv_cache_layout: str, + num_iters: int, + implementation: str, + benchmark_mode: str, + device: str = "cuda", +) -> float: + """Return latency (seconds) for given num_tokens.""" + + if kv_cache_dtype == "fp8" and head_size % 16: + raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") + + if implementation not in ("cuda", "triton"): + raise ValueError( + f"Unsupported implementation: {implementation}. " + "Only 'cuda' and 'triton' are supported." + ) + if implementation == "triton" and kv_cache_layout == "HND": + return float("nan") # Triton does not support HND layout yet. + + set_random_seed(42) + torch.set_default_device(device) + + # create random key / value tensors [T, H, D]. + key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device) + value = torch.randn_like(key) + + # prepare the slot mapping. + # each token is assigned a unique slot in the KV-cache. + num_slots = block_size * num_blocks + if num_tokens > num_slots: + raise ValueError("num_tokens cannot exceed the total number of cache slots") + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + key_caches, value_caches = create_kv_caches_with_random_flash( + num_blocks, + block_size, + 1, # num_layers + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + cache_layout=kv_cache_layout, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + # to free unused memory + del key_caches, value_caches + + # compute per-kernel scaling factors for fp8 conversion (if used). + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) + + if implementation == "cuda": + function_under_test = lambda: ops.reshape_and_cache_flash( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + function_under_test = lambda: triton_reshape_and_cache_flash( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + if benchmark_mode == "cudagraph": + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + function_under_test() + torch.cuda.synchronize() + function_under_test = lambda: g.replay() + + def run_cuda_benchmark(n_iters: int) -> float: + nonlocal key, value, key_cache, value_cache, slot_mapping + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(n_iters): + function_under_test() + torch.cuda.synchronize() + end = time.perf_counter() + return (end - start) / n_iters + + # warm-up + run_cuda_benchmark(3) + + lat = run_cuda_benchmark(num_iters) + + # free tensors to mitigate OOM when sweeping + del key, value, key_cache, value_cache, slot_mapping + torch.cuda.empty_cache() + + return lat + + +def main(args): + rows = [] + for layout in ["NHD", "HND"]: + for exp in range(1, 17): + n_tok = 2**exp + lat = run_benchmark( + num_tokens=n_tok, + num_heads=args.num_heads, + head_size=args.head_size, + block_size=args.block_size, + num_blocks=args.num_blocks, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + kv_cache_dtype=args.kv_cache_dtype, + kv_cache_layout=layout, + num_iters=args.iters, + implementation=args.implementation, + benchmark_mode=args.mode, + device="cuda", + ) + rows.append([n_tok, layout, f"{lat * 1e6:.3f}"]) + + print( + f"Benchmark results for implementation {args.implementation}" + f" (measuring with {args.mode}):" + ) + print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"])) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + + parser.add_argument("--num-heads", type=int, default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--num-blocks", type=int, default=128 * 512) + + parser.add_argument( + "--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="bfloat16", + ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8"], + default="auto", + ) + + parser.add_argument("--iters", type=int, default=100) + + parser.add_argument( + "--implementation", + type=str, + choices=["cuda", "triton"], + default="cuda", + ) + + parser.add_argument( + "--mode", + type=str, + choices=["cudagraph", "no_graph"], + default="cudagraph", + ) + + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d7f5bcf9dada3b499e133ac7d7b262583fb615 --- /dev/null +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools + +import torch +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from torch import nn + +from vllm import _custom_ops as vllm_ops +from vllm.triton_utils import triton + + +class HuggingFaceRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor | None = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor | None = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor | None = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_vllm = rmsnorm_vllm( + x.clone(), weight, residual.clone() if residual is not None else None + ) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"vLLM output={output_vllm}") + + if torch.allclose( + output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 + ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range)) + + +def get_benchmark(use_residual): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm"], + line_names=["HuggingFace", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + args={}, + ) + ) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument( + "--use-residual", action="store_true", help="Whether to use residual connection" + ) + parser.add_argument( + "--save-path", + type=str, + default="./configs/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual, + ) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1df3b2939abf2a7632c7148d6794bbc6b53167 --- /dev/null +++ b/benchmarks/kernels/benchmark_rope.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools + +import torch + +from vllm.benchmarks.lib.utils import default_vllm_config +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser + +batch_size_range = [2**i for i in range(0, 8, 2)] +seq_len_range = [2**i for i in range(6, 10, 1)] +num_heads_range = [32, 48] +configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range)) + + +def get_benchmark(head_size, rotary_dim, is_neox_style, device): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "num_heads"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch", "flashinfer", "vllm"], + line_names=["PyTorch", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}", + args={}, + ) + ) + @default_vllm_config() + def benchmark(batch_size, seq_len, num_heads, provider): + dtype = torch.bfloat16 + max_position = 8192 + rope_parameters = {"partial_rotary_factor": rotary_dim / head_size} + rope = get_rope(head_size, max_position, is_neox_style, rope_parameters) + rope = rope.to(dtype=dtype, device=device) + cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) + + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + query = torch.randn( + (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device + ) + key = torch.randn_like(query) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_native(positions, query.clone(), key.clone()), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.ops.vllm.flashinfer_rotary_embedding( + positions, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + is_neox_style, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_cuda(positions, query.clone(), key.clone()), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the rotary embedding kernels." + ) + parser.add_argument("--is-neox-style", type=bool, default=True) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--num-heads", type=int, default=8) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) + parser.add_argument( + "--dtype", type=str, choices=["bfloat16", "float"], default="float" + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" + ) + parser.add_argument("--save-path", type=str, default="./configs/rope/") + args = parser.parse_args() + + # Get the benchmark function + benchmark = get_benchmark( + args.head_size, args.rotary_dim, args.is_neox_style, args.device + ) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..3e23c4cac059c04cf6a4153e9830cbec2ace36f0 --- /dev/null +++ b/benchmarks/kernels/benchmark_shapes.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +WEIGHT_SHAPES = { + "ideal": [[4 * 256 * 32, 256 * 32]], + "mistralai/Mistral-7B-v0.1/TP1": [ + [4096, 6144], + [4096, 4096], + [4096, 28672], + [14336, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP2": [ + [4096, 3072], + [2048, 4096], + [4096, 14336], + [7168, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP4": [ + [4096, 1536], + [1024, 4096], + [4096, 7168], + [3584, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP1": [ + [4096, 12288], + [4096, 4096], + [4096, 22016], + [11008, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP2": [ + [4096, 6144], + [2048, 4096], + [4096, 11008], + [5504, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP4": [ + [4096, 3072], + [1024, 4096], + [4096, 5504], + [2752, 4096], + ], + "meta-llama/Llama-2-13b-hf/TP1": [ + [5120, 15360], + [5120, 5120], + [5120, 27648], + [13824, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP2": [ + [5120, 7680], + [2560, 5120], + [5120, 13824], + [6912, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP4": [ + [5120, 3840], + [1280, 5120], + [5120, 6912], + [3456, 5120], + ], + "meta-llama/Llama-2-70b-hf/TP1": [ + [8192, 10240], + [8192, 8192], + [8192, 57344], + [28672, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP2": [ + [8192, 5120], + [4096, 8192], + [8192, 28672], + [14336, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP4": [ + [8192, 2560], + [2048, 8192], + [8192, 14336], + [7168, 8192], + ], +} + +WEIGHT_SHAPES_MOE = { + "mistralai/Mixtral-8x7B-Instruct-v0.1": [ + [8, 2, 4096, 28672], + [8, 2, 14336, 4096], + ], + "deepseek-ai/DeepSeek-V2-Lite": [ + [64, 6, 2048, 1408], + ], + "ibm-granite/granite-3.0-1b-a400m": [ + [32, 8, 1024, 1024], + ], + "ibm-granite/granite-3.0-3b-a800m": [ + [40, 8, 1024, 1536], + ], +} diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..da32bc30cb2ae3b385b79c852334f1594a4fe52d --- /dev/null +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -0,0 +1,720 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Comprehensive 3-way SiLU Benchmark Suite + +This benchmark compares three SiLU implementations: +1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation +2. Triton Kernel - Triton-based implementation + +The suite generates detailed performance comparisons including: +- Memory bandwidth utilization +- Speedup ratios (baseline vs optimized implementations) +- Performance across different expert configurations and token distributions +""" + +from collections.abc import Callable + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + persistent_masked_m_silu_mul_quant, +) +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.torch_utils import set_random_seed + + +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) + # Sizes --------------------------------------------------------------- + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + # Stride for counts (elements) + stride_counts_e, + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h + base_ys_offset = e * stride_ys_e + g * stride_ys_g + + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): + gate = tl.load( + input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0 + ).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0) + + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) + + +def silu_mul_fp8_quant_deep_gemm_triton( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens, + group_size: int = 128, + eps: float = 1e-10, + expert_offsets: torch.Tensor = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = (H + group_size - 1) // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, ( + "tokens_per_expert must be shape (E,)" + ) + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, + ) + + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G,) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) + + return y_q, y_s + + +# Parse generation strategies +strategies = ["random_imbalanced", "uniform", "max_t"] + + +def benchmark( + kernel: Callable, + E: int, + T: int, + H: int, + total_tokens: int, + num_parallel_tokens: int = 64, + G: int = 128, + runs: int = 200, + num_warmups: int = 20, + gen_strategy: str = "default", + iterations_per_run: int = 20, +): + def generate_data(seed_offset=0): + """Generate input data with given seed offset""" + set_random_seed(42 + seed_offset) + y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() + + if gen_strategy == "random_imbalanced": + + def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"): + mean = total_tokens // n_e + min_max = mean // ratio + e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean + e[0] = min_max + r = torch.rand(size=(E - 1,)) + r /= r.sum() + r *= total_tokens - min_max + r = r.round().long() + e[1:] = r.to(device=device) + return e + + tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda") + elif gen_strategy == "uniform": + r = torch.rand(size=(E,)) + r /= r.sum() + r *= total_tokens + r = r.round().long() + tokens_per_expert = r + elif gen_strategy == "max_t": + tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert.fill_(total_tokens / E) + elif gen_strategy == "first_t": + tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert[0] = min(T, total_tokens) + else: + raise ValueError(f"Unknown generation strategy: {gen_strategy}") + return y, tokens_per_expert + + dataset_count = 4 + # Pre-generate different input matrices for each iteration to avoid cache effects + data_sets = [generate_data(i) for i in range(dataset_count)] + + # Warmup + y, tokens_per_expert = data_sets[0] + for _ in range(num_warmups): + kernel( + y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G + ) + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + # Benchmark + latencies: list[float] = [] + for _ in range(runs): + torch.cuda.synchronize() + + start_event.record() + for i in range(iterations_per_run): + y, tokens_per_expert = data_sets[i % dataset_count] + kernel( + y, + tokens_per_expert, + num_parallel_tokens=num_parallel_tokens, + group_size=G, + ) + end_event.record() + end_event.synchronize() + + total_time_ms = start_event.elapsed_time(end_event) + per_iter_time_ms = total_time_ms / iterations_per_run + latencies.append(per_iter_time_ms) + + # Use median instead of average for better outlier handling + median_time_ms = np.median(latencies) + median_time_s = median_time_ms / 1000 + + # Calculate actual work done (using first dataset for consistency) + _, tokens_per_expert = data_sets[0] + actual_tokens = tokens_per_expert.sum().item() + actual_elements = actual_tokens * H + + # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops + ops_per_element = 8 + total_ops = actual_elements * ops_per_element + gflops = total_ops / median_time_s / 1e9 + + # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) + input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs + output_bytes = actual_tokens * H * 1 # H fp8 outputs + scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 + total_bytes = input_bytes + output_bytes + scale_bytes + memory_bw = total_bytes / median_time_s / 1e9 + + HOPPER_BANDWIDTH_TBPS = 3.35 + return ( + median_time_ms, + gflops, + memory_bw, + (memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100, + ) + + +def create_comparison_plot( + ratios, silu_v2_times, triton_times, config_labels, strategy_name, id +): + fig, ax = plt.subplots(1, 1, figsize=(18, 6)) + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.25 + + # Execution Time plot (lower is better) + ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue") + ax.bar( + x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green" + ) + + # Add speedup labels over each bar trio + for i in range(len(x)): + triton_v2_speedup = ratios[i][1] # triton/v2 + max_height = max(silu_v2_times[i], triton_times[i]) + + # Triton/V2 speedup + ax.text( + x[i] + width / 2, + max_height + max_height * 0.02, + f"{triton_v2_speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=8, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + return fig, ax + + +def create_combined_plot(all_results): + num_strategies = len(all_results) + fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies)) + + if num_strategies == 1: + axes = [axes] + + for idx, ( + strategy_name, + all_ratios, + all_silu_v2_results, + all_triton_results, + config_labels, + config_x_axis, + ) in enumerate(all_results): + ax = axes[idx] + + # Flatten the nested results to get bandwidth percentages for plotting + silu_v2_bandwidths = [] + triton_bandwidths = [] + flat_ratios = [] + + for config_results in all_silu_v2_results: + for result in config_results: + silu_v2_bandwidths.append(result[3]) # bandwidth percentage + + for config_results in all_triton_results: + for result in config_results: + triton_bandwidths.append(result[3]) # bandwidth percentage + + for config_ratios in all_ratios: + for ratio in config_ratios: + flat_ratios.append(ratio) + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.25 + + # Bandwidth utilization plot (higher is better) + ax.bar( + x, + silu_v2_bandwidths, + width, + label="SiLU V2 (CUDA)", + alpha=0.8, + color="blue", + ) + ax.bar( + x + width, + triton_bandwidths, + width, + label="Triton Kernel", + alpha=0.8, + color="green", + ) + + # Add speedup labels over each bar trio + for i in range(len(x)): + triton_v2_speedup = flat_ratios[i] # triton/v2 + max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i]) + + # Triton/V2 speedup + ax.text( + x[i] + width / 2, + max_height + max_height * 0.02, + f"{triton_v2_speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=8, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + filename = "silu_benchmark_combined_3way.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +outer_dim = 7168 +configs = [ + # DeepSeekV3 Configs + # (1, 56, 7168), + (8, 1024, 7168), + # (32, 56, 7168), + # DeepSeekV3 Configs + (32, 1024, 7168), + # DeepSeekV3 Configs + (256, 1024, 7168), +] + +runs = 100 +num_warmups = 20 + +strategy_descriptions = { + "uniform": "Uniform Random", + "random_imbalanced": "Imbalanced Random", + "max_t": "Even Assignment", + "first_t": "experts[0] = T, experts[1:] = 0", +} + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"Testing strategies: {', '.join(strategies)}") +print(f"Configurations: {len(configs)} configs") + +all_results = [] + +# Run benchmarks for each strategy +for id, strategy in enumerate(strategies): + print(f"\n{'=' * 60}") + print(f"Testing strategy: {strategy_descriptions[strategy]}") + print(f"{'=' * 60}") + + # Collect benchmark data for all three algorithms + config_labels = [] + config_x_axis = [] + all_silu_v2_results = [] + all_triton_results = [] + all_ratios = [] + + for E, T, H in configs: + total_tokens_config = [] + for i in [8, 16, 32, 64, 128, 256, 512]: + if i <= T: + total_tokens_config.append(i * E) + config_x_axis.append(total_tokens_config) + + silu_v2_results = [] + triton_results = [] + ratios = [] + + for total_tokens in total_tokens_config: + config_label = f"E={E},T={T},H={H},TT={total_tokens}" + config_labels.append(config_label) + + # SiLU V2 (CUDA kernel) results + time_ms_silu_v2, gflops, gbps, perc = benchmark( + persistent_masked_m_silu_mul_quant, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc)) + + # Triton kernel results + time_ms_triton, gflops, gbps, perc = benchmark( + silu_mul_fp8_quant_deep_gemm_triton, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + triton_results.append((time_ms_triton, gflops, gbps, perc)) + + # Calculate speedup ratios (triton baseline / implementation) + triton_v2_ratio = time_ms_triton / time_ms_silu_v2 + ratios.append(triton_v2_ratio) + + print( + f"Completed: {config_label}:" + f" V2: {time_ms_silu_v2:.3f}ms," + f" Triton: {time_ms_triton:.3f}ms" + ) + + all_silu_v2_results.append(silu_v2_results) + all_triton_results.append(triton_results) + all_ratios.append(ratios) + + # Store results for combined plotting + all_results.append( + ( + strategy_descriptions[strategy], + all_ratios, + all_silu_v2_results, + all_triton_results, + config_labels, + config_x_axis, + ) + ) + + # Print summary table for this strategy + print(f"\nSummary Table - {strategy_descriptions[strategy]}:") + print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}") + print("-" * 90) + + for i, (E, T, H) in enumerate(configs): + # Get the first result for each config (simplifying for summary) + v2_time = silu_v2_results[i][0] + triton_time = triton_results[i][0] + triton_v2_speedup = triton_time / v2_time + config_label = f"E={E:3d},T={T:4d},H={H:4d}" + print( + f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} " + f"{triton_v2_speedup:8.2f}x" + ) + + +def create_total_tokens_plot(all_results): + num_strategies = len(all_results) + num_configs = len(configs) + + fig, axs = plt.subplots( + num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies) + ) + + # Add main title to the entire figure + fig.suptitle( + "Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)", + fontsize=18, + fontweight="bold", + y=0.98, + ) + + # Handle single strategy case + if num_strategies == 1: + axs = axs.reshape(1, -1) + + # Handle single config case + if num_configs == 1: + axs = axs.reshape(-1, 2) + + for strategy_idx, result in enumerate(all_results): + ( + strategy_name, + all_ratios, + all_silu_v2_results, + all_triton_results, + config_labels, + config_x_axis, + ) = result + + for config_idx in range(num_configs): + # Speedup plot (left column) + ax_speedup = axs[strategy_idx, config_idx * 2] + # Bandwidth plot (right column) + ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1] + + E, T, H = configs[config_idx] + ratios = all_ratios[config_idx] + total_tokens_values = config_x_axis[config_idx] + + # Extract speedup ratios + triton_v2_ratios = [ratio for ratio in ratios] + + # Extract bandwidth percentages for all implementations + v2_bandwidth_percentages = [ + result[3] for result in all_silu_v2_results[config_idx] + ] + triton_bandwidth_percentages = [ + result[3] for result in all_triton_results[config_idx] + ] + + # Plot speedup ratios vs total tokens (left plot) + ax_speedup.plot( + total_tokens_values, + triton_v2_ratios, + "go-", + linewidth=3, + markersize=8, + label="Triton/V2 Speedup", + ) + ax_speedup.set_title( + f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) + ax_speedup.legend(prop={"weight": "bold"}) + ax_speedup.grid(True, alpha=0.3) + + # Plot bandwidth utilization (right plot) + ax_bandwidth.plot( + total_tokens_values, + v2_bandwidth_percentages, + "o-", + linewidth=3, + markersize=8, + label="SiLU V2", + color="blue", + ) + ax_bandwidth.plot( + total_tokens_values, + triton_bandwidth_percentages, + "o-", + linewidth=3, + markersize=8, + label="Triton", + color="green", + ) + ax_bandwidth.set_title( + f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_bandwidth.set_ylabel( + "% of Peak Bandwidth", fontweight="bold", fontsize=11 + ) + ax_bandwidth.legend(prop={"weight": "bold"}) + ax_bandwidth.grid(True, alpha=0.3) + + # Format x-axis labels for both plots + for ax in [ax_speedup, ax_bandwidth]: + ax.set_xticks(total_tokens_values) + ax.set_xticklabels( + [ + f"{tt // 1000}K" if tt >= 1000 else str(tt) + for tt in total_tokens_values + ], + fontweight="bold", + ) + # Make tick labels bold + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight("bold") + + # Add value labels on Triton/V2 speedup points + for x, y in zip(total_tokens_values, triton_v2_ratios): + ax_speedup.annotate( + f"{y:.2f}x", + (x, y), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=9, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3), + ) + + plt.tight_layout() + plt.subplots_adjust(top=0.93) # Make room for main title + filename = "silu_benchmark_total_tokens_3way.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +# Create comprehensive 3-way comparison plots +combined_plot_filename = create_combined_plot(all_results) +total_tokens_plot_filename = create_total_tokens_plot(all_results) + +print(f"\n{'=' * 80}") +print("3-Way Benchmark Suite Complete!") +print(f"Generated combined comparison plot: {combined_plot_filename}") +print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}") +print("Compared: SiLU V2 (CUDA), and Triton implementations") +print(f"{'=' * 80}") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0d6fbb9a470582773c0eb6fc605a210e180cfc --- /dev/null +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import csv +import os +from datetime import datetime + +import flashinfer +import torch + +from vllm.utils.math_utils import round_up + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@torch.no_grad() +def benchmark_decode( + dtype: torch.dtype, + quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, +): + torch.set_default_device("cuda") + torch.manual_seed(0) + + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) + + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / block_size) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + if q_quant_dtype == FP8_DTYPE: + query, _ = to_float8(ref_query) + else: + query = ref_query + + kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_seq_len + + seq_lens = kv_lens + max_seq_len = torch.max(seq_lens).item() + + # Always using 1.0 scale to reflect the real perf in benchmarking + k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, _ = to_float8(ref_kv_cache) + else: + kv_cache = ref_kv_cache + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) + + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=True, + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + ) + + def time_fn(fn, warmup=10, trials=20): + torch.cuda.synchronize() + start = torch.Event(enable_timing=True) + end = torch.Event(enable_timing=True) + times = [] + for i in range(warmup): + fn() + for i in range(trials): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return sum(times) / len(times), torch.std(torch.tensor(times)) + + o_scale = 1.0 + o_sf_scale = None + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + + def baseline_decode(): + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) + + def trtllm_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, + out=output_trtllm, + ) + + baseline_mean, baseline_std = time_fn(baseline_decode) + trtllm_mean, trtllm_std = time_fn(trtllm_decode) + + # Calculate percentage speedup (positive means TRT is faster) + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean + + print( + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}" + f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" + ) + + # Return results for CSV writing + return { + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), + "baseline_mean": baseline_mean, + "baseline_std": baseline_std.item(), + "speedup_percent": speedup_percent, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, + "num_kv_heads": num_kv_heads, + "head_size": head_size, + "max_seq_len": max_seq_len, + } + + +def write_results_to_csv(results, filename=None): + """Write benchmark results to CSV file.""" + if filename is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" + + fieldnames = [ + "batch_size", + "trtllm_mean", + "trtllm_std", + "baseline_mean", + "baseline_std", + "speedup_percent", + "q_dtype", + "kv_cache_dtype", + "output_dtype", + "block_size", + "num_kv_heads", + "head_size", + "max_seq_len", + ] + + file_exists = os.path.exists(filename) + + with open(filename, "a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + if not file_exists: + writer.writeheader() + + for result in results: + writer.writerow(result) + + print(f"Results written to {filename}") + + +if __name__ == "__main__": + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] + max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + all_results = [] + + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), + ] + + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_decode( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) + + # Write all results to CSV + write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..84bde723abf7fa02090c783296092540571845da --- /dev/null +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import csv +import os +from datetime import datetime + +import flashinfer +import torch + +from vllm.utils.math_utils import round_up + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@torch.no_grad() +def benchmark_prefill( + dtype: torch.dtype, + quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None], + batch_size: int, + max_seq_len: int, + num_heads: tuple[int, int] = (64, 8), + head_size: int = 128, + kv_layout: str = "HND", + block_size: int = 16, + warmup: int = 10, + trials: int = 20, +): + torch.set_default_device("cuda") + torch.manual_seed(0) + + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + max_q_len = max_kv_len = max_seq_len + + num_qo_heads, num_kv_heads = num_heads + assert num_qo_heads % num_kv_heads == 0 + + sm_scale = float(1.0 / (head_size**0.5)) + + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / block_size) + + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) + q_lens[-1] = max_q_len + q_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ] + ) + + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn( + torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype + ) + if q_quant_dtype == FP8_DTYPE: + query, _ = to_float8(ref_query) + else: + query = ref_query + + kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens[-1] = max_kv_len + + seq_lens = kv_lens + q_lens + max_seq_len = torch.max(seq_lens).item() + + # Always using 1.0 scale to reflect the real perf in benchmarking + k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + if kv_quant_dtype == FP8_DTYPE: + kv_cache, _ = to_float8(ref_kv_cache) + else: + kv_cache = ref_kv_cache + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(batch_size): + seq_len = seq_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) + + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + causal=True, + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + ) + + def time_fn(fn, warmup=10, trials=20): + torch.cuda.synchronize() + start = torch.Event(enable_timing=True) + end = torch.Event(enable_timing=True) + times = [] + for i in range(warmup): + fn() + for i in range(trials): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return sum(times) / len(times), torch.std(torch.tensor(times)) + + o_scale = 1.0 + o_sf_scale = None + output_baseline = torch.empty(ref_query.shape, dtype=dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + + def baseline_prefill(): + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) + + def trtllm_prefill(): + return flashinfer.prefill.trtllm_batch_context_with_kv_cache( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=seq_lens, + max_q_len=max_q_len, + max_kv_len=max_seq_len, + bmm1_scale=q_scale * k_scale * sm_scale, + bmm2_scale=v_scale / o_scale, + batch_size=batch_size, + cum_seq_lens_q=q_indptr, + cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, + out=output_trtllm, + ) + + baseline_mean, baseline_std = time_fn(baseline_prefill) + trtllm_mean, trtllm_std = time_fn(trtllm_prefill) + + # Calculate percentage speedup (positive means TRT is faster) + speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean + + print( + f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}" + f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}" + ) + + # Return results for CSV writing + return { + "batch_size": batch_size, + "trtllm_mean": trtllm_mean, + "trtllm_std": trtllm_std.item(), + "baseline_mean": baseline_mean, + "baseline_std": baseline_std.item(), + "speedup_percent": speedup_percent, + "q_dtype": str(q_quant_dtype), + "kv_cache_dtype": str(kv_quant_dtype), + "output_dtype": str(o_quant_dtype), + "block_size": block_size, + "num_kv_heads": num_kv_heads, + "head_size": head_size, + "max_seq_len": max_seq_len, + } + + +def write_results_to_csv(results, filename=None): + """Write benchmark results to CSV file.""" + if filename is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" + + fieldnames = [ + "batch_size", + "trtllm_mean", + "trtllm_std", + "baseline_mean", + "baseline_std", + "speedup_percent", + "q_dtype", + "kv_cache_dtype", + "output_dtype", + "block_size", + "num_kv_heads", + "head_size", + "max_seq_len", + ] + + file_exists = os.path.exists(filename) + + with open(filename, "a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + if not file_exists: + writer.writeheader() + + for result in results: + writer.writerow(result) + + print(f"Results written to {filename}") + + +if __name__ == "__main__": + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] + max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + all_results = [] + + dtype = torch.bfloat16 + quant_dtypes = [ + # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) + (None, None, None), + (FP8_DTYPE, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), + ] + + for quant_dtype in quant_dtypes: + q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype + q_quant_dtype = q_quant_dtype or dtype + kv_quant_dtype = kv_quant_dtype or dtype + o_quant_dtype = o_quant_dtype or dtype + + print( + f"Running benchmark for q_dtype = {q_quant_dtype}, " + f"kv_cache_dtype: {kv_quant_dtype}, " + f"output_dtype: {o_quant_dtype}" + ) + print( + "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" + "baseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in batch_sizes: + result = benchmark_prefill( + dtype=dtype, + quant_dtypes=quant_dtype, + batch_size=bs, + max_seq_len=max_seq_len, + ) + all_results.append(result) + + # Write all results to CSV + write_results_to_csv(all_results) diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..3a85c5c74d6932ab4403a04bb7a546a49e79314e --- /dev/null +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -0,0 +1,415 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from sglang quantization/tuning_block_wise_kernel.py + +import argparse +import json +import multiprocessing as mp +import os +import time +from datetime import datetime +from typing import Any + +import torch +from tqdm import tqdm + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _w8a8_triton_block_scaled_mm, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser + +mp.set_start_method("spawn", force=True) + +assert current_platform.is_cuda() or current_platform.is_rocm(), ( + "Only support tune w8a8 block fp8 kernel on CUDA/ROCm device." +) + +DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "half": torch.half, + "bfloat16": torch.bfloat16, +} + + +def w8a8_block_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + config: dict[str, Any], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with + block-wise quantization. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. + It should be 2-dim, e.g., [128, 128]. + output_dtype: The dtype of the returned tensor. + + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + if A.dtype == torch.float8_e4m3fn: + kernel = _w8a8_triton_block_scaled_mm + else: + raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") + + kernel[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C + + +def get_configs_compute_bound(): + configs = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def get_weight_shapes(tp_size): + # NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. + # Modify them, if you tune for another different model. + # cannot TP + total = [ + (512 + 64, 7168), + (2112, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + ] + # N can TP + n_tp = [ + (18432 * 2, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (24576, 1536), + (12288, 7168), + (4096, 7168), + ] + # K can TP + k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)] + + weight_shapes = [] + for t in total: + weight_shapes.append(t) + for n_t in n_tp: + new_t = (n_t[0] // tp_size, n_t[1]) + weight_shapes.append(new_t) + for k_t in k_tp: + new_t = (k_t[0], k_t[1] // tp_size) + weight_shapes.append(new_t) + return weight_shapes + + +def benchmark_config( + A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10 +): + def run(): + w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) + + torch.cuda.synchronize() + # JIT complication & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune(M, N, K, block_size, out_dtype, search_space, input_type): + factor_for_scale = 1e-2 + + if input_type == "fp8": + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) + A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) + B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + else: + raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") + * factor_for_scale + ) + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + A, + B, + As, + Bs, + block_size, + config, + out_dtype, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + block_n, + block_k, + configs, + save_path, + input_type="fp8", +) -> None: + os.makedirs(save_path, exist_ok=True) + device_name = current_platform.get_device_name().replace(" ", "_") + json_file_name = ( + f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8," + f"block_shape=[{block_n},{block_k}].json" + ) + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def tune_on_gpu(args_dict): + """Run tuning on a specific GPU.""" + gpu_id = args_dict["gpu_id"] + batch_sizes = args_dict["batch_sizes"] + weight_shapes = args_dict["weight_shapes"] + args = args_dict["args"] + + torch.cuda.set_device(gpu_id) + print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + block_n = args.block_n + block_k = args.block_k + out_dtype = DTYPE_MAP[args.out_dtype] + save_path = args.save_path + input_type = args.input_type + + search_space = get_configs_compute_bound() + search_space = [ + config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0 + ] + + start = time.time() + for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): + N, K = shape[0], shape[1] + print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune( + batch_size, + N, + K, + [block_n, block_k], + out_dtype, + search_space, + input_type, + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") + ] + best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)} + save_configs(N, K, block_n, block_k, best_configs, save_path, input_type) + + end = time.time() + print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") + + +def distribute_batch_sizes(batch_sizes, num_gpus): + """Distribute batch sizes across available GPUs.""" + batches_per_gpu = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus) + + process_args = [] + for gpu_id in range(num_gpus): + process_args.append( + { + "gpu_id": gpu_id, + "batch_sizes": batches_per_gpu[gpu_id], + "weight_shapes": weight_shapes, # Each GPU processes all weight shapes + "args": args, + } + ) + + ctx = mp.get_context("spawn") + with ctx.Pool(num_gpus) as pool: + pool.map(tune_on_gpu, process_args) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description=""" +Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1: + python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8 +Then copy to model_executor/layers/quantization/utils/configs + """, + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("--tp-size", "-tp", type=int, default=8) + parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8") + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="float16", + ) + parser.add_argument("--block-n", type=int, default=128) + parser.add_argument("--block-k", type=int, default=128) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--save-path", type=str, default="./") + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/cpu/benchmark_cpu_attn.py b/benchmarks/kernels/cpu/benchmark_cpu_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..d03b70a9f5034ab74efbfebda83d2f7e31bb4874 --- /dev/null +++ b/benchmarks/kernels/cpu/benchmark_cpu_attn.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools +import time + +import numpy as np +import torch + +from vllm._custom_ops import ( + cpu_attention_with_kv_cache, + cpu_attn_get_scheduler_metadata, + cpu_attn_reshape_and_cache, +) +from vllm.platforms import CpuArchEnum, current_platform +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed +from vllm.v1.attention.backends.cpu_attn import CPUAttentionBackend, _get_attn_isa + + +def get_attn_isa( + block_size: int | None = None, + dtype: torch.dtype | None = None, +): + if block_size and dtype: + return _get_attn_isa(dtype, block_size) + else: + if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: + return "neon" + elif torch._C._cpu._is_amx_tile_supported(): + return "amx" + else: + return "vec" + + +# rand number generation takes too much time, cache rand tensors +@functools.lru_cache(maxsize=128, typed=False) +def tensor_cache( + elem_num: int, + dtype: torch.dtype, +) -> torch.Tensor: + tensor = torch.randn(elem_num, dtype=dtype) + return tensor + + +@torch.inference_mode() +def main( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: int = None, + dtype: torch.dtype = torch.bfloat16, + block_size: int = 128, + num_blocks: int = 4096, + use_sink: bool = False, + enable_kv_split: bool = False, + isa: str | None = None, + seed: int = 0, + iters: int = 20, +) -> None: + set_random_seed(seed) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) + scale = head_size**-0.5 + token_num = sum(query_lens) + + if isa is None: + isa = get_attn_isa(block_size, dtype) + + s_aux = ( + 15 * torch.rand((num_query_heads,), dtype=torch.bfloat16) if use_sink else None + ) + + query = tensor_cache( + elem_num=token_num * num_query_heads * head_size, + dtype=dtype, + ) + query = query.view( + token_num, + num_query_heads, + head_size, + ) + + key_value = tensor_cache( + elem_num=2 * num_blocks * num_kv_heads * block_size * head_size, + dtype=dtype, + ) + key_value = key_value.view( + 2, + num_blocks, + block_size, + num_kv_heads, + head_size, + ) + key_cache, value_cache = key_value.unbind(0) + + # KV cache for CPU attention + packed_key_cache = torch.empty( + num_blocks, num_kv_heads, block_size, head_size, dtype=dtype + ) + packed_value_cache = torch.empty_like(packed_key_cache) + + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) + + # use reshape_and_cache to pack key_cache and value_cache + slot_mapping = torch.arange(0, num_blocks * block_size, dtype=torch.int64) + cpu_attn_reshape_and_cache( + key=key_cache.view(-1, num_kv_heads, head_size), + value=value_cache.view(-1, num_kv_heads, head_size), + key_cache=packed_key_cache, + value_cache=packed_value_cache, + slot_mapping=slot_mapping, + isa=isa, + ) + + metadata = cpu_attn_get_scheduler_metadata( + num_reqs=num_seqs, + num_heads=num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + seq_lens=kv_lens_tensor, + dtype=dtype, + query_start_loc=cu_query_lens, + causal=True, + sliding_window_size=sliding_window if sliding_window is not None else -1, + isa=isa, + enable_kv_split=enable_kv_split, + ) + + out_with_split = torch.empty_like(query) + + def run_benchmark(iters: int) -> list[float]: + times = [] + for _ in range(iters): + start_time = time.perf_counter_ns() + cpu_attention_with_kv_cache( + query=query, + key_cache=packed_key_cache, + value_cache=packed_value_cache, + output=out_with_split, + query_start_loc=cu_query_lens, + seq_lens=kv_lens_tensor, + scale=scale, + causal=True, + alibi_slopes=None, + sliding_window=window_size, + block_table=block_tables, + softcap=0, + scheduler_metadata=metadata, + s_aux=s_aux, + ) + end_time = time.perf_counter_ns() + times.append((end_time - start_time) / 1e6) + return times + + # warmup + run_benchmark(5) + # benchmark + times = run_benchmark(iters) + + time_min = min(times) + time_max = max(times) + time_mean = np.mean(times) + time_std = np.std(times) + + print("\tmin (ms) = ", time_min) + print("\tmax (ms) = ", time_max) + print("\tmean (ms) = ", time_mean) + print("\tstd = ", time_std) + print("\tmedian (ms) = ", np.median(times)) + + +def generate_seq_lens( + batch_size: int, + q_len_min: int, + q_len_max: int, + kv_len_min: int, + kv_len_max: int, + seed: int = 0, +) -> list[tuple[int, int]]: + assert 1 <= q_len_min <= q_len_max + assert 1 <= kv_len_min <= kv_len_max + assert kv_len_max >= q_len_min + + g = torch.Generator(device="cpu").manual_seed(seed) + + def rint(lo: int, hi: int) -> int: + return torch.randint(lo, hi + 1, (1,), generator=g).item() + + seq_lens: list[tuple[int, int]] = [] + for _ in range(batch_size): + # ensure q <= kv + kv = rint(max(kv_len_min, q_len_min), kv_len_max) + q = rint(q_len_min, min(q_len_max, kv)) + seq_lens.append((q, kv)) + + return seq_lens + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--q-len-min", type=int, default=512) + parser.add_argument("--q-len-max", type=int, default=512) + parser.add_argument("--kv-len-min", type=int, default=512) + parser.add_argument("--kv-len-max", type=int, default=512) + parser.add_argument("--num-blocks", type=int, default=4096) + + parser.add_argument("--sliding-window", type=int, default=None) + parser.add_argument("--num-query-heads", type=int, default=32) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument( + "--head-size", + type=int, + choices=CPUAttentionBackend.get_supported_head_sizes(), + default=128, + ) + parser.add_argument("--enable-kv-split", action="store_true") + parser.add_argument("--block-size", type=int, choices=[32, 64, 128], default=128) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" + ) + parser.add_argument("--use-sink", action="store_true") + parser.add_argument( + "--isa", type=str, choices=["vec", "neon", "amx", "vec16"], default=None + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--iters", type=int, default=20) + + args = parser.parse_args() + print(args) + + seq_lens = generate_seq_lens( + args.batch_size, + args.q_len_min, + args.q_len_max, + args.kv_len_min, + args.kv_len_max, + args.seed, + ) + + print("batch (query len, kv len) = ", seq_lens) + + main( + seq_lens=seq_lens, + num_heads=(args.num_query_heads, args.num_kv_heads), + head_size=args.head_size, + sliding_window=args.sliding_window, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + block_size=args.block_size, + num_blocks=args.num_blocks, + use_sink=args.use_sink, + enable_kv_split=args.enable_kv_split, + isa=args.isa + if args.isa is not None + else get_attn_isa(args.block_size, STR_DTYPE_TO_TORCH_DTYPE[args.dtype]), + seed=args.seed, + iters=args.iters, + ) diff --git a/benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py b/benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..df6a9c60a7e06732e924574ef3d6382b4b52ec2a --- /dev/null +++ b/benchmarks/kernels/cpu/benchmark_cpu_fused_moe.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import sys +import time + +import numpy as np +import torch + +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.torch_utils import set_random_seed + +# Check if CPU MoE operations are available +try: + from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight +except (ImportError, AttributeError) as e: + print("ERROR: CPU fused MoE operations are not available on this platform.") + print("This benchmark requires x86 CPU with proper vLLM CPU extensions compiled.") + print( + "The cpu_fused_moe kernel is typically available on Linux x86_64 " + "with AVX2/AVX512." + ) + print(f"Import error: {e}") + sys.exit(1) + +# ISA selection following test_cpu_fused_moe.py pattern +ISA_CHOICES = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"] + + +@torch.inference_mode() +def main( + batch_size: int, + expert_num: int, + hidden_size: int, + intermediate_size: int, + topk_num: int, + use_bias: bool = False, + dtype: torch.dtype = torch.bfloat16, + activation: str = "silu", + isa: str = "vec", + seed: int = 0, + iters: int = 20, +) -> None: + set_random_seed(seed) + # up_dim = 2 * intermediate_size for gate + up projection + up_dim = 2 * intermediate_size + + input_tensor = torch.randn((batch_size, hidden_size), dtype=dtype) / ( + 0.5 * hidden_size**0.5 + ) + + w13 = torch.randn((expert_num, up_dim, hidden_size), dtype=dtype) / ( + 0.5 * hidden_size**0.5 + ) + w2 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype) / ( + 0.5 * intermediate_size**0.5 + ) + + w13_bias = None + w2_bias = None + if use_bias: + w13_bias = torch.randn((expert_num, up_dim), dtype=dtype) / (0.5 * up_dim**0.5) + w2_bias = torch.randn((expert_num, hidden_size), dtype=dtype) / ( + 0.5 * hidden_size**0.5 + ) + + router_logits = torch.randn((batch_size, expert_num), dtype=dtype) + score = torch.softmax(router_logits, dim=-1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(score, topk_num) + topk_ids = topk_ids.to(torch.int32) + + packed_w13 = cpu_prepack_moe_weight(w13, isa) + packed_w2 = cpu_prepack_moe_weight(w2, isa) + + def run_benchmark(iters: int) -> list[float]: + times = [] + for _ in range(iters): + start_time = time.perf_counter_ns() + _ = cpu_fused_moe( + input_tensor, + packed_w13, + packed_w2, + w13_bias, + w2_bias, + topk_weights, + topk_ids, + activation, + isa, + ) + end_time = time.perf_counter_ns() + times.append((end_time - start_time) / 1e6) + return times + + # warmup + run_benchmark(5) + # benchmark + times = run_benchmark(iters) + + if not times: + print("No iterations to measure. Set --iters > 0.") + return + + time_min = min(times) + time_max = max(times) + time_mean = np.mean(times) + time_std = np.std(times) + + print("\tmin (ms) = ", time_min) + print("\tmax (ms) = ", time_max) + print("\tmean (ms) = ", time_mean) + print("\tstd = ", time_std) + print("\tmedian (ms) = ", np.median(times)) + + # Calculate throughput metrics + # FLOPs estimation: 2 * batch * topk * (hidden * up_dim + intermediate * hidden) + flops_per_token = ( + 2 * topk_num * (hidden_size * up_dim + intermediate_size * hidden_size) + ) + total_flops = batch_size * flops_per_token + tflops = total_flops / (time_mean * 1e-3) / 1e12 + print(f"\tthroughput (TFLOP/s) = {tflops:.4f}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the CPU fused MoE kernel.") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--expert-num", type=int, default=8) + parser.add_argument("--hidden-size", type=int, default=2880) + parser.add_argument("--intermediate-size", type=int, default=2880) + parser.add_argument( + "--topk-num", + type=int, + default=None, + help="Number of experts to route each token to (default: expert_num // 2)", + ) + parser.add_argument("--use-bias", action="store_true") + parser.add_argument( + "--activation", + type=str, + choices=["silu", "swigluoai"], + default="silu", + help="Activation function", + ) + parser.add_argument( + "--isa", + type=str, + choices=ISA_CHOICES, + default=ISA_CHOICES[0], + help=f"ISA to use (available: {ISA_CHOICES})", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--iters", type=int, default=20) + + args = parser.parse_args() + + # Default topk_num to expert_num // 2, minimum 1 + topk_num = ( + args.topk_num if args.topk_num is not None else max(args.expert_num // 2, 1) + ) + + print(args) + + main( + batch_size=args.batch_size, + expert_num=args.expert_num, + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + topk_num=topk_num, + use_bias=args.use_bias, + dtype=torch.bfloat16, # Following test_cpu_fused_moe.py + activation=args.activation, + isa=args.isa, + seed=args.seed, + iters=args.iters, + ) diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a28c6956be0e95e992a8da83f4dad577860fdc1a --- /dev/null +++ b/benchmarks/kernels/deepgemm/README.md @@ -0,0 +1,129 @@ +# DeepSeek DeepGEMM Kernels Benchmark + +This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels. + +Currently, this just includes dense GEMMs and only works on Hopper GPUs. + +## Setup + +You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory: + +```bash +git clone --recursive https://github.com/deepseek-ai/DeepGEMM +cd DeepGEMM +python setup.py install +uv pip install -e . +``` + +## Usage + +```console +python benchmark_fp8_block_dense_gemm.py +INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda. +===== STARTING FP8 GEMM BENCHMARK ===== +PyTorch version: 2.5.1+cu124 +CUDA version: 12.4 +Triton version: 3.1.0 +Using device: NVIDIA H100 80GB HBM3 +WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. +WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. +INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. + +===== PERFORMANCE COMPARISON ===== + +DeepGEMM Implementation: ++------+-------+-------+-----------+--------+--------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | ++------+-------+-------+-----------+--------+--------+ +| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 | +| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 | +| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 | +| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 | +| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 | +| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 | +| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 | +| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 | +| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 | +| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 | +| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 | +| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 | +| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 | +| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 | +| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 | +| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 | ++------+-------+-------+-----------+--------+--------+ + +vLLM Triton Implementation: ++------+-------+-------+-----------+--------+--------+--------------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | ++------+-------+-------+-----------+--------+--------+--------------+ +| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster | +| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower | +| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower | +| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower | +| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower | +| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower | +| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster | +| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster | +| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower | +| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster | +| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower | +| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower | +| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower | +| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower | +| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower | +| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower | ++------+-------+-------+-----------+--------+--------+--------------+ + +vLLM CUTLASS Implementation: ++------+-------+-------+-----------+--------+--------+--------------+--------------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton | ++------+-------+-------+-----------+--------+--------+--------------+--------------+ +| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster | +| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster | +| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster | +| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster | +| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster | +| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster | +| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster | +| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster | +| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster | +| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster | +| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster | +| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster | +| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster | +| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster | +| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster | +| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster | ++------+-------+-------+-----------+--------+--------+--------------+--------------+ + +===== AVERAGE PERFORMANCE ===== ++----------------+------------+----------+---------------+ +| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) | ++----------------+------------+----------+---------------+ +| DeepGEMM | 310.98 | 1052.10 | 0.11 | +| vLLM Triton | 144.30 | 715.60 | 0.23 | +| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 | ++----------------+------------+----------+---------------+ + +===== AVERAGE SPEEDUPS ===== ++-----------------------------+--------------+ +| Comparison | Speedup | ++-----------------------------+--------------+ +| DeepGEMM vs vLLM Triton | 1.71x faster | +| DeepGEMM vs vLLM CUTLASS | 0.94x slower | +| vLLM CUTLASS vs vLLM Triton | 1.84x faster | ++-----------------------------+--------------+ + +===== ACCURACY COMPARISON ===== ++----------------+-----------------------+ +| Implementation | Avg Diff vs Reference | ++----------------+-----------------------+ +| DeepGEMM | 0.000684 | +| vLLM Triton | 0.000684 | +| vLLM CUTLASS | 0.000684 | ++----------------+-----------------------+ +``` diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5a85526a151e56e680e95fc1d8599c4a335002cd --- /dev/null +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 +import time + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, + w8a8_triton_block_scaled_mm, +) +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import ( + calc_diff, + fp8_gemm_nt, + per_block_cast_to_fp8, +) + + +def benchmark_shape( + m: int, + n: int, + k: int, + warmup: int = 100, + repeat: int = 10000, + verbose: bool = False, +) -> dict: + """Benchmark all implementations for a specific (m, n, k) shape.""" + if verbose: + print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") + + # Create test tensors + A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + + # Reference result in BF16 + torch.cuda.synchronize() + C_ref = A @ B.t() + + # Pre-quantize B for all implementations + # (weights can be pre-quantized offline) + B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True) + B_vllm, B_scale_vllm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True) + + # Block size configuration + block_size = [128, 128] + + # Pre-quantize A for all implementations + A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( + A, block_size[1], column_major_scales=True, tma_aligned_scales=True + ) + C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + A, block_size[1], column_major_scales=True + ) + + # === DeepGEMM Implementation === + def deepgemm_gemm(): + fp8_gemm_nt( + (A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm + ) + return C_deepgemm + + # === vLLM Triton Implementation === + def vllm_triton_gemm(): + return w8a8_triton_block_scaled_mm( + A_vllm, + B_vllm, + A_scale_vllm, + B_scale_vllm, + block_size, + output_dtype=torch.bfloat16, + ) + + # === vLLM CUTLASS Implementation === + def vllm_cutlass_gemm(): + return ops.cutlass_scaled_mm( + A_vllm_cutlass, + B_vllm.T, + scale_a=A_scale_vllm_cutlass, + scale_b=B_scale_vllm.T, + out_dtype=torch.bfloat16, + ) + + # Run correctness check first + if verbose: + print("Running correctness check...") + C_deepgemm = deepgemm_gemm() + C_vllm_triton = vllm_triton_gemm() + C_vllm_cutlass = vllm_cutlass_gemm() + + deepgemm_diff = calc_diff(C_deepgemm, C_ref) + vllm_triton_diff = calc_diff(C_vllm_triton, C_ref) + vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref) + + if verbose: + print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") + print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") + print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") + print( + "vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}" + ) + print( + "vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}" + ) + + # Benchmark implementations + implementations = { + "DeepGEMM": deepgemm_gemm, + "vLLM Triton": vllm_triton_gemm, + "vLLM CUTLASS": vllm_cutlass_gemm, + } + + benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}} + + for name, func in implementations.items(): + # Warmup + for _ in range(warmup): + func() + torch.cuda.synchronize() + + # Timing loop + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + func() + torch.cuda.synchronize() + end = time.time() + + # Calculate timing and TFLOPS + avg_time_ms = (end - start) / repeat * 1000 + avg_time_us = avg_time_ms * 1000 + tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12 + gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3) + + benchmark_results["implementations"][name] = { + "time_ms": avg_time_ms, + "time_us": avg_time_us, + "tflops": tflops, + "gb_s": gb_s, + "diff": { + "DeepGEMM": 0.0 + if name == "DeepGEMM" + else calc_diff(func(), C_deepgemm), + "Reference": deepgemm_diff + if name == "DeepGEMM" + else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff), + }, + } + + if verbose: + print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s") + + # Calculate speedups + baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] + for name, data in benchmark_results["implementations"].items(): + if name != "DeepGEMM": + speedup = baseline / data["time_ms"] + benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup + if verbose: + print( + f"DeepGEMM is {1 / speedup:.2f}x " + f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}" + ) + + vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"] + cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time + benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = ( + cutlass_vs_triton + ) + if verbose: + print( + f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " + f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton" + ) + + return benchmark_results + + +def format_table_row(values, widths): + """Format a row with specified column widths.""" + return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |" + + +def print_table(headers, rows, title=None): + """Print a table with headers and rows.""" + if title: + print(f"\n{title}") + + # Calculate column widths based on headers and data + widths = [ + max(len(str(h)), max(len(str(row[i])) for row in rows)) + for i, h in enumerate(headers) + ] + + # Create separator line + separator = "+-" + "-+-".join("-" * w for w in widths) + "-+" + + # Print table + print(separator) + print(format_table_row(headers, widths)) + print(separator) + for row in rows: + print(format_table_row(row, widths)) + print(separator) + + +def format_speedup(value): + """Format speedup value with indicator if it's faster or slower.""" + return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}" + + +def run_benchmarks(verbose: bool = False): + """Run benchmarks for a set of common shapes.""" + print("===== STARTING FP8 GEMM BENCHMARK =====") + + # Make sure we're using the GPU + if not torch.cuda.is_available(): + print("CUDA not available! Tests require GPU.") + return + + # Print system information + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA version: {torch.version.cuda}") + print(f"Triton version: {triton.__version__}") + print(f"Using device: {torch.cuda.get_device_name()}") + + # Enable TF32 for better performance + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Set seeds for reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Define benchmark shapes (m, n, k) + shapes = [ + (8, 4096, 7168), + (8, 7168, 18432), + (8, 18432, 7168), + (64, 4096, 7168), + (64, 7168, 18432), + (64, 18432, 7168), + (64, 24576, 1536), + (64, 32768, 512), + (64, 7168, 16384), + (128, 4096, 7168), + (128, 7168, 18432), + (128, 18432, 7168), + (1024, 4096, 7168), + (1024, 18432, 7168), + (2048, 4096, 7168), + (4096, 4096, 7168), + ] + shapes = [ + # (64, 2112, 7168), + (64, 24576, 1536), + (64, 32768, 512), + (64, 7168, 16384), + (64, 4096, 7168), + (64, 7168, 2048), + # (128, 2112, 7168), + (128, 24576, 1536), + (128, 32768, 512), + (128, 7168, 16384), + (128, 4096, 7168), + (128, 7168, 2048), + # (4096, 2112, 7168), + (4096, 24576, 1536), + (4096, 32768, 512), + (4096, 7168, 16384), + (4096, 4096, 7168), + (4096, 7168, 2048), + ] + + all_results = [] + for m, n, k in shapes: + result = benchmark_shape(m, n, k, verbose=verbose) + all_results.append(result) + + # Print results in a nicely formatted table + print("\n===== PERFORMANCE COMPARISON =====") + + # Print DeepGEMM table + deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"] + deepgemm_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["DeepGEMM"] + deepgemm_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + ] + ) + + print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:") + + # Print vLLM Triton table + triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"] + triton_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["vLLM Triton"] + speedup = impl_data.get("speedup_vs_deepgemm", 1.0) + triton_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(speedup), + ] + ) + + print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:") + + # Print vLLM CUTLASS table + cutlass_headers = [ + "m", + "n", + "k", + "Time (μs)", + "TFLOPS", + "GB/s", + "vs DeepGEMM", + "vs Triton", + ] + cutlass_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["vLLM CUTLASS"] + vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) + vs_triton = impl_data.get("speedup_vs_triton", 1.0) + cutlass_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(vs_deepgemm), + format_speedup(vs_triton), + ] + ) + + print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:") + + # Calculate and print averages + print("\n===== AVERAGE PERFORMANCE =====") + + implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] + avg_metrics = { + impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations + } + + for result in all_results: + for impl in implementations: + impl_data = result["implementations"][impl] + avg_metrics[impl]["tflops"] += impl_data["tflops"] + avg_metrics[impl]["gb_s"] += impl_data["gb_s"] + avg_metrics[impl]["time_ms"] += impl_data["time_ms"] + + num_shapes = len(all_results) + avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"] + avg_rows = [] + + for impl in implementations: + avg_tflops = avg_metrics[impl]["tflops"] / num_shapes + avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes + avg_time = avg_metrics[impl]["time_ms"] / num_shapes + avg_rows.append( + [impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"] + ) + + print_table(avg_headers, avg_rows) + + # Calculate average speedups + avg_speedups = { + "DeepGEMM vs vLLM Triton": 0, + "DeepGEMM vs vLLM CUTLASS": 0, + "vLLM CUTLASS vs vLLM Triton": 0, + } + + for result in all_results: + deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] + vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"] + + avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time + avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time + avg_speedups["vLLM CUTLASS vs vLLM Triton"] += ( + vllm_triton_time / vllm_cutlass_time + ) + + print("\n===== AVERAGE SPEEDUPS =====") + speedup_headers = ["Comparison", "Speedup"] + speedup_rows = [] + for comparison, total in avg_speedups.items(): + avg_speedup = total / num_shapes + status = "faster" if avg_speedup > 1 else "slower" + speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"]) + + print_table(speedup_headers, speedup_rows) + + # Average accuracy comparison + print("\n===== ACCURACY COMPARISON =====") + avg_diff = {impl: 0 for impl in implementations} + + for result in all_results: + for impl in implementations: + avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"] + + diff_headers = ["Implementation", "Avg Diff vs Reference"] + diff_rows = [] + for impl in implementations: + diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"]) + + print_table(diff_headers, diff_rows) + + +if __name__ == "__main__": + run_benchmarks(verbose=False) diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..6964a3d3e0824d6ec93d6dff012b79cc56f7433e --- /dev/null +++ b/benchmarks/kernels/graph_machete_bench.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +import pickle +from collections import defaultdict + +import matplotlib.pyplot as plt +import pandas as pd +import regex as re +import seaborn as sns +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.utils.argparse_utils import FlexibleArgumentParser + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the latency of processing a single batch of " + "requests till completion." + ) + parser.add_argument("filename", type=str) + + args = parser.parse_args() + + with open(args.filename, "rb") as f: + data = pickle.load(f) + raw_results: list[TMeasurement] = data["results"] + + results = defaultdict(lambda: list()) + for v in raw_results: + result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) + if result is not None: + KN = result.group(1) + else: + raise Exception("MKN not found") + result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label) + if result is not None: + M = result.group(1) + else: + raise Exception("MKN not found") + + kernel = v.task_spec.description + results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median}) + + rows = int(math.ceil(len(results) / 2)) + fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) + axs = axs.flatten() + for axs_idx, (shape, data) in enumerate(results.items()): + plt.sca(axs[axs_idx]) + df = pd.DataFrame(data) + sns.lineplot( + data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2", + ) + plt.title(f"Shape: {shape}") + plt.ylabel("time (median, s)") + plt.tight_layout() + plt.savefig("graph_machete_bench.pdf") diff --git a/benchmarks/kernels/requirements.txt b/benchmarks/kernels/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1411a4a0b5ab886adfb744e685d150151ab10023 --- /dev/null +++ b/benchmarks/kernels/requirements.txt @@ -0,0 +1 @@ +pandas \ No newline at end of file diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a9af811bbe9ca9e0b0f66c493f27bcc890dc3515 --- /dev/null +++ b/benchmarks/kernels/utils.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from collections.abc import Callable, Iterable +from typing import Any + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + + +@dataclasses.dataclass +class CudaGraphBenchParams: + num_ops_in_cuda_graph: int + + +@dataclasses.dataclass +class ArgPool: + """ + When some argument of the benchmarking function is annotated with this type, + the benchmarking class (BenchMM) will collapse the argument to a pick a + single value from the given list of values, during function invocation. + For every invocation during a benchmarking run, it will choose a + different value from the list. + """ + + values: Iterable[Any] + + def __getitem__(self, index): + return self.values[index] + + +class Bench: + class ArgsIterator: + def __init__(self, args_list, kwargs_list): + assert len(args_list) == len(kwargs_list) + self.args_list = args_list + self.kwargs_list = kwargs_list + self.n = len(self.args_list) + self.idx = 0 + + def __next__(self): + while True: + yield (self.args_list[self.idx], self.kwargs_list[self.idx]) + self.idx += 1 + self.idx = self.idx % self.n + + def reset(self): + self.idx = 0 + + @property + def n_args(self): + return self.n + + def __init__( + self, + cuda_graph_params: CudaGraphBenchParams | None, + label: str, + sub_label: str, + description: str, + fn: Callable, + *args, + **kwargs, + ): + self.cuda_graph_params = cuda_graph_params + self.use_cuda_graph = self.cuda_graph_params is not None + self.label = label + self.sub_label = sub_label + self.description = description + self.fn = fn + + # Process args + self._args = args + self._kwargs = kwargs + self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list) + + # Cudagraph runner + self.g = None + if self.use_cuda_graph: + self.g = self.get_cuda_graph_runner() + + # benchmark run params + self.min_run_time = 1 + + def collapse_argpool(self, *args, **kwargs): + argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [ + arg for arg in kwargs.values() if isinstance(arg, ArgPool) + ] + if len(argpool_args) == 0: + return [args], [kwargs] + + # Make sure all argpools are of the same size + argpool_size = len(argpool_args[0].values) + assert all([argpool_size == len(arg.values) for arg in argpool_args]) + + # create copies of the args + args_list = [] + kwargs_list = [] + for _ in range(argpool_size): + args_list.append(args) + kwargs_list.append(kwargs.copy()) + + for i in range(argpool_size): + # collapse args; Just pick the ith value + args_list[i] = tuple( + [arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]] + ) + + # collapse kwargs + kwargs_i = kwargs_list[i] + arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)] + for k in arg_pool_keys: + # again just pick the ith value + kwargs_i[k] = kwargs_i[k][i] + kwargs_list[i] = kwargs_i + + return args_list, kwargs_list + + def get_cuda_graph_runner(self): + assert self.use_cuda_graph + assert self.args_iterator is not None + + num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph + + # warmup + args_it = self.args_iterator.__next__() + for _ in range(2): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(num_graph_ops): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + return g + + def run_cudagrah(self) -> TMeasurement: + assert self.use_cuda_graph + globals = {"g": self.g} + + return TBenchmark.Timer( + stmt="g.replay()", + globals=globals, + label=( + f"{self.label}" + f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops" + ), + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run_eager(self) -> TMeasurement: + setup = None + stmt = None + globals = None + + has_arg_pool = self.args_iterator.n_args > 1 + if has_arg_pool: + setup = """ + args_iterator.reset() + args_it = args_iterator.__next__() + """ + stmt = """ + args, kwargs = next(args_it) + fn(*args, **kwargs) + """ + globals = {"fn": self.fn, "args_iterator": self.args_iterator} + else: + # no arg pool. Just use the args and kwargs directly + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + args, kwargs = next(args_it) + + setup = "" + stmt = """ + fn(*args, **kwargs) + """ + globals = {"fn": self.fn, "args": args, "kwargs": kwargs} + + return TBenchmark.Timer( + stmt=stmt, + setup=setup, + globals=globals, + label=self.label, + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run(self) -> TMeasurement: + timer = None + if self.use_cuda_graph: # noqa SIM108 + timer = self.run_cudagrah() + else: + timer = self.run_eager() + if not timer.meets_confidence() or timer.has_warnings: + print("Doesn't meet confidence - re-running bench ...") + return self.run() + return timer + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + print(f"exc type {exc_type}") + print(f"exc value {exc_value}") + print(f"exc traceback {traceback}") diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..9a057990bda5f64deada11b0beb56c0207570de5 --- /dev/null +++ b/benchmarks/kernels/weight_shapes.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "meta-llama/Llama-3.1-405b-hf": [ + ([16384, 18432], 1), + ([16384, 16384], 0), + ([16384, 106496], 1), + ([53248, 16384], 0), + ], + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 73728], 1), + ([36864, 12288], 0), + ], +} diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fa3fa0513e8f2221378ecf09531aed4f5b99b3a4 --- /dev/null +++ b/benchmarks/multi_turn/README.md @@ -0,0 +1,178 @@ +# Benchmark KV Cache Offloading with Multi-Turn Conversations + +The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt` + +First start serving your model + +```bash +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ + +vllm serve $MODEL_PATH --served-model-name Llama +``` + +The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface). + +## Synthetic Multi-Turn Conversations + +Download the following text file (used for generation of synthetic conversations) + +```bash +wget https://www.gutenberg.org/ebooks/1184.txt.utf-8 +mv 1184.txt.utf-8 pg1184.txt +``` + +The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`). + +But you may use other text files if you prefer (using this specific file is not required). + +Then run the benchmarking script + +```bash +export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ + +python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \ +--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6 +``` + +You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.). + +If successful, you will see the following output + +```bash +---------------------------------------------------------------------------------------------------- +Statistics summary: +runtime_sec = 215.810 +requests_per_sec = 0.769 +---------------------------------------------------------------------------------------------------- + count mean std min 25% 50% 75% 90% 99% max +ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54 +tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05 +latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94 +input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00 +input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00 +output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00 +output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00 +---------------------------------------------------------------------------------------------------- +``` + +If you run with `--warmup-step`, the summary will also include `warmup_runtime_sec` +and `total_runtime_incl_warmup_sec` (while `runtime_sec` continues to reflect the +benchmark-only runtime so the reported throughput stays comparable). + +### JSON configuration file for synthetic conversations generation + +The input flag `--input-file` is used to determine the input conversations for the benchmark.
+When the input is a JSON file with the field `"filetype": "generate_conversations"` the tool will generate synthetic multi-turn (questions and answers) conversations. + +The file `generate_multi_turn.json` is an example file. + +The file must contain the sections `prompt_input` and `prompt_output`. + +The `prompt_input` section must contain `num_turns`, `prefix_num_tokens` and `num_tokens`: + +* `num_turns` - Number of total turns in the conversation (both user & assistant).
+The final value will always be rounded to an even number so each user turn has a reply. +* `prefix_num_tokens` - Tokens added at the start of only the **first user turn** in a conversation (unique per conversation). +* `num_tokens` - Total token length of each **user** message (one turn). + +The `prompt_output` section must contain `num_tokens`: + +* `num_tokens` - Total token length of each **assistant** message (one turn). + +### Random distributions for synthetic conversations generation + +When creating an input JSON file (such as `generate_multi_turn.json`),
+every numeric field (such as `num_turns` or `num_tokens`) requires a distribution.
+The distribution determines how to randomly sample values for the field. + +The available distributions are listed below. + +**Note:** The optional `max` field (for lognormal, zipf, and poisson) can be used to cap sampled values at an upper bound.
+Can be used to make sure that the total number of tokens in every request does not exceed `--max-model-len`. + +#### constant + +```json +{ + "distribution": "constant", + "value": 500 +} +``` + +* `value` - the fixed integer value (always returns the same number). + +#### uniform + +```json +{ + "distribution": "uniform", + "min": 12, + "max": 18 +} +``` + +* `min` - minimum value (inclusive). +* `max` - maximum value (inclusive), should be equal or larger than min. + +#### lognormal + +```json +{ + "distribution": "lognormal", + "average": 1000, + "max": 5000 +} +``` + +You can parameterize the lognormal distribution in one of two ways: + +Using the average and optional median ratio: + +* `average` - target average value of the distribution. +* `median_ratio` - the ratio of the median to the average; controls the skewness. Must be in the range (0, 1). + +Using the parameters of the underlying normal distribution: + +* `mean` - mean of the underlying normal distribution. +* `sigma` - standard deviation of the underlying normal distribution. + +#### zipf + +```json +{ + "distribution": "zipf", + "alpha": 1.2, + "max": 100 +} +``` + +* `alpha` - skew parameter (> 1). Larger values produce stronger skew toward smaller integers. + +#### poisson + +```json +{ + "distribution": "poisson", + "alpha": 10, + "max": 50 +} +``` + +* `alpha` - expected value (λ). Also the variance of the distribution. + +## ShareGPT Conversations + +To run with the ShareGPT data, download the following ShareGPT dataset: +`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json` + +Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py` + +```bash +python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128 +``` + +The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles. + +The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed). + +Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`. diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb8a2f386a9715c06617ec6afafc79dca3cec2f --- /dev/null +++ b/benchmarks/multi_turn/bench_dataset.py @@ -0,0 +1,600 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from statistics import mean +from typing import Any, NamedTuple + +import numpy as np # type: ignore +import pandas as pd # type: ignore +from bench_utils import ( + TEXT_SEPARATOR, + Color, + logger, +) +from tqdm import tqdm +from transformers import AutoTokenizer # type: ignore + +# Conversation ID is a string (e.g: "UzTK34D") +ConvId = str + +# A list of dicts (dicts with keys "id" and "messages") +ShareGptConversations = list[dict[str, Any]] + +# A list of dicts (dicts with keys "role" and "content") +MessagesList = list[dict[str, str]] + +# Map conversation ID to conversation messages +ConversationsMap = list[ConvId, MessagesList] + + +class Distribution(ABC): + @abstractmethod + def sample(self, size: int = 1) -> np.ndarray: + pass + + +class UniformDistribution(Distribution): + def __init__( + self, + min_val: int | float, + max_val: int | float, + is_integer: bool = True, + ) -> None: + self.min_val = min_val + self.max_val = max_val + self.is_integer = is_integer + + def sample(self, size: int = 1) -> np.ndarray: + if self.is_integer: + return np.random.randint( + int(self.min_val), int(self.max_val + 1), size=size + ) + else: + return np.random.uniform(self.min_val, self.max_val, size=size) + + def __repr__(self) -> str: + return f"UniformDistribution[{self.min_val}, {self.max_val}]" + + +class ConstantDistribution(Distribution): + def __init__(self, value: int | float) -> None: + self.value = value + self.max_val = value + + def sample(self, size: int = 1) -> np.ndarray: + return np.full(shape=size, fill_value=self.value) + + def __repr__(self) -> str: + return f"Constant[{self.value}]" + + +class ZipfDistribution(Distribution): + def __init__(self, alpha: float, max_val: int | None = None) -> None: + self.alpha = alpha + self.max_val = max_val + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.zipf(self.alpha, size=size) + if self.max_val: + samples = np.minimum(samples, self.max_val) + return samples + + def __repr__(self) -> str: + return f"ZipfDistribution[{self.alpha}]" + + +class PoissonDistribution(Distribution): + def __init__(self, alpha: float, max_val: int | None = None) -> None: + self.alpha = alpha + self.max_val = max_val + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.poisson(self.alpha, size=size) + if self.max_val: + samples = np.minimum(samples, self.max_val) + return samples + + def __repr__(self) -> str: + return f"PoissonDistribution[{self.alpha}]" + + +class LognormalDistribution(Distribution): + def __init__( + self, + mean: float | None = None, + sigma: float | None = None, + average: int | None = None, + median_ratio: float | None = None, + max_val: int | None = None, + ) -> None: + self.average = average + self.median_ratio = median_ratio + self.max_val = max_val + + if average is not None: + if average < 1: + raise ValueError("Lognormal average must be positive") + + if mean or sigma: + raise ValueError( + "When using lognormal average, you can't provide mean/sigma" + ) + + if self.median_ratio is None: + # Default value that provides relatively wide range of values + self.median_ratio = 0.85 + + # Calculate mean/sigma of np.random.lognormal based on the average + mean, sigma = self._generate_lognormal_by_median( + target_average=self.average, median_ratio=self.median_ratio + ) + else: + if mean is None or sigma is None: + raise ValueError( + "Must provide both mean and sigma if average is not used" + ) + + if mean <= 0 or sigma < 0: + raise ValueError( + "Lognormal mean must be positive and sigma must be non-negative" + ) + + # Mean and standard deviation of the underlying normal distribution + # Based on numpy.random.lognormal + self.mean = mean + self.sigma = sigma + + @staticmethod + def _generate_lognormal_by_median( + target_average: int, median_ratio: float + ) -> tuple[float, float]: + """ + Compute (mu, sigma) for a lognormal distribution given: + - a target average (mean of the distribution) + - a ratio of median / mean (controls skewness), assume mean > median + + Background: + If Z ~ Normal(mu, sigma^2), then X = exp(Z) ~ LogNormal(mu, sigma). + * mean(X) = exp(mu + sigma^2 / 2) + * median(X) = exp(mu) + + So: + median / mean = exp(mu) / exp(mu + sigma^2 / 2) + = exp(-sigma^2 / 2) + + Rearranging: + sigma^2 = 2 * ln(mean / median) + mu = ln(median) + + This gives a unique (mu, sigma) for any valid mean and median. + """ + # Check input validity: median must be smaller than mean + if median_ratio <= 0 or median_ratio >= 1: + raise ValueError("median_ratio must be in range (0, 1)") + + target_median = target_average * median_ratio + + # Solve sigma^2 = 2 * ln(mean / median) + sigma = np.sqrt(2 * np.log(target_average / target_median)) + mu = np.log(target_median) + + return mu, sigma + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size) + + if self.average is not None: + # Scale to average + samples *= self.average / samples.mean() + + if self.max_val: + samples = np.minimum(samples, self.max_val) + + return np.round(samples).astype(int) + + def __repr__(self) -> str: + if self.average: + return ( + f"LognormalDistribution[{self.average}, " + f"{self.median_ratio}, {self.max_val}]" + ) + return f"LognormalDistribution[{self.mean}, {self.sigma}, {self.max_val}]" + + +class GenConvArgs(NamedTuple): + num_conversations: int + text_files: list[str] + input_num_turns: Distribution + input_common_prefix_num_tokens: Distribution + input_prefix_num_tokens: Distribution + input_num_tokens: Distribution + output_num_tokens: Distribution + print_stats: bool + + +def verify_field_exists( + conf: dict, field_name: str, section: str, subsection: str +) -> None: + if field_name not in conf: + raise ValueError( + f"Missing field '{field_name}' in {section=} and {subsection=}" + ) + + +def get_random_distribution( + conf: dict, section: str, subsection: str, optional: bool = False +) -> Distribution: + # section can be "prompt_input" or "prompt_output" (both required) + conf = conf[section] + + if optional and subsection not in conf: + # Optional subsection, if not found assume the value is always 0 + return ConstantDistribution(0) + + # subsection can be "num_turns", "num_tokens" or "prefix_num_tokens" + if subsection not in conf: + raise ValueError(f"Missing subsection {subsection} in section {section}") + + conf = conf[subsection] + + distribution = conf.get("distribution") + if distribution is None: + raise ValueError( + f"Missing field 'distribution' in {section=} and {subsection=}" + ) + + if distribution == "constant": + verify_field_exists(conf, "value", section, subsection) + return ConstantDistribution(conf["value"]) + + elif distribution == "zipf": + verify_field_exists(conf, "alpha", section, subsection) + max_val = conf.get("max", None) + return ZipfDistribution(conf["alpha"], max_val=max_val) + + elif distribution == "poisson": + verify_field_exists(conf, "alpha", section, subsection) + max_val = conf.get("max", None) + return PoissonDistribution(conf["alpha"], max_val=max_val) + + elif distribution == "lognormal": + max_val = conf.get("max", None) + + if "average" in conf: + # Infer lognormal mean/sigma (numpy) from input average + median_ratio = conf.get("median_ratio", None) + return LognormalDistribution( + average=conf["average"], median_ratio=median_ratio, max_val=max_val + ) + + # Use mean/sigma directly (for full control over the distribution) + verify_field_exists(conf, "mean", section, subsection) + verify_field_exists(conf, "sigma", section, subsection) + return LognormalDistribution( + mean=conf["mean"], sigma=conf["sigma"], max_val=max_val + ) + + elif distribution == "uniform": + verify_field_exists(conf, "min", section, subsection) + verify_field_exists(conf, "max", section, subsection) + + min_value = conf["min"] + max_value = conf["max"] + + assert min_value > 0 + assert min_value <= max_value + + is_integer = isinstance(min_value, int) and isinstance(max_value, int) + return UniformDistribution(min_value, max_value, is_integer) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + +def parse_input_json_file(conf: dict) -> GenConvArgs: + # Validate the input file + assert isinstance(conf, dict) + required_fields = [ + "filetype", + "num_conversations", + "text_files", + "prompt_input", + "prompt_output", + ] + for field in required_fields: + assert field in conf, f"Missing field {field} in input {conf}" + + assert conf["filetype"] == "generate_conversations" + + assert conf["num_conversations"] > 0, "num_conversations should be larger than zero" + + text_files = conf["text_files"] + + assert isinstance(text_files, list), "Field 'text_files' should be a list" + assert len(text_files) > 0, ( + "Field 'text_files' should be a list with at least one file" + ) + + # Parse the parameters for the prompt input/output workload + input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns") + input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens") + input_common_prefix_num_tokens = get_random_distribution( + conf, "prompt_input", "common_prefix_num_tokens", optional=True + ) + input_prefix_num_tokens = get_random_distribution( + conf, "prompt_input", "prefix_num_tokens" + ) + output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens") + + print_stats: bool = conf.get("print_stats", False) + assert isinstance(print_stats, bool), ( + "Field 'print_stats' should be either 'true' or 'false'" + ) + + args = GenConvArgs( + num_conversations=conf["num_conversations"], + text_files=text_files, + input_num_turns=input_num_turns, + input_common_prefix_num_tokens=input_common_prefix_num_tokens, + input_prefix_num_tokens=input_prefix_num_tokens, + input_num_tokens=input_num_tokens, + output_num_tokens=output_num_tokens, + print_stats=print_stats, + ) + return args + + +def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None: + # Collect statistics + conv_stats: list[dict[Any, Any]] = [] + req_stats: list[int] = [] + + print("\nCollecting statistics...") + for messages in conversations.values(): + # messages is a list of dicts + user_tokens: list[int] = [] + assistant_tokens: list[int] = [] + request_tokens: list[int] = [] + + req_tokens = 0 + for m in messages: + content = m["content"] + num_tokens = len(tokenizer(content).input_ids) + + if m["role"] == "user": + user_tokens.append(num_tokens) + # New user prompt including all chat history + req_tokens += num_tokens + request_tokens.append(req_tokens) + + elif m["role"] == "assistant": + assistant_tokens.append(num_tokens) + # Update assistant answer + # (will be part of chat history for the next user prompt) + req_tokens += num_tokens + + item_stats = { + "conversation_turns": len(messages), + "user_tokens": mean(user_tokens), + "assistant_tokens": mean(assistant_tokens), + } + + conv_stats.append(item_stats) + req_stats.extend(request_tokens) + + # Print statistics + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99] + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + df = pd.DataFrame(conv_stats) + print(df.describe(percentiles=percentiles).transpose()) + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Request statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + df = pd.DataFrame(req_stats, columns=["request_tokens"]) + print(df.describe(percentiles=percentiles).transpose()) + print(TEXT_SEPARATOR) + + +def generate_conversations( + args: GenConvArgs, tokenizer: AutoTokenizer +) -> ConversationsMap: + # Text for all user prompts + # (text from the input text files will be appended to this line) + base_prompt_text = "Please rewrite the following text and add more content: " + base_prompt_token_count = len( + tokenizer.encode(base_prompt_text, add_special_tokens=False) + ) + + logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}") + logger.info(args) + + list_of_tokens = [] + + for filename in args.text_files: + # Load text file that will be used to generate prompts + with open(filename) as file: + data = file.read() + tokens_in_file = tokenizer.encode(data, add_special_tokens=False) + list_of_tokens.extend(tokens_in_file) + logger.info( + f"Loaded {len(tokens_in_file)} tokens from file {filename}, " + f"total tokens so far: {len(list_of_tokens)}" + ) + + conversations: ConversationsMap = {} + conv_id = 0 + + # Generate number of turns for every conversation + turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations) + + # Turn count should be at least 2 (one user prompt and one assistant answer) + turn_count = np.maximum(turn_count, 2) + + # Round up to an even number (every user prompt should have an answer) + turn_count = turn_count + (turn_count % 2) + + # Generate number of prefix tokens for every conversation + conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample( + args.num_conversations + ) + + # Used to reduce shared text between conversations + # (jump/skip over text sections between conversations) + base_offset = 0 + + # Common prefix size for all conversations (only 1 sample required) + common_prefix_text = "" + common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0] + if common_prefix_tokens > 0: + # Using "." at the end to separate sentences + common_prefix_text = ( + tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "." + ) + base_offset += common_prefix_tokens + + for conv_id in tqdm( + range(args.num_conversations), + total=args.num_conversations, + desc="Generating conversations", + unit="conv", + ): + # Generate a single conversation + messages: MessagesList = [] + + nturns = turn_count[conv_id] + + # User prompt token count per turn (with lower limit) + input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int) + input_token_count = np.maximum(input_token_count, base_prompt_token_count) + + # Assistant answer token count per turn (with lower limit) + output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype( + int + ) + output_token_count = np.maximum(output_token_count, 1) + + user_turn = True + for turn_id in range(nturns): + if user_turn: + role = "user" + num_tokens = input_token_count[turn_id] + + # Generate the user prompt, + # use a unique prefix (the conv_id) for each conversation + # (to avoid shared prefix between conversations) + content = f"{conv_id} is a nice number... " + + if len(common_prefix_text) > 0 and turn_id == 0: + content = common_prefix_text + content + + # Update the number of tokens left for the content + num_tokens -= len(tokenizer.encode(content, add_special_tokens=False)) + + if turn_id == 0: + prefix_num_tokens = conv_prefix_tokens[conv_id] + if prefix_num_tokens > 0: + # Add prefix text (context) to the first turn + start_offset = base_offset + end_offset = start_offset + prefix_num_tokens + assert len(list_of_tokens) > end_offset, ( + "Not enough input text to generate " + f"{prefix_num_tokens} tokens for the " + f"prefix text ({start_offset=}, {end_offset=})" + ) + + content += f"{conv_id}, " + tokenizer.decode( + list_of_tokens[start_offset:end_offset] + ) + base_offset += prefix_num_tokens + + # Add the actual user prompt/question after the prefix text + content += base_prompt_text + num_tokens -= base_prompt_token_count + + if num_tokens > 0: + # Add text from the input file (to reach the desired token count) + start_offset = base_offset + turn_id * input_token_count.max() + end_offset = start_offset + num_tokens + assert len(list_of_tokens) > end_offset, ( + f"Not enough input text to generate {num_tokens} tokens " + f"for the prompt ({start_offset=}, {end_offset=})" + ) + + # Convert tokens back to text + content += tokenizer.decode(list_of_tokens[start_offset:end_offset]) + else: + role = "assistant" + # This content will not be used as input to the LLM server + # (actual answers will be used instead). + # Content is only required to determine the min_tokens/max_tokens + # (inputs to the LLM server). + num_tokens = output_token_count[turn_id] + assert len(list_of_tokens) > num_tokens, ( + f"Not enough input text to generate {num_tokens} " + "tokens for assistant content" + ) + content = tokenizer.decode(list_of_tokens[:num_tokens]) + + # Append the user/assistant message to the list of messages + messages.append({"role": role, "content": content}) + user_turn = not user_turn + + # Add the new conversation + conversations[f"CONV_ID_{conv_id}"] = messages + + # Increase base offset for the next conversation + base_offset += nturns + + if args.print_stats: + print_conv_stats(conversations, tokenizer) + + return conversations + + +def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap: + conversations: ConversationsMap = {} + + for item in input_list: + conv_id: str = item["id"] + assert isinstance(conv_id, str) + + assert conv_id not in conversations, ( + f"Conversation ID {conv_id} found more than once in the input" + ) + + messages: MessagesList = item["messages"] + assert isinstance(messages, list), ( + f"Conversation messages should be a list (ID: {conv_id})" + ) + assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})" + + conversations[conv_id] = messages + + logger.info(f"Using {len(conversations)} unique conversations (IDs)") + assert len(conversations) == len(input_list) + + # Print statistics about the selected conversations + stats: list[dict[str, Any]] = [] + for conv_data in conversations.values(): + stats.append({"num_turns": len(conv_data)}) + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999] + conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles) + print(conv_stats.transpose()) + print(TEXT_SEPARATOR) + + return conversations + + +def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations: + output: ShareGptConversations = [] + for conv_id, conv_data in input_dict.items(): + new_item = {"id": conv_id, "messages": conv_data} + output.append(new_item) + + return output diff --git a/benchmarks/multi_turn/bench_utils.py b/benchmarks/multi_turn/bench_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e959a4be711c9cc0ed7f2981927d12799cbf9c7f --- /dev/null +++ b/benchmarks/multi_turn/bench_utils.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +from enum import Enum + + +class Color(Enum): + RED = "\033[91m" + GREEN = "\033[92m" + BLUE = "\033[94m" + PURPLE = "\033[95m" + CYAN = "\033[96m" + YELLOW = "\033[93m" + RESET = "\033[0m" + + def __str__(self): + return self.value + + +TEXT_SEPARATOR = "-" * 100 + +# Configure the logger +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] - %(message)s", + datefmt="%d-%m-%Y %H:%M:%S", +) +logger = logging.getLogger(__name__) diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py new file mode 100644 index 0000000000000000000000000000000000000000..e23f6b923f1b9fa4835c7274d6fa825c90aad225 --- /dev/null +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -0,0 +1,1666 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import asyncio +import json +import logging +import multiprocessing as mp +import os +import random +import time +from collections import Counter, deque +from datetime import datetime +from enum import Enum +from http import HTTPStatus +from statistics import mean +from typing import NamedTuple + +import aiohttp # type: ignore +import numpy as np # type: ignore +import pandas as pd # type: ignore +from bench_dataset import ( + ConversationsMap, + ConvId, + GenConvArgs, + MessagesList, + ShareGptConversations, + conversations_dict_to_list, + conversations_list_to_dict, + generate_conversations, + parse_input_json_file, +) +from bench_utils import TEXT_SEPARATOR, Color, logger +from transformers import AutoTokenizer # type: ignore + +NUM_TOKENS_FROM_DATASET = 0 +TERM_SIGNAL = None + + +class ConversationSampling(str, Enum): + ROUND_ROBIN = "round_robin" + RANDOM = "random" + + def __str__(self): + return self.value + + +class ClientArgs(NamedTuple): + seed: int + max_num_requests: int | None + skip_first_turn: bool + max_turns: int | None + max_active_conversations: int + verbose: bool + print_content: bool + verify_output: bool + conversation_sampling: ConversationSampling + request_rate: float + max_retries: int + + +class RequestArgs(NamedTuple): + chat_url: str + model: str + stream: bool + limit_min_tokens: int # Use negative value for no limit + limit_max_tokens: int # Use negative value for no limit + timeout_sec: int + + +class BenchmarkArgs(NamedTuple): + url: str + num_clients: int + early_stop: bool + + +class ServerResponse(NamedTuple): + valid: bool + ttft_ms: float # time to first chunk + tpot_ms: float # time per output chunk (one or more tokens) + latency_ms: float + start_time_ms: float + first_chunk: str # first chunk of the content + content: str # includes the first_chunk + num_chunks: int + + def __str__(self) -> str: + return f"ttft_ms {self.ttft_ms:.2f}, tpot_ms {self.tpot_ms:.2f}, latency_ms {self.latency_ms:.2f}" # noqa: E501 + + +class RequestStats(NamedTuple): + ttft_ms: float + tpot_ms: float + latency_ms: float + start_time_ms: float + input_num_turns: int + input_num_tokens: int + output_num_tokens: int + output_num_chunks: int + output_num_first_chunk_tokens: int + approx_cached_percent: float + conversation_id: str + client_id: int + + def __str__(self) -> str: + return ( + f"ttft_ms {self.ttft_ms:.2f}, tpot_ms {self.tpot_ms:.2f}, latency_ms {self.latency_ms:.2f}, input_num_tokens {self.input_num_tokens}, " # noqa: E501 + f"output_num_tokens {self.output_num_tokens} ({self.output_num_chunks} chunks, {self.output_num_first_chunk_tokens} tokens in first chunk), " # noqa: E501 + f"approx_cached_percent {self.approx_cached_percent:.2f}%" + ) + + +class MetricStats: + def __init__(self) -> None: + self.min: float | None = None + self.max: float | None = None + self.avg: float | None = None + self.sum = 0.0 + self.count = 0 + + def update(self, value: float) -> None: + if self.min is None: + self.min = value + else: + self.min = min(self.min, value) + + if self.max is None: + self.max = value + else: + self.max = max(self.max, value) + + self.sum += value + self.count += 1 + self.avg = self.sum / self.count + + def __repr__(self) -> str: + if self.count == 0: + return "no data" + return f"avg: {self.avg:>10.3f}, min: {self.min:>10.3f}, max: {self.max:>10.3f}" + + +class MovingAverage: + def __init__(self, window_size: int) -> None: + self.window_size = window_size + self.window = np.zeros(window_size) + self.index = 0 + self.sum = 0.0 + self.count = 0 + self.avg: float | None = None + + def update(self, new_value: float) -> None: + if self.count < self.window_size: + # Filling up the window + self.sum += new_value + self.window[self.count] = new_value + self.count += 1 + else: + # Window is full, start replacing old values + old_value = self.window[self.index] + self.sum = self.sum - old_value + new_value + self.window[self.index] = new_value + self.index = (self.index + 1) % self.window_size + + self.avg = self.sum / self.count + + def __repr__(self) -> str: + if self.count == 0: + return "no data" + return f"avg: {self.avg:>10.3f} ({self.count} samples)" + + +class DebugStats: + def __init__(self, logger: logging.Logger, window_size: int) -> None: + self.logger = logger + self.metrics: dict[str, MovingAverage | MetricStats] = { + "moving_avg_ttft_ms": MovingAverage(window_size), + "moving_avg_tpot_ms": MovingAverage(window_size), + "ttft_ms": MetricStats(), + "tpot_ms": MetricStats(), + "latency_ms": MetricStats(), + "input_num_turns": MetricStats(), + "input_num_tokens": MetricStats(), + "output_num_tokens": MetricStats(), + } + + def update(self, data: RequestStats) -> None: + self.metrics["ttft_ms"].update(data.ttft_ms) + self.metrics["moving_avg_ttft_ms"].update(data.ttft_ms) + self.metrics["tpot_ms"].update(data.tpot_ms) + self.metrics["moving_avg_tpot_ms"].update(data.tpot_ms) + self.metrics["latency_ms"].update(data.latency_ms) + self.metrics["input_num_turns"].update(data.input_num_turns) + self.metrics["input_num_tokens"].update(data.input_num_tokens) + self.metrics["output_num_tokens"].update(data.output_num_tokens) + + def print(self) -> None: + self.logger.info("-" * 50) + for k, v in self.metrics.items(): + kv_info = f"[{k:25}] {v}" + self.logger.info(kv_info) + self.logger.info("-" * 50) + + +def nanosec_to_millisec(value: float) -> float: + return value / 1000000.0 + + +def nanosec_to_sec(value: float) -> float: + return value / 1000000000.0 + + +async def send_request( + session: aiohttp.ClientSession, + messages: list[dict[str, str]], + chat_url: str, + model: str, + stream: bool = True, + min_tokens: int | None = None, + max_tokens: int | None = None, + timeout_sec: int = 120, +) -> ServerResponse: + payload = { + "model": model, + "messages": messages, + "seed": 0, + "temperature": 0.0, + } + + if stream: + payload["stream"] = True + payload["stream_options"] = {"include_usage": False} + + if min_tokens is not None: + payload["min_tokens"] = min_tokens + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + + headers = {"Content-Type": "application/json"} + + # Calculate the timeout for the request + if max_tokens is not None: + # Assume TPOT of 200ms and use max_tokens to determine timeout + token_based_timeout = int(max_tokens * 0.2) + if token_based_timeout > timeout_sec: + timeout_sec = token_based_timeout + logger.info( + "Using timeout of %ds based on max_tokens %d", + timeout_sec, + max_tokens, + ) + timeout = aiohttp.ClientTimeout(total=timeout_sec) + + valid_response = True + ttft: float | None = None + chunk_delay: list[int] = [] + latency: float | None = None + first_chunk = "" + generated_text = "" + + start_time: int = time.perf_counter_ns() + most_recent_timestamp: int = start_time + + async with session.post( + url=chat_url, json=payload, headers=headers, timeout=timeout + ) as response: + http_status = HTTPStatus(response.status) + if http_status == HTTPStatus.OK: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk == "[DONE]": + # End of stream + latency = time.perf_counter_ns() - start_time + elif stream is False: + data = json.loads(chunk) + message = data["choices"][0]["message"] + assert message["role"] == "assistant" + generated_text += message["content"] + else: + timestamp: int = time.perf_counter_ns() + data = json.loads(chunk) + + # Delta is the new content/text/data + delta = data["choices"][0]["delta"] + if delta.get("content", None): + if ttft is None: + # First token + first_token_time = time.perf_counter_ns() + ttft = first_token_time - start_time + first_chunk = delta["content"] + else: + # Decoding phase + chunk_delay.append(timestamp - most_recent_timestamp) + + generated_text += delta["content"] + + most_recent_timestamp = timestamp + else: + valid_response = False + content = await response.text() + logger.warning( + f"{Color.YELLOW}Received HTTP status {http_status.value} " + f"({http_status.phrase}): {content}{Color.RESET}" + ) + + if latency is None: + latency = -1.0 + if valid_response: + # Streaming is disabled, latency was not set + latency = time.perf_counter_ns() - start_time + + if ttft is None: + # The response was a single chunk + ttft = latency + + # Each chunk may include more than one token + tpot: float = mean(chunk_delay) if len(chunk_delay) > 0 else 0.0 + num_chunks: int = len(chunk_delay) + + sr = ServerResponse( + valid=valid_response, + ttft_ms=nanosec_to_millisec(ttft) if ttft > 0.0 else -1.0, + tpot_ms=nanosec_to_millisec(tpot), + latency_ms=nanosec_to_millisec(latency), + start_time_ms=nanosec_to_millisec(start_time), + first_chunk=first_chunk, + content=generated_text, + num_chunks=num_chunks, + ) + return sr + + +def get_short_string(input: str) -> str: + n = 20 + if len(input) < 400: + return input + + return f"{input[:n]}...{input[-n:]}" + + +def get_token_count(tokenizer: AutoTokenizer, text: str) -> int: + return len(tokenizer(text, add_special_tokens=False).input_ids) + + +def get_messages_token_count( + tokenizer: AutoTokenizer, messages: list[dict[str, str]] +) -> int: + token_count = 0 + for m in messages: + token_count += get_token_count(tokenizer, m["content"]) + + return token_count + + +async def send_turn( + session: aiohttp.ClientSession, + client_id: int, + conv_id: str, + conversation_messages: MessagesList, + messages_to_use: int, + tokenizer: AutoTokenizer, + req_args: RequestArgs, + verbose: bool, + verify_output: bool, +) -> RequestStats | None: + assert messages_to_use > 0 + assert messages_to_use <= len(conversation_messages) + + messages = conversation_messages[:messages_to_use] + + # Index of the next message (the role should be "user") + index = messages_to_use - 1 + + # Verify that the message has only two keys, "role" and "content" + assert len(messages[index].keys()) == 2 + assert "role" in messages[index] and "content" in messages[index] + assert messages[index]["role"] == "user", ( + f"Failed on conversation ID {conv_id}, message role should be user" + ) + + if verbose: + print( + f"{Color.CYAN}Messages (conversation ID {conv_id}," + f" {len(messages)} turns):{Color.RESET}", + messages, + ) + + # None means that there is no upper/lower limit for the output token count + min_tokens = None if req_args.limit_min_tokens < 0 else req_args.limit_min_tokens + max_tokens = None if req_args.limit_max_tokens < 0 else req_args.limit_max_tokens + + if len(conversation_messages) > messages_to_use: + # The conversation contains an assistant answer for the next user prompt + if ( + min_tokens == NUM_TOKENS_FROM_DATASET + or max_tokens == NUM_TOKENS_FROM_DATASET + ): + # Compute number of tokens in the answer (from the input conversation) + assistant_answer = conversation_messages[messages_to_use] + answer_num_tokens = get_token_count(tokenizer, assistant_answer["content"]) + assert assistant_answer["role"] == "assistant" + + if min_tokens == NUM_TOKENS_FROM_DATASET: + min_tokens = max(1, answer_num_tokens) + + if max_tokens == NUM_TOKENS_FROM_DATASET: + max_tokens = max(1, answer_num_tokens) + + # Send the current conversation to LLM and get a response + response: ServerResponse = await send_request( + session, + messages, + req_args.chat_url, + req_args.model, + req_args.stream, + min_tokens, + max_tokens, + req_args.timeout_sec, + ) + + if response.valid is False: + # Request failed + return None + + # Compute number of tokens in input / output + input_num_tokens = get_messages_token_count(tokenizer, messages) + + # Num tokens in the user's last question + question_num_tokens = get_token_count(tokenizer, messages[index]["content"]) + + # Num tokens in the history/context of the question + assert input_num_tokens >= question_num_tokens + history_num_tokens = input_num_tokens - question_num_tokens + + # Num tokens in the LLM's answer (first chunk and full answer) + first_chunk_tokens = get_token_count(tokenizer, response.first_chunk) + + output_content = response.content + output_num_tokens = get_token_count(tokenizer, output_content) + + # Prefix caching approximated cached percent + approx_cached_percent = ( + 100.0 * (history_num_tokens / input_num_tokens) if input_num_tokens > 0 else 0.0 + ) + + # Compute the correct TTFT and TPOT (based on tokens and not chunks). + # Required because multiple output tokens may be bundled in a single chunk. + if output_num_tokens > 1 and output_num_tokens > first_chunk_tokens: + # More than one token and more than one chunk in the output + decode_ms = response.latency_ms - response.ttft_ms + decode_num_tokens = output_num_tokens - first_chunk_tokens + tpot_ms = decode_ms / decode_num_tokens + else: + # In this case: output_num_tokens == first_chunk_tokens + # Output was a single chunk (output_num_tokens > 1) + # or even a single token (output_num_tokens == 1) + tpot_ms = 0.0 + + if first_chunk_tokens > 1: + # First chunk had multiple tokens, adjust TTFT for a single token + delta_ms = (first_chunk_tokens - 1) * tpot_ms + ttft_ms = max(0.1, response.ttft_ms - delta_ms) + else: + # First chunk had only one token + ttft_ms = response.ttft_ms + + rs = RequestStats( + ttft_ms=ttft_ms, + tpot_ms=tpot_ms, + latency_ms=response.latency_ms, + start_time_ms=response.start_time_ms, + input_num_turns=len(messages), + input_num_tokens=input_num_tokens, + output_num_tokens=output_num_tokens, + output_num_chunks=response.num_chunks, + output_num_first_chunk_tokens=first_chunk_tokens, + approx_cached_percent=approx_cached_percent, + conversation_id=conv_id, + client_id=client_id, + ) + + if verbose: + print( + f"\n{Color.YELLOW}Response ({output_num_tokens} tokens):{Color.RESET}", + output_content, + ) + print(f"{Color.YELLOW}Response metrics: {rs}{Color.RESET}") + print("-" * 70) + + # Save the LLM's answer (will be used as part of the context for the next user turn) + answer_index = messages_to_use + if len(conversation_messages) > answer_index: + assert conversation_messages[answer_index]["role"] == "assistant", ( + f"Failed on conversation ID {conv_id}, message role should be assistant" + ) + + orig_content = conversation_messages[answer_index]["content"] + if verify_output: + # Compare the new answer to the answer from the input file + debug_info = ( + f"LLM/dataset answers do not match ({conv_id}):" + f"\n'{get_short_string(output_content)}' (len: {len(output_content)})," + f"\n'{get_short_string(orig_content)}' (len: {len(orig_content)})" + ) + if orig_content != output_content: + raise ValueError(debug_info) + + # Update the answer + conversation_messages[answer_index]["content"] = output_content + else: + # A user prompt that has no answer, add the answer as a new message + new_answer = {"role": "assistant", "content": output_content} + conversation_messages.append(new_answer) + + return rs + + +async def poisson_sleep(request_rate: float, verbose: bool = False) -> None: + # Generate a random time interval from the Poisson distribution + assert request_rate > 0 + + interval = np.random.exponential(1.0 / request_rate) + if verbose: + logger.info(f"Sleeping for {interval:.3f} seconds...") + await asyncio.sleep(interval) + + +async def exponential_backoff_sleep( + attempt_cnt: int, + base_rate: float = 1.0, + backoff_factor: float = 2.0, + jitter_fraction: float = 0.10, + verbose: bool = False, +) -> None: + # Sleep with exponential backoff and jitter after a failed request. + backoff_delay = base_rate * (backoff_factor**attempt_cnt) + jittered_delay = backoff_delay * ( + 1 + np.random.uniform(-jitter_fraction, jitter_fraction) + ) + + if verbose: + logger.info(f"Backoff for {jittered_delay:.3f} seconds...") + + await asyncio.sleep(jittered_delay) + + +async def client_main( + args: ClientArgs, + req_args: RequestArgs, + client_id: int, + tokenizer: AutoTokenizer, + stop_event: mp.Event, # type: ignore + task_queue: mp.Queue, + result_queue: mp.Queue, + conv_queue: mp.Queue, +) -> None: + logger.info( + f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501 + ) + + # Set unique seed per client (each client runs in its own process) + # Add 1 to ensure no client uses the same seed as the main process + client_seed = args.seed + client_id + 1 + random.seed(client_seed) + np.random.seed(client_seed) + + # Active conversations + active_convs: ConversationsMap = {} + conv_id_queue: deque = deque(maxlen=args.max_active_conversations) + + # Keep track of how many messages have been used for each conversation + turns_count: Counter = Counter() + num_successes = 0 + num_failures = 0 + + # Track the timestamp (time.perf_counter()) + # of the last turn per conversation (only for debug) + time_of_last_turn: dict[ConvId, float] = {} + + # Flag that indicates that there are no new tasks (conversations) for the client + task_queue_empty = False + + async with aiohttp.ClientSession() as session: + # Print progress + + while task_queue_empty is False: + result = None + + if ( + args.max_num_requests + and num_successes + num_failures == args.max_num_requests + ): + logger.info( + f"{Color.YELLOW}Client {client_id} reached " + f"request limit{Color.RESET}" + ) + break + + if stop_event.is_set(): # type: ignore + logger.info( + f"{Color.YELLOW}Client {client_id} received " + f"a termination signal{Color.RESET}" + ) + break + + while ( + len(active_convs) < args.max_active_conversations + and task_queue_empty is False + ): + # Get a new conversation from the task queue + conv_id, messages = task_queue.get() + + if conv_id is TERM_SIGNAL: + task_queue_empty = True + break + + if args.skip_first_turn: + # Skip the first turn (both user and assistant), + # relevant if warmup was enabled. + # Default turns_count[conv_id] will be zero if conv_id + # was never inserted/updated in turns_count. + turns_count[conv_id] += 2 + + if turns_count[conv_id] < len(messages): + # Add new conversation + active_convs[conv_id] = messages + conv_id_queue.append(conv_id) + + if args.verbose: + logger.info( + f"{Color.GREEN}Client {client_id} will use conversation ID {conv_id} (active conversations {len(active_convs)}){Color.RESET}" # noqa: E501 + ) + + elif args.verbose: + # No more messages (conversation finished during the warmup) + logger.info( + f"{Color.YELLOW}Client {client_id} will not use conversation ID {conv_id} (all {len(messages)} messages already sent){Color.RESET}" # noqa: E501 + ) + + if len(active_convs) == 0 or task_queue_empty: + logger.info( + f"{Color.YELLOW}Client {client_id} has no more work{Color.RESET}" + ) + break + + # Pick an active conversation for the next request + if args.conversation_sampling == ConversationSampling.ROUND_ROBIN: + conv_id = conv_id_queue.pop() + else: + # ConversationSampling.RANDOM + active_ids = list(active_convs.keys()) + conv_id = random.choice(active_ids) + + messages = active_convs[conv_id] + assert isinstance(messages, list) and len(messages) > 0 + + # Update the amount of messages to use + turns_count[conv_id] += 1 + current_turn = turns_count[conv_id] + + assert current_turn < len(messages), ( + f"Turn number {current_turn} is invalid for conversation ID {conv_id}" + f" that has only {len(messages)} messages" + ) + + if args.verbose: + curr_time_sec: float = time.perf_counter() + time_since_last_turn: str | float = "N/A" + if conv_id in time_of_last_turn: + time_since_last_turn = round( + curr_time_sec - time_of_last_turn[conv_id], 3 + ) + logger.info( + f"Client {client_id} using conversation ID {conv_id} (turn: {current_turn}, time since last turn [sec]: {time_since_last_turn})" # noqa: E501 + ) + time_of_last_turn[conv_id] = curr_time_sec + + success = False + for attempt_cnt in range(args.max_retries + 1): + try: + exception = False + result = await send_turn( + session, + client_id, + conv_id, + messages, + current_turn, + tokenizer, + req_args, + args.print_content, + args.verify_output, + ) + if result is not None: + result_queue.put(result) + success = True + break + else: + logger.warning( + f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + except asyncio.exceptions.TimeoutError: + exception = True + logger.error( + "%sClient %d - Timeout during conversation ID %s (turn: %d). " + "Base timeout is %ss (set with --request-timeout-sec), but the " + "effective timeout may be longer based on max_tokens. If this " + "is unexpected, consider increasing the timeout or checking " + "model performance.%s", + Color.RED, + client_id, + conv_id, + current_turn, + req_args.timeout_sec, + Color.RESET, + ) + except Exception: + exception = True + logger.exception( + f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + + # Sleep before retry if not last attempt + if not success and attempt_cnt < args.max_retries: + await exponential_backoff_sleep(attempt_cnt, verbose=args.verbose) + + if not success: + num_failures += 1 + # Remove the conversation (should not be used again) + active_convs.pop(conv_id) + if exception: + break # Exit gracefully instead of raising an error + + else: + num_successes += 1 + + # Update the turns counter to include the LLM response + # The LLM response will be used as context for the next user turn + turns_count[conv_id] += 1 + + max_turns = len(messages) + if args.max_turns is not None: + # Limit the number of turns in the conversation + max_turns = min(args.max_turns, max_turns) + + if turns_count[conv_id] >= max_turns: + # Conversation has no more turns (no longer active) + # save the updated conversation (with the LLM server's answer) + conv_queue.put((conv_id, active_convs.pop(conv_id))) + if args.verbose: + logger.info( + f"{Color.GREEN}Client {client_id} finished " + f"conversation ID {conv_id}{Color.RESET}" + ) + else: + # Conversation is not finished, insert it at the back of the queue + conv_id_queue.appendleft(conv_id) + + # Sleep between requests (if lambda is positive) + if args.request_rate > 0: + await poisson_sleep(args.request_rate, args.verbose) + + # Send indication that the client is done + conv_queue.put((TERM_SIGNAL, TERM_SIGNAL)) + + logger.info( + f"{Color.CYAN}Client {client_id} is done " + f"({num_successes=}, {num_failures=}){Color.RESET}" + ) + + +def worker_function( + client_id: int, + tokenizer: AutoTokenizer, + client_args: ClientArgs, + req_args: RequestArgs, + stop_event: mp.Event, # type: ignore + task_queue: mp.Queue, + result_queue: mp.Queue, + conv_queue: mp.Queue, +) -> None: + asyncio.run( + client_main( + client_args, + req_args, + client_id, + tokenizer, + stop_event, + task_queue, + result_queue, + conv_queue, + ) + ) + + +def get_client_config( + args: argparse.Namespace, input_conv: ConversationsMap +) -> tuple[ClientArgs, RequestArgs]: + if args.num_clients < 1: + raise ValueError("Number of clients must be a positive number") + + if len(input_conv) < args.num_clients: + raise ValueError( + "Number of conversations must be equal or larger than the number of clients" + ) + + max_req_per_client: int | None = None + if args.max_num_requests is not None: + # Max number of requests per client + req_per_client = args.max_num_requests // args.num_clients + if req_per_client < 1: + raise ValueError("Number of requests should be at least one per client") + max_req_per_client = req_per_client + + max_active_conversations = args.max_active_conversations + if max_active_conversations is None: + # Each client will have only one active conversation at a time + max_active_conversations = args.num_clients + + if max_active_conversations > len(input_conv): + raise ValueError( + f"Max active conversations {max_active_conversations} " + "must be equal or less than the total number of conversations" + ) + + # Max number of active conversations per client + max_active_conv_per_client = max_active_conversations // args.num_clients + if max_active_conv_per_client < 1: + raise ValueError( + f"Max active conversations {max_active_conversations} " + "must be equal or greater than the number of clients" + ) + + # Skip the first user turn (as part of the warmup) + skip_first_turn = args.warmup_step + + # Common arguments for all clients + client_args = ClientArgs( + seed=args.seed, + max_num_requests=max_req_per_client, + skip_first_turn=skip_first_turn, + max_turns=args.max_turns, + max_active_conversations=max_active_conv_per_client, + verbose=args.verbose, + print_content=args.print_content, + verify_output=args.verify_output, + conversation_sampling=args.conversation_sampling, + request_rate=args.request_rate, + max_retries=args.max_retries, + ) + + if args.limit_min_tokens > 0 or args.limit_max_tokens > 0: + if args.limit_min_tokens < 1 or args.limit_max_tokens < 1: + raise ValueError( + "Invalid min/max tokens limits (both limits should be provided)" + ) + if args.limit_min_tokens > args.limit_max_tokens: + raise ValueError( + "Invalid min/max tokens limits (min should not be larger than max)" + ) + + if args.request_timeout_sec <= 0: + raise ValueError("Request timeout must be a positive number") + + # Arguments for API requests + chat_url = f"{args.url}/v1/chat/completions" + model_name = args.served_model_name if args.served_model_name else args.model + + req_args = RequestArgs( + chat_url=chat_url, + model=model_name, + stream=not args.no_stream, + limit_min_tokens=args.limit_min_tokens, + limit_max_tokens=args.limit_max_tokens, + timeout_sec=args.request_timeout_sec, + ) + + return client_args, req_args + + +async def main_mp( + client_args: ClientArgs, + req_args: RequestArgs, + bench_args: BenchmarkArgs, + tokenizer: AutoTokenizer, + input_conv: ConversationsMap, +) -> tuple[ConversationsMap, list[RequestStats]]: + # An event that will trigger graceful termination of all the clients + stop_event = mp.Event() + + # Queue for input conversations (from the input file/dataset) + task_queue: mp.Queue = mp.Queue() + + # Queue for client measurements (TTFT, TPOT, etc. for each request) + result_queue: mp.Queue = mp.Queue() + + # Queue for output conversations (with the LLM answers, sent by the server) + conv_queue: mp.Queue = mp.Queue() + output_conv: ConversationsMap = {} + client_metrics: list[RequestStats] = [] + + # Start all clients + start_time = time.perf_counter_ns() + logger.info(f"{Color.GREEN}Starting {bench_args.num_clients} clients{Color.RESET}") + + clients = [] + for client_id in range(bench_args.num_clients): + client = mp.Process( + name=f"client_{client_id}", + target=worker_function, + args=( + client_id, + tokenizer, + client_args, + req_args, + stop_event, + task_queue, + result_queue, + conv_queue, + ), + ) + clients.append(client) + client.start() + + # Submit all the input conversations as tasks for the clients + for conv_id, messages in input_conv.items(): + task_queue.put((conv_id, messages)) + + # Add termination signals for clients + for _ in range(bench_args.num_clients): + task_queue.put((TERM_SIGNAL, TERM_SIGNAL)) + + # Collect the updated conversations from all clients + num_clients_finished = 0 + total_convs = len(input_conv) + + debug_stats = DebugStats(logger, min(15 * bench_args.num_clients, 500)) + + while num_clients_finished < bench_args.num_clients: + # Collect updated conversation + conv_id, messages = conv_queue.get() + + # Collect results (measurements) + while not result_queue.empty(): + new_data = result_queue.get() + client_metrics.append(new_data) + debug_stats.update(new_data) + + if conv_id is TERM_SIGNAL: + num_clients_finished += 1 + logger.info( + f"{Color.CYAN}{num_clients_finished} out of " + f"{bench_args.num_clients} clients finished{Color.RESET}" + ) + + if bench_args.early_stop and not stop_event.is_set(): + # Once one client finished, stop all other clients. + # there is no reason to continue the benchmark with fewer clients. + logger.info( + f"{Color.YELLOW}Sending termination signal to clients{Color.RESET}" + ) + stop_event.set() + else: + output_conv[conv_id] = messages + + finished_convs = len(output_conv) + percent = finished_convs / total_convs + + # Tuned to control the print rate (can be changed if required) + print_cycle = max(3, int(bench_args.num_clients / 4)) + + if finished_convs % print_cycle == 0: + runtime_sec = nanosec_to_sec(time.perf_counter_ns() - start_time) + logger.info( + f"{Color.CYAN}Finished {finished_convs} out of {total_convs} conversations ({percent:.0%}), " # noqa: E501 + f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501 + ) + + rps: str | float = round(len(client_metrics) / runtime_sec, 3) + if len(client_metrics) < (5 * bench_args.num_clients): + # Do not estimate the RPS if the number of samples is very low + # (threshold can be tuned if needed) + rps = "N/A" + + runtime_left_sec: str | float = round( + (runtime_sec / finished_convs) * (total_convs - finished_convs), 3 + ) + if percent < 0.05: + # If less than 5% of the conversations were not finished, + # the estimation will probably be very inaccurate + # (threshold can be tuned if needed). + runtime_left_sec = "N/A" + + logger.info( + f"{Color.CYAN}Estimated req/sec {rps}, estimated runtime left {runtime_left_sec} sec{Color.RESET}" # noqa: E501 + ) + debug_stats.print() + + logger.info( + f"{Color.CYAN}All {bench_args.num_clients} clients finished{Color.RESET}" + ) + + # At this point all the clients finished, + # collect results (TTFT, TPOT, etc.) from all the clients. + # This needs to happen before calling join on the clients + # (result_queue should be emptied). + while not result_queue.empty(): + client_metrics.append(result_queue.get()) + + logger.info(f"Collected {len(client_metrics)} samples from all the clients") + + # Wait for all clients to finish + for client in clients: + logger.info( + f"{Color.CYAN}Waiting for client {client.name} " + f"(is alive: {client.is_alive()}){Color.RESET}" + ) + + client.join(timeout=req_args.timeout_sec + 1) + + if client.is_alive(): + logger.warning( + f"{Color.YELLOW}Client {client.name} will be terminated{Color.RESET}" + ) + client.terminate() + + exitcode = client.exitcode + if exitcode != 0: + logger.error( + f"{Color.RED}Client {client.name} exited " + f"with exit code {exitcode}{Color.RESET}" + ) + + logger.info( + f"All {bench_args.num_clients} clients exited (successfully " + f"finished {len(output_conv)} out of {total_convs} conversations)" + ) + + # Queues should be closed, required to avoid hang at interpreter shutdown + unfinished_tasks = 0 + while not task_queue.empty(): + task_queue.get() + unfinished_tasks += 1 + + if unfinished_tasks > 0: + # Can happen if not all tasks (conversations) have finished. + # May happen if --max-num-requests was used, + # or if an error occurred in one of the clients. + logger.debug(f"Discarding {unfinished_tasks} unfinished tasks") + + task_queue.close() + task_queue.join_thread() + + result_queue.close() + result_queue.join_thread() + + conv_queue.close() + conv_queue.join_thread() + + return output_conv, client_metrics + + +def get_filename_with_timestamp(label: str, extension: str) -> str: + time_now = datetime.now() + timestamp = time_now.strftime("%d-%m-%Y_%H-%M-%S") + filename = f"{label}__{timestamp}.{extension}" + return filename + + +def process_statistics( + client_metrics: list[RequestStats], + warmup_percentages: list[float], + test_params: dict, + verbose: bool, + gen_conv_args: GenConvArgs | None = None, + excel_output: bool = False, + warmup_runtime_sec: float | None = None, +) -> None: + if len(client_metrics) == 0: + logger.info("No samples to process") + return + + logger.info(f"Processing {len(client_metrics)} samples...") + + raw_data = pd.DataFrame(client_metrics) + + if verbose: + # Calculate the time between user turns in each conversation (in a new column) + raw_data = raw_data.sort_values(by=["conversation_id", "start_time_ms"]) + raw_data["time_between_user_turns_sec"] = raw_data.groupby("conversation_id")[ + "start_time_ms" + ].diff() + + # Convert milliseconds to seconds + raw_data["time_between_user_turns_sec"] = ( + raw_data["time_between_user_turns_sec"] / 1000.0 + ) + + # Final raw data should be sorted by time + raw_data = raw_data.sort_values(by=["start_time_ms"]) + raw_data["end_time_ms"] = raw_data["start_time_ms"] + raw_data["latency_ms"] + + percentiles = [0.25, 0.5, 0.75, 0.9] + + # Add more percentiles if there are enough samples + if len(raw_data) >= 100: + percentiles.append(0.99) + + if len(raw_data) >= 1000: + percentiles.append(0.999) + + if len(raw_data) >= 10000: + percentiles.append(0.9999) + + # Set precision for numbers in the output text (the dataframes) + pd.set_option("display.precision", 2) + + # Exclude parameters from RequestStats + exclude = [ + "start_time_ms", + "end_time_ms", + "output_num_first_chunk_tokens", + "approx_cached_percent", + "conversation_id", + "client_id", + ] + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Parameters:{Color.RESET}") + for k, v in test_params.items(): + print(f"{k}={v}") + + # conversations generation parameters + if gen_conv_args is not None: + gen_params = { + "text_files": ", ".join(gen_conv_args.text_files), + "input_num_turns": str(gen_conv_args.input_num_turns), + "input_common_prefix_num_tokens": str( + gen_conv_args.input_common_prefix_num_tokens + ), + "input_prefix_num_tokens": str(gen_conv_args.input_prefix_num_tokens), + "input_num_tokens": str(gen_conv_args.input_num_tokens), + "output_num_tokens": str(gen_conv_args.output_num_tokens), + } + + print(f"{Color.YELLOW}Conversations Generation Parameters:{Color.RESET}") + for k, v in gen_params.items(): + print(f"{k}={v}") + + print(TEXT_SEPARATOR) + + params_list = [] + df_list = [] + for percent in warmup_percentages: + # Select samples from the end (tail) of the dataframe + warmup_count = int(percent * len(raw_data)) + tail_count = len(raw_data) - warmup_count + if tail_count == 0: + # No reason to process if the count of samples is zero + break + + df = raw_data.tail(tail_count) + + # Runtime is the diff between the end of the last request + # and the start of the first request + runtime_sec = df["end_time_ms"].iloc[-1] - df["start_time_ms"].iloc[0] + + # Convert milliseconds to seconds + runtime_sec = runtime_sec / 1000.0 + requests_per_sec = float(len(df)) / runtime_sec + params = { + "runtime_sec": runtime_sec, + "requests_per_sec": requests_per_sec, + } + if warmup_runtime_sec is not None: + params["warmup_runtime_sec"] = warmup_runtime_sec + params["total_runtime_incl_warmup_sec"] = runtime_sec + warmup_runtime_sec + + # Generate a summary of relevant metrics (and drop irrelevant data) + df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose() + + # List for Excel file + params_list.append(params) + df_list.append(df) + + # Print the statistics summary + if percent > 0 or len(warmup_percentages) > 1: + print( + f"{Color.YELLOW}Statistics summary " + f"(assuming {percent:.0%} warmup samples):{Color.RESET}" + ) + else: + print(f"{Color.YELLOW}Statistics summary:{Color.RESET}") + + for k, v in params.items(): + if isinstance(v, float): + print(f"{k} = {v:.3f}") + else: + print(f"{k} = {v}") + print(TEXT_SEPARATOR) + print(df) + print(TEXT_SEPARATOR) + + if excel_output: + prefix = f"statistics_{test_params['num_clients']}_clients" + filename = get_filename_with_timestamp(prefix, "xlsx") + + with pd.ExcelWriter(filename, engine="xlsxwriter") as writer: + startrow = 0 + test_params_df = pd.DataFrame([test_params]) + test_params_df.to_excel( + writer, sheet_name="Summary", index=False, startrow=startrow + ) + startrow += len(test_params_df) + 3 + + if gen_conv_args is not None: + gen_params_df = pd.DataFrame([gen_params]) + gen_params_df.to_excel( + writer, sheet_name="Summary", index=False, startrow=(startrow - 1) + ) + startrow += len(gen_params_df) + 3 + + for params, df_stats in zip(params_list, df_list): + df_params = pd.DataFrame([params]) + df_params.to_excel( + writer, sheet_name="Summary", index=False, startrow=startrow + ) + startrow += len(df_params) + 2 + df_stats.to_excel( + writer, sheet_name="Summary", index=True, startrow=startrow + ) + startrow += len(df_stats) + 3 + + raw_data.to_excel(writer, sheet_name="Raw data", index=False, startrow=0) + + logger.info( + f"{Color.GREEN}Client metrics exported to file: {filename}{Color.RESET}" + ) + + +async def get_server_info(url: str) -> None: + logger.info(f"{Color.BLUE}Collecting information from server: {url}{Color.RESET}") + async with aiohttp.ClientSession() as session: + # Get server version (not mandatory, "version" endpoint may not exist) + url_version = f"{url}/version" + async with session.get(url_version) as response: + if HTTPStatus(response.status) == HTTPStatus.OK: + text = await response.text() + logger.info(f"{Color.BLUE}Server version: {text}{Color.RESET}") + + # Get available models + url_models = f"{url}/v1/models" + async with session.get(url_models) as response: + if HTTPStatus(response.status) == HTTPStatus.OK: + text = await response.text() + logger.info(f"{Color.BLUE}Models:{Color.RESET}") + models_data = json.loads(text) + models_list = models_data["data"] + for model in models_list: + model_id = model["id"] + max_model_len = model.get("max_model_len", "N/A") + logger.info( + f"{Color.BLUE}\t{model_id=}, {max_model_len=}{Color.RESET}" + ) + else: + logger.info(f"{Color.RED}Failed to get models{Color.RESET}") + + +async def main() -> None: + parser = argparse.ArgumentParser( + prog="Benchmark serving with multi-turn conversations", + description="Benchmark online inference using REST API", + ) + parser.add_argument("--version", action="version", version="%(prog)s 1.0") + + parser.add_argument( + "-i", + "--input-file", + type=str, + required=True, + help="Input JSON file with ShareGPT conversations or " + "configuration file for generation of synthetic conversations", + ) + parser.add_argument( + "-o", + "--output-file", + type=str, + default=None, + help="Output JSON file containing conversations with updated assistant answers", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="Seed for random number generators (default: 0)", + ) + + parser.add_argument( + "-m", "--model", type=str, required=True, help="Path of the LLM model" + ) + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the `--model` argument. ", + ) + + parser.add_argument( + "-u", + "--url", + type=str, + default="http://localhost:8000", + help="Base URL for the LLM API server", + ) + + parser.add_argument( + "-p", + "--num-clients", + type=int, + default=1, + help="Number of clients that will send requests in parallel", + ) + parser.add_argument( + "-k", + "--max-active-conversations", + type=int, + default=None, + help="Max number of active conversations at a time (for all clients)", + ) + parser.add_argument( + "-n", + "--max-num-requests", + type=int, + default=None, + help="Max number of requests to send (total for all clients)", + ) + + parser.add_argument( + "--warmup-step", + default=False, + action="store_true", + help="Run a warmup step (using only the first turn of every conversation), " + "measurements will not be included in the final benchmark results", + ) + + parser.add_argument( + "--max-turns", + type=int, + default=None, + help="Maximum number of turns/messages per conversation, " + "includes both user and assistant messages " + "(a positive number, e.g: 2, 4, 6, etc.), disabled by default", + ) + parser.add_argument( + "--no-early-stop", + default=False, + action="store_true", + help="By default, the benchmark will stop if at least one client exits." + " Use this flag to disable this behavior", + ) + + parser.add_argument( + "--limit-max-tokens", + type=int, + default=NUM_TOKENS_FROM_DATASET, + help="Set max_tokens for the output token count of each request " + "(must also set --limit-min-tokens). " + "Overrides output token count from the input dataset. " + "Use a negative value to disable this limit.", + ) + parser.add_argument( + "--limit-min-tokens", + type=int, + default=NUM_TOKENS_FROM_DATASET, + help="Set min_tokens for the output token count of each request " + "(must also set --limit-max-tokens). " + "Overrides output token count from the input dataset. " + "Use a negative value to disable this limit.", + ) + + parser.add_argument( + "--request-rate", + type=float, + default=0, + help="Expected request rate (Poisson process) per client in requests/sec." + "Set to 0 for no delay between requests.", + ) + parser.add_argument( + "--max-retries", + type=int, + default=int(os.environ.get("MULTITURN_BENCH_MAX_RETRIES", "0")), + help="Maximum number of retry attempts for timed-out requests. " + "Default is 0 (no retries). " + "Set to higher values to retry failed requests and maintain " + "fair workload distribution. " + "Can also be set via MULTITURN_BENCH_MAX_RETRIES environment variable.", + ) + parser.add_argument( + "--conversation-sampling", + type=ConversationSampling, + choices=list(ConversationSampling), + default=ConversationSampling.ROUND_ROBIN, + help=( + "Strategy for selecting which conversation to use for the next request. " + "Options: 'round_robin' (cycle through conversations), " + "'random' (pick randomly)." + ), + ) + parser.add_argument( + "--verify-output", + default=False, + action="store_true", + help="Verify the LLM output (compare to the answers in the input JSON file)", + ) + parser.add_argument( + "--request-timeout-sec", + type=int, + default=120, + help="Timeout in seconds for each API request (default: 120). " + "Automatically increased if max tokens imply longer decoding.", + ) + + parser.add_argument( + "--no-stream", + default=False, + action="store_true", + help="Disable stream/streaming mode (set 'stream' to False in the API request)", + ) + + parser.add_argument( + "-e", + "--excel-output", + default=False, + action="store_true", + help="Export summary to Excel file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + default=False, + action="store_true", + help="Enable verbose output", + ) + parser.add_argument( + "--print-content", + default=False, + action="store_true", + help="Print the user prompts and the server's answers", + ) + + parser.add_argument( + "--warmup-percentages", + type=str, + default="0%", + help="Ignore the first X samples as warmup (X is a percentage)." + " A comma separated list of percentages can be used " + "(for example: --warmup-percentages=0%%,50%%)", + ) + + args = parser.parse_args() + + logger.info(args) + + logger.info(f"{Color.GREEN}Input parameters:{Color.RESET}") + logger.info(f"url={args.url}") + logger.info(f"model={args.model}") + logger.info(f"num_clients={args.num_clients}") + + if args.verify_output: + logger.info(f"{Color.PURPLE}Verify is enabled{Color.RESET}") + + # Calculate the amount of samples to filter (as warmup samples/measurements). + try: + warmup_percentages: list[float] = [0.0] + if not args.warmup_step: + # Warmup percentage can be used only if the warmup step was used + warmup_strings: list[str] = args.warmup_percentages.split(",") + warmup_strings = [x.replace("%", "") for x in warmup_strings] + warmup_percentages = [float(x) / 100 for x in warmup_strings] + + # Check for valid range (0 to 1) + for p in warmup_percentages: + assert p >= 0.0 and p < 1.0 + + # Sort from high to low warmup percentage + warmup_percentages.sort() + + logger.info( + f"Warmup percentages (percentage of samples): {warmup_percentages}" + ) + + except Exception: + raise ValueError( + f"Invalid --warmup-percentage={args.warmup_percentage}" + ) from None + + # Set global seeds for main process + random.seed(args.seed) + np.random.seed(args.seed) + + logger.info("Loading tokenizer") + tokenizer = AutoTokenizer.from_pretrained(args.model) + + await get_server_info(args.url) + + # Load the input file (either conversations of configuration file) + logger.info(f"Reading input file: {args.input_file}") + with open(args.input_file) as f: + input_data = json.load(f) + + gen_conv_args = None + if isinstance(input_data, list): + # The conversations are stored as a list of dicts + logger.info(f"Found {len(input_data)} items in the input file") + + # Convert the list to a ConversationsMap + conversations = conversations_list_to_dict(input_data) + + elif isinstance(input_data, dict): + # The input file is a configuration file + # (type is determined by the field 'filetype') + if "filetype" not in input_data: + raise Exception( + f"Input file {args.input_file} is invalid (missing 'filetype')" + ) + + logger.info(f"Using input file with filetype: {input_data['filetype']}") + + gen_conv_args = parse_input_json_file(input_data) + + # Disable warning from "huggingface/tokenizers" + # (when using python multiprocessing and tokenizers) + os.environ["TOKENIZERS_PARALLELISM"] = "true" + + # Generate synthetic conversations + conversations = generate_conversations(gen_conv_args, tokenizer) + + else: + raise Exception(f"Input file {args.input_file} is invalid") + + if args.max_turns is not None: + if args.max_turns < 1: + raise ValueError("Max turns must be a positive number") + logger.info( + f"{Color.PURPLE}Max turns per conversation " + f"is limited to {args.max_turns}{Color.RESET}" + ) + + # Create benchmark configurations + client_args, req_args = get_client_config(args, conversations) + + bench_args = BenchmarkArgs( + url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop + ) + + warmup_runtime_sec: float | None = None + + # Warm-up step + if args.warmup_step: + # Only send a single user prompt from every conversation. + # max_active_conversations must be 1, + # otherwise the clients may exit after sending a single request + # (because the task queue is empty). + warmup_client_args = client_args._replace( + skip_first_turn=False, max_turns=1, max_active_conversations=1 + ) + + # Early stop should be disabled, + # all clients should finish their work before exiting + warmup_bench_args = bench_args._replace(early_stop=False) + + logger.info("%sWarmup start%s", Color.PURPLE, Color.RESET) + warmup_start_ns = time.perf_counter_ns() + conversations, _ = await main_mp( + warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations + ) + warmup_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - warmup_start_ns) + logger.info( + "%sWarmup runtime: %.3f sec (%.3f ms)%s", + Color.PURPLE, + warmup_runtime_sec, + warmup_runtime_sec * 1000, + Color.RESET, + ) + logger.info("%sWarmup done%s", Color.PURPLE, Color.RESET) + + # Run the benchmark + benchmark_start_ns = time.perf_counter_ns() + client_convs, client_metrics = await main_mp( + client_args, req_args, bench_args, tokenizer, conversations + ) + benchmark_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - benchmark_start_ns) + + # Calculate requests per second + requests_per_sec = len(client_metrics) / benchmark_runtime_sec + benchmark_runtime_ms = benchmark_runtime_sec * 1000.0 + logger.info( + "%sAll clients finished, benchmark runtime: %.3f sec (%.3f ms), " + "requests per second: %.3f%s", + Color.GREEN, + benchmark_runtime_sec, + benchmark_runtime_ms, + requests_per_sec, + Color.RESET, + ) + if warmup_runtime_sec is not None: + total_runtime_sec = benchmark_runtime_sec + warmup_runtime_sec + logger.info( + "%sWarmup runtime: %.3f sec (%.3f ms)%s", + Color.GREEN, + warmup_runtime_sec, + warmup_runtime_sec * 1000, + Color.RESET, + ) + logger.info( + "%sTotal runtime (including warmup): %.3f sec (%.3f ms)%s", + Color.GREEN, + total_runtime_sec, + total_runtime_sec * 1000, + Color.RESET, + ) + + # Benchmark parameters + params = { + "model": args.model, + "num_clients": args.num_clients, + "num_conversations": len(conversations), + "active_conversations": args.max_active_conversations, + "seed": args.seed, + } + + if args.limit_min_tokens > 0: + params["min_tokens"] = args.limit_min_tokens + + if args.limit_max_tokens > 0: + params["max_tokens"] = args.limit_max_tokens + + # Process and print statistics (and save excel file with the statistics) + process_statistics( + client_metrics, + test_params=params, + warmup_percentages=warmup_percentages, + verbose=args.verbose, + gen_conv_args=gen_conv_args, + excel_output=args.excel_output, + warmup_runtime_sec=warmup_runtime_sec, + ) + + if args.output_file is not None: + # Write a JSON file with the updated conversations + # The "assistant" content will contain the answers from the tested LLM + output_data: ShareGptConversations = conversations_dict_to_list(client_convs) + logger.info( + f"{Color.GREEN}Writing conversations file: {args.output_file}{Color.RESET}" + ) + with open(args.output_file, "w") as f: + json.dump(output_data, f, indent=4) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/multi_turn/convert_sharegpt_to_openai.py b/benchmarks/multi_turn/convert_sharegpt_to_openai.py new file mode 100644 index 0000000000000000000000000000000000000000..fccab4d0ce21ad69980736710eee7e9814485974 --- /dev/null +++ b/benchmarks/multi_turn/convert_sharegpt_to_openai.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Download dataset from: +https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json + +Convert to OpenAI API: +export INPUT_FILE=sharegpt_20230401_clean_lang_split.json +python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128 +""" + +import argparse +import json +import random +from statistics import mean +from typing import Any + +import pandas as pd # type: ignore +import tqdm # type: ignore +from transformers import AutoTokenizer # type: ignore + + +def has_non_english_chars(text: str) -> bool: + return not text.isascii() + + +def content_is_valid( + content: str, min_content_len: int | None, max_content_len: int | None +) -> bool: + if min_content_len and len(content) < min_content_len: + return False + + if max_content_len and len(content) > max_content_len: + return False + + return has_non_english_chars(content) + + +def print_stats( + conversations: "list[dict[Any, Any]]", tokenizer: AutoTokenizer | None = None +) -> None: + # Collect statistics + stats = [] + + print("\nCollecting statistics...") + for item in tqdm.tqdm(conversations): + # item has "id" and "messages" + messages = item["messages"] + + user_turns = 0 + assistant_turns = 0 + user_words = 0 + assistant_words = 0 + conv_chars = 0 + + user_tokens: list[int] = [] + assistant_tokens: list[int] = [] + + for m in messages: + content = m["content"] + conv_chars += len(content) + content_num_words = content.count(" ") + 1 + + num_tokens = 0 + if tokenizer: + num_tokens = len(tokenizer(m["content"]).input_ids) + + if m["role"] == "user": + user_turns += 1 + user_words += content_num_words + if tokenizer: + user_tokens.append(num_tokens) + + elif m["role"] == "assistant": + assistant_turns += 1 + assistant_words += content_num_words + if tokenizer: + assistant_tokens.append(num_tokens) + + # assert user_turns == assistant_turns, \ + # f"Invalid conversation ID {item['id']}" + + conv_words = user_words + assistant_words + item_stats = { + "user_turns": user_turns, + "assistant_turns": assistant_turns, + "user_words": user_words, + "assistant_words": assistant_words, + "conv_turns": len(messages), + "conv_words": conv_words, + "conv_characters": conv_chars, + } + + if len(user_tokens) > 0: + item_stats["user_tokens"] = int(mean(user_tokens)) + + if len(assistant_tokens) > 0: + item_stats["assistant_tokens"] = int(mean(assistant_tokens)) + + stats.append(item_stats) + + print("\nStatistics:") + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999] + df = pd.DataFrame(stats) + print(df.describe(percentiles=percentiles).transpose()) + + +def convert_sharegpt_to_openai( + seed: int, + input_file: str, + output_file: str, + max_items: int | None, + min_content_len: int | None = None, + max_content_len: int | None = None, + min_turns: int | None = None, + max_turns: int | None = None, + model: str | None = None, +) -> None: + if min_turns and max_turns: + assert min_turns <= max_turns + + if min_content_len and max_content_len: + # Verify that min is not larger than max if both were given + assert min_content_len <= max_content_len + + print( + f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=}," + f" {max_content_len=}, {min_turns=}, {max_turns=}\n" + ) + + random.seed(seed) + + tokenizer = None + if model is not None: + print(f"Loading tokenizer from: {model}") + tokenizer = AutoTokenizer.from_pretrained(model) + + # Read the ShareGPT JSON file + print(f"Reading file: {input_file}") + with open(input_file, encoding="utf-8") as f: + # Should be a list of dicts + # Each dict should have "id" (string) and "conversations" (list of dicts) + sharegpt_data = json.load(f) + + assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts" + + print(f"Total items in input file: {len(sharegpt_data):,}") + + print(f"Shuffling dataset with seed {seed}") + random.shuffle(sharegpt_data) + + # Map conversation ID to the all the messages + conversation_parts: dict[str, list[Any]] = {} + + for item in tqdm.tqdm(sharegpt_data): + assert "id" in item, "Missing key 'id'" + assert "conversations" in item, "Missing key 'conversations'" + + # Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.) + conv_id, _ = item["id"].split("_") + new_turns = item["conversations"] + + if conv_id not in conversation_parts: + # Start new conversation + conversation_parts[conv_id] = [] + elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0: + prev_turns = conversation_parts[conv_id][-1] + if prev_turns[-1]["from"] == new_turns[0]["from"]: + new_turns = new_turns[1:] + + if len(new_turns) > 0: + # We assume that parts are in order in the ShareGPT dataset + conversation_parts[conv_id].append(new_turns) + + dataset: list[dict[str, Any]] = [] + for conv_id, conv_parts in conversation_parts.items(): + new_item = {"id": conv_id} + + conversations: list[dict[str, str]] = [] + + # Merge all parts + for conv_part in conv_parts: + conversations.extend(conv_part) + + if len(conversations) > 0: + new_item["conversations"] = conversations + dataset.append(new_item) + + print(f"Total unique conversations (IDs) in input file: {len(dataset):,}") + + # Final output data + final_openai_dataset: list[dict] = [] + + # Filter conversations from the ShareGPT dataset and convert to OpenAI format + for item in tqdm.tqdm(dataset): + messages: list[dict] = [] + + assert "id" in item, "Missing key 'id'" + assert "conversations" in item, "Missing key 'conversations'" + + conv_id = item["id"] + conversations = item["conversations"] + + if min_turns is not None and len(conversations) < min_turns: + # Skip short conversations + continue + + # Convert each message in the conversation, up to max_turns if specified + for i, turn in enumerate(conversations): + assert "from" in turn and "value" in turn, ( + f"Invalid conversation ID {conv_id} - missing 'from' or 'value'" + ) + + role = None + turn_from = turn["from"] + + if turn_from in {"human", "user"}: + role = "user" + elif turn_from in {"gpt", "bing", "chatgpt", "bard"}: + role = "assistant" + elif turn_from == "system": + role = "system" + + assert role is not None, ( + f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid" + ) + + if i == 0 and role != "user": + # If the first message is from assistant (gpt), skip it. + # this happens when the conversation is a follow-up + # to a previous conversation (from the same user). + continue + + if max_turns is not None and i >= max_turns: + break + + # Convert message to OpenAI format (with "role" and "content") + content = turn["value"] + messages.append({"role": role, "content": content}) + + # Add the converted conversation to the OpenAI format + if len(messages) > 0: + valid_messages = True + + # First turn should always be from the user + user_turn = True + + for m in messages: + # Make sure that turns alternate between user and assistant + if (user_turn and m["role"] != "user") or ( + not user_turn and m["role"] != "assistant" + ): + valid_messages = False + break + + user_turn = not user_turn + + content = m["content"] + valid_messages = content_is_valid( + content, min_content_len, max_content_len + ) + if not valid_messages: + break + + if valid_messages is True: + final_openai_dataset.append({"id": conv_id, "messages": messages}) + + assert len(final_openai_dataset) > 0, "Final number of conversations is zero" + + print_stats(final_openai_dataset) + + print_stats_again = False + if max_items is not None and len(final_openai_dataset) > max_items: + print(f"\n\nSampling {max_items} items from the dataset...") + print_stats_again = True + final_openai_dataset = random.sample(final_openai_dataset, max_items) + + if print_stats_again: + # Print stats after the dataset changed + print_stats(final_openai_dataset, tokenizer) + + # Write the converted data to a new JSON file + final_size = len(final_openai_dataset) + print(f"\nTotal conversations converted (after filtering): {final_size:,}") + print(f"\nWriting file: {output_file}") + with open(output_file, "w", encoding="utf-8") as f: + json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert ShareGPT dataset to OpenAI API format" + ) + parser.add_argument("input_file", help="Path to the input ShareGPT JSON file") + parser.add_argument( + "output_file", help="Path to the output OpenAI format JSON file" + ) + parser.add_argument( + "--seed", type=int, default=0, help="Seed for random number generators" + ) + parser.add_argument( + "--max-items", + type=int, + default=None, + help="Maximum number of items in the output file", + ) + parser.add_argument( + "--min-turns", + type=int, + default=None, + help="Minimum number of turns per conversation", + ) + parser.add_argument( + "--max-turns", + type=int, + default=None, + help="Maximum number of turns per conversation", + ) + parser.add_argument( + "--min-content-len", + type=int, + default=None, + help="Min number of characters in the messages' content", + ) + parser.add_argument( + "--max-content-len", + type=int, + default=None, + help="Max number of characters in the messages' content", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LLM model, only the tokenizer will be used", + ) + + args = parser.parse_args() + + convert_sharegpt_to_openai( + args.seed, + args.input_file, + args.output_file, + args.max_items, + args.min_content_len, + args.max_content_len, + args.min_turns, + args.max_turns, + args.model, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/multi_turn/requirements.txt b/benchmarks/multi_turn/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bae656a5c5c4bd4f2bb544b32138cd31fd692c87 --- /dev/null +++ b/benchmarks/multi_turn/requirements.txt @@ -0,0 +1,6 @@ +numpy>=1.24 +pandas>=2.0.0 +aiohttp>=3.10 +transformers>=4.46 +xlsxwriter>=3.2.1 +tqdm>=4.66 diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py new file mode 100644 index 0000000000000000000000000000000000000000..178599952d5c4e81145c73676b8ab0d4eaef6aa6 --- /dev/null +++ b/benchmarks/overheads/benchmark_hashing.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import cProfile +import pstats + +from vllm import LLM, SamplingParams +from vllm.utils.argparse_utils import FlexibleArgumentParser + +# A very long prompt, total number of tokens is about 15k. +LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000 +LONG_PROMPT = " ".join(LONG_PROMPT) + + +def main(args): + llm = LLM( + model=args.model, + enforce_eager=True, + enable_prefix_caching=True, + tensor_parallel_size=args.tensor_parallel_size, + ) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + profiler = cProfile.Profile() + + print("------warm up------") + for i in range(3): + output = llm.generate(LONG_PROMPT, sampling_params) + print(output[0].outputs[0].text) + + print("------start generating------") + for i in range(3): + profiler.runctx( + "llm.generate(LONG_PROMPT, sampling_params)", globals(), locals() + ) + + # analyze the runtime of hashing function + stats = pstats.Stats(profiler) + stats.sort_stats("cumulative") + total_time = 0 + total_calls = 0 + for func in stats.stats: + if "hash_of_block" in func[2]: + total_time = stats.stats[func][3] + total_calls = stats.stats[func][0] + percentage = (total_time / stats.total_tt) * 100 + print( + f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime." + ) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the performance of hashing function in" + "automatic prefix caching." + ) + parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k") + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--output-len", type=int, default=10) + parser.add_argument( + "--enable-prefix-caching", action="store_true", help="enable prefix caching" + ) + args = parser.parse_args() + main(args) diff --git a/benchmarks/run_structured_output_benchmark.sh b/benchmarks/run_structured_output_benchmark.sh new file mode 100644 index 0000000000000000000000000000000000000000..bc40ed83f438c69212feda8207f63fa000100121 --- /dev/null +++ b/benchmarks/run_structured_output_benchmark.sh @@ -0,0 +1,131 @@ +#!/bin/bash + +# default values +MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"} +BACKEND=${BACKEND:-"vllm"} +DATASET=${DATASET:-"xgrammar_bench"} +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"} +PORT=${PORT:-8000} +STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1} +TOTAL_SECONDS=${TOTAL_SECONDS:-90} +MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300} +TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"} + +usage() { + echo "Usage: $0 [options]" + echo "Options:" + echo " --model MODEL Model to benchmark (default: $MODEL)" + echo " --backend BACKEND Backend to use (default: $BACKEND)" + echo " --dataset DATASET Dataset to use (default: $DATASET)" + echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)" + echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)" + echo " --port PORT Port to use (default: $PORT)" + echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)" + echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)" + echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)" + echo " -h, --help Show this help message and exit" + exit 0 +} + +# parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --backend) + BACKEND="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --max-new-tokens) + MAX_NEW_TOKENS="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --structured-output-ratio) + STRUCTURED_OUTPUT_RATIO="$2" + shift 2 + ;; + --tokenizer-mode) + TOKENIZER_MODE="$2" + shift 2 + ;; + --total-seconds) + TOTAL_SECONDS="$2" + shift 2 + ;; + -h|--help) + usage + ;; + *) + printf "Unknown argument: %s\n" "$1" + usage + ;; + esac +done + +# Create output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +# Define QPS values to test +QPS_VALUES=(25 20 15 10 5 1) + +# Common parameters +COMMON_PARAMS=( + --backend "$BACKEND" + --model "$MODEL" + --dataset "$DATASET" + --structured-output-ratio "$STRUCTURED_OUTPUT_RATIO" + --save-results + --result-dir "$OUTPUT_DIR" + --output-len "$MAX_NEW_TOKENS" + --port "$PORT" + --tokenizer-mode "$TOKENIZER_MODE" +) + +echo "Starting structured output benchmark with model: $MODEL" +echo "Backend: $BACKEND" +echo "Dataset: $DATASET" +echo "Results will be saved to: $OUTPUT_DIR" +echo "----------------------------------------" + +# Run benchmarks with different QPS values +for qps in "${QPS_VALUES[@]}"; do + echo "Running benchmark with QPS: $qps" + + # Get git hash and branch for the filename + GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown") + GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown") + + # Construct filename for this run + FILENAME="${BACKEND}_${qps}qps_$(basename "$MODEL")_${DATASET}_${GIT_HASH}_${GIT_BRANCH}.json" + + NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc) + NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part + echo "Running benchmark with $NUM_PROMPTS prompts" + + # Run the benchmark + python "$SCRIPT_DIR/benchmark_serving_structured_output.py" "${COMMON_PARAMS[@]}" \ + --request-rate "$qps" \ + --result-filename "$FILENAME" \ + --num-prompts "$NUM_PROMPTS" + + echo "Completed benchmark with QPS: $qps" + echo "----------------------------------------" +done + +echo "All benchmarks completed!" +echo "Results saved to: $OUTPUT_DIR" diff --git a/benchmarks/sonnet.txt b/benchmarks/sonnet.txt new file mode 100644 index 0000000000000000000000000000000000000000..34c444e8ce8e2dc701ec80931401c57014ae0bd1 --- /dev/null +++ b/benchmarks/sonnet.txt @@ -0,0 +1,518 @@ +FROM fairest creatures we desire increase, +That thereby beauty's rose might never die, +But as the riper should by time decease, +His tender heir might bear his memory: +But thou, contracted to thine own bright eyes, +Feed'st thy light'st flame with self-substantial fuel, +Making a famine where abundance lies, +Thyself thy foe, to thy sweet self too cruel. +Thou that art now the world's fresh ornament +And only herald to the gaudy spring, +Within thine own bud buriest thy content +And, tender churl, makest waste in niggarding. +Pity the world, or else this glutton be, +To eat the world's due, by the grave and thee. +When forty winters shall beseige thy brow, +And dig deep trenches in thy beauty's field, +Thy youth's proud livery, so gazed on now, +Will be a tatter'd weed, of small worth held: +Then being ask'd where all thy beauty lies, +Where all the treasure of thy lusty days, +To say, within thine own deep-sunken eyes, +Were an all-eating shame and thriftless praise. +How much more praise deserved thy beauty's use, +If thou couldst answer 'This fair child of mine +Shall sum my count and make my old excuse,' +Proving his beauty by succession thine! +This were to be new made when thou art old, +And see thy blood warm when thou feel'st it cold. +Look in thy glass, and tell the face thou viewest +Now is the time that face should form another; +Whose fresh repair if now thou not renewest, +Thou dost beguile the world, unbless some mother. +For where is she so fair whose unear'd womb +Disdains the tillage of thy husbandry? +Or who is he so fond will be the tomb +Of his self-love, to stop posterity? +Thou art thy mother's glass, and she in thee +Calls back the lovely April of her prime: +So thou through windows of thine age shall see +Despite of wrinkles this thy golden time. +But if thou live, remember'd not to be, +Die single, and thine image dies with thee. +Unthrifty loveliness, why dost thou spend +Upon thyself thy beauty's legacy? +Nature's bequest gives nothing but doth lend, +And being frank she lends to those are free. +Then, beauteous niggard, why dost thou abuse +The bounteous largess given thee to give? +Profitless usurer, why dost thou use +So great a sum of sums, yet canst not live? +For having traffic with thyself alone, +Thou of thyself thy sweet self dost deceive. +Then how, when nature calls thee to be gone, +What acceptable audit canst thou leave? +Thy unused beauty must be tomb'd with thee, +Which, used, lives th' executor to be. +Those hours, that with gentle work did frame +The lovely gaze where every eye doth dwell, +Will play the tyrants to the very same +And that unfair which fairly doth excel: +For never-resting time leads summer on +To hideous winter and confounds him there; +Sap cheque'd with frost and lusty leaves quite gone, +Beauty o'ersnow'd and bareness every where: +Then, were not summer's distillation left, +A liquid prisoner pent in walls of glass, +Beauty's effect with beauty were bereft, +Nor it nor no remembrance what it was: +But flowers distill'd though they with winter meet, +Leese but their show; their substance still lives sweet. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Lo! in the orient when the gracious light +Lifts up his burning head, each under eye +Doth homage to his new-appearing sight, +Serving with looks his sacred majesty; +And having climb'd the steep-up heavenly hill, +Resembling strong youth in his middle age, +yet mortal looks adore his beauty still, +Attending on his golden pilgrimage; +But when from highmost pitch, with weary car, +Like feeble age, he reeleth from the day, +The eyes, 'fore duteous, now converted are +From his low tract and look another way: +So thou, thyself out-going in thy noon, +Unlook'd on diest, unless thou get a son. +Music to hear, why hear'st thou music sadly? +Sweets with sweets war not, joy delights in joy. +Why lovest thou that which thou receivest not gladly, +Or else receivest with pleasure thine annoy? +If the true concord of well-tuned sounds, +By unions married, do offend thine ear, +They do but sweetly chide thee, who confounds +In singleness the parts that thou shouldst bear. +Mark how one string, sweet husband to another, +Strikes each in each by mutual ordering, +Resembling sire and child and happy mother +Who all in one, one pleasing note do sing: +Whose speechless song, being many, seeming one, +Sings this to thee: 'thou single wilt prove none.' +Is it for fear to wet a widow's eye +That thou consumest thyself in single life? +Ah! if thou issueless shalt hap to die. +The world will wail thee, like a makeless wife; +The world will be thy widow and still weep +That thou no form of thee hast left behind, +When every private widow well may keep +By children's eyes her husband's shape in mind. +Look, what an unthrift in the world doth spend +Shifts but his place, for still the world enjoys it; +But beauty's waste hath in the world an end, +And kept unused, the user so destroys it. +No love toward others in that bosom sits +That on himself such murderous shame commits. +For shame! deny that thou bear'st love to any, +Who for thyself art so unprovident. +Grant, if thou wilt, thou art beloved of many, +But that thou none lovest is most evident; +For thou art so possess'd with murderous hate +That 'gainst thyself thou stick'st not to conspire. +Seeking that beauteous roof to ruinate +Which to repair should be thy chief desire. +O, change thy thought, that I may change my mind! +Shall hate be fairer lodged than gentle love? +Be, as thy presence is, gracious and kind, +Or to thyself at least kind-hearted prove: +Make thee another self, for love of me, +That beauty still may live in thine or thee. +As fast as thou shalt wane, so fast thou growest +In one of thine, from that which thou departest; +And that fresh blood which youngly thou bestowest +Thou mayst call thine when thou from youth convertest. +Herein lives wisdom, beauty and increase: +Without this, folly, age and cold decay: +If all were minded so, the times should cease +And threescore year would make the world away. +Let those whom Nature hath not made for store, +Harsh featureless and rude, barrenly perish: +Look, whom she best endow'd she gave the more; +Which bounteous gift thou shouldst in bounty cherish: +She carved thee for her seal, and meant thereby +Thou shouldst print more, not let that copy die. +When I do count the clock that tells the time, +And see the brave day sunk in hideous night; +When I behold the violet past prime, +And sable curls all silver'd o'er with white; +When lofty trees I see barren of leaves +Which erst from heat did canopy the herd, +And summer's green all girded up in sheaves +Borne on the bier with white and bristly beard, +Then of thy beauty do I question make, +That thou among the wastes of time must go, +Since sweets and beauties do themselves forsake +And die as fast as they see others grow; +And nothing 'gainst Time's scythe can make defence +Save breed, to brave him when he takes thee hence. +O, that you were yourself! but, love, you are +No longer yours than you yourself here live: +Against this coming end you should prepare, +And your sweet semblance to some other give. +So should that beauty which you hold in lease +Find no determination: then you were +Yourself again after yourself's decease, +When your sweet issue your sweet form should bear. +Who lets so fair a house fall to decay, +Which husbandry in honour might uphold +Against the stormy gusts of winter's day +And barren rage of death's eternal cold? +O, none but unthrifts! Dear my love, you know +You had a father: let your son say so. +Not from the stars do I my judgment pluck; +And yet methinks I have astronomy, +But not to tell of good or evil luck, +Of plagues, of dearths, or seasons' quality; +Nor can I fortune to brief minutes tell, +Pointing to each his thunder, rain and wind, +Or say with princes if it shall go well, +By oft predict that I in heaven find: +But from thine eyes my knowledge I derive, +And, constant stars, in them I read such art +As truth and beauty shall together thrive, +If from thyself to store thou wouldst convert; +Or else of thee this I prognosticate: +Thy end is truth's and beauty's doom and date. +When I consider every thing that grows +Holds in perfection but a little moment, +That this huge stage presenteth nought but shows +Whereon the stars in secret influence comment; +When I perceive that men as plants increase, +Cheered and cheque'd even by the self-same sky, +Vaunt in their youthful sap, at height decrease, +And wear their brave state out of memory; +Then the conceit of this inconstant stay +Sets you most rich in youth before my sight, +Where wasteful Time debateth with Decay, +To change your day of youth to sullied night; +And all in war with Time for love of you, +As he takes from you, I engraft you new. +But wherefore do not you a mightier way +Make war upon this bloody tyrant, Time? +And fortify yourself in your decay +With means more blessed than my barren rhyme? +Now stand you on the top of happy hours, +And many maiden gardens yet unset +With virtuous wish would bear your living flowers, +Much liker than your painted counterfeit: +So should the lines of life that life repair, +Which this, Time's pencil, or my pupil pen, +Neither in inward worth nor outward fair, +Can make you live yourself in eyes of men. +To give away yourself keeps yourself still, +And you must live, drawn by your own sweet skill. +Who will believe my verse in time to come, +If it were fill'd with your most high deserts? +Though yet, heaven knows, it is but as a tomb +Which hides your life and shows not half your parts. +If I could write the beauty of your eyes +And in fresh numbers number all your graces, +The age to come would say 'This poet lies: +Such heavenly touches ne'er touch'd earthly faces.' +So should my papers yellow'd with their age +Be scorn'd like old men of less truth than tongue, +And your true rights be term'd a poet's rage +And stretched metre of an antique song: +But were some child of yours alive that time, +You should live twice; in it and in my rhyme. +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Devouring Time, blunt thou the lion's paws, +And make the earth devour her own sweet brood; +Pluck the keen teeth from the fierce tiger's jaws, +And burn the long-lived phoenix in her blood; +Make glad and sorry seasons as thou fleets, +And do whate'er thou wilt, swift-footed Time, +To the wide world and all her fading sweets; +But I forbid thee one most heinous crime: +O, carve not with thy hours my love's fair brow, +Nor draw no lines there with thine antique pen; +Him in thy course untainted do allow +For beauty's pattern to succeeding men. +Yet, do thy worst, old Time: despite thy wrong, +My love shall in my verse ever live young. +A woman's face with Nature's own hand painted +Hast thou, the master-mistress of my passion; +A woman's gentle heart, but not acquainted +With shifting change, as is false women's fashion; +An eye more bright than theirs, less false in rolling, +Gilding the object whereupon it gazeth; +A man in hue, all 'hues' in his controlling, +Much steals men's eyes and women's souls amazeth. +And for a woman wert thou first created; +Till Nature, as she wrought thee, fell a-doting, +And by addition me of thee defeated, +By adding one thing to my purpose nothing. +But since she prick'd thee out for women's pleasure, +Mine be thy love and thy love's use their treasure. +So is it not with me as with that Muse +Stirr'd by a painted beauty to his verse, +Who heaven itself for ornament doth use +And every fair with his fair doth rehearse +Making a couplement of proud compare, +With sun and moon, with earth and sea's rich gems, +With April's first-born flowers, and all things rare +That heaven's air in this huge rondure hems. +O' let me, true in love, but truly write, +And then believe me, my love is as fair +As any mother's child, though not so bright +As those gold candles fix'd in heaven's air: +Let them say more than like of hearsay well; +I will not praise that purpose not to sell. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +As an unperfect actor on the stage +Who with his fear is put besides his part, +Or some fierce thing replete with too much rage, +Whose strength's abundance weakens his own heart. +So I, for fear of trust, forget to say +The perfect ceremony of love's rite, +And in mine own love's strength seem to decay, +O'ercharged with burden of mine own love's might. +O, let my books be then the eloquence +And dumb presagers of my speaking breast, +Who plead for love and look for recompense +More than that tongue that more hath more express'd. +O, learn to read what silent love hath writ: +To hear with eyes belongs to love's fine wit. +Mine eye hath play'd the painter and hath stell'd +Thy beauty's form in table of my heart; +My body is the frame wherein 'tis held, +And perspective it is the painter's art. +For through the painter must you see his skill, +To find where your true image pictured lies; +Which in my bosom's shop is hanging still, +That hath his windows glazed with thine eyes. +Now see what good turns eyes for eyes have done: +Mine eyes have drawn thy shape, and thine for me +Are windows to my breast, where-through the sun +Delights to peep, to gaze therein on thee; +Yet eyes this cunning want to grace their art; +They draw but what they see, know not the heart. +Let those who are in favour with their stars +Of public honour and proud titles boast, +Whilst I, whom fortune of such triumph bars, +Unlook'd for joy in that I honour most. +Great princes' favourites their fair leaves spread +But as the marigold at the sun's eye, +And in themselves their pride lies buried, +For at a frown they in their glory die. +The painful warrior famoused for fight, +After a thousand victories once foil'd, +Is from the book of honour razed quite, +And all the rest forgot for which he toil'd: +Then happy I, that love and am beloved +Where I may not remove nor be removed. +Lord of my love, to whom in vassalage +Thy merit hath my duty strongly knit, +To thee I send this written embassage, +To witness duty, not to show my wit: +Duty so great, which wit so poor as mine +May make seem bare, in wanting words to show it, +But that I hope some good conceit of thine +In thy soul's thought, all naked, will bestow it; +Till whatsoever star that guides my moving +Points on me graciously with fair aspect +And puts apparel on my tatter'd loving, +To show me worthy of thy sweet respect: +Then may I dare to boast how I do love thee; +Till then not show my head where thou mayst prove me. +Weary with toil, I haste me to my bed, +The dear repose for limbs with travel tired; +But then begins a journey in my head, +To work my mind, when body's work's expired: +For then my thoughts, from far where I abide, +Intend a zealous pilgrimage to thee, +And keep my drooping eyelids open wide, +Looking on darkness which the blind do see +Save that my soul's imaginary sight +Presents thy shadow to my sightless view, +Which, like a jewel hung in ghastly night, +Makes black night beauteous and her old face new. +Lo! thus, by day my limbs, by night my mind, +For thee and for myself no quiet find. +How can I then return in happy plight, +That am debarr'd the benefit of rest? +When day's oppression is not eased by night, +But day by night, and night by day, oppress'd? +And each, though enemies to either's reign, +Do in consent shake hands to torture me; +The one by toil, the other to complain +How far I toil, still farther off from thee. +I tell the day, to please them thou art bright +And dost him grace when clouds do blot the heaven: +So flatter I the swart-complexion'd night, +When sparkling stars twire not thou gild'st the even. +But day doth daily draw my sorrows longer +And night doth nightly make grief's strength seem stronger. +When, in disgrace with fortune and men's eyes, +I all alone beweep my outcast state +And trouble deal heaven with my bootless cries +And look upon myself and curse my fate, +Wishing me like to one more rich in hope, +Featured like him, like him with friends possess'd, +Desiring this man's art and that man's scope, +With what I most enjoy contented least; +Yet in these thoughts myself almost despising, +Haply I think on thee, and then my state, +Like to the lark at break of day arising +From sullen earth, sings hymns at heaven's gate; +For thy sweet love remember'd such wealth brings +That then I scorn to change my state with kings. +When to the sessions of sweet silent thought +I summon up remembrance of things past, +I sigh the lack of many a thing I sought, +And with old woes new wail my dear time's waste: +Then can I drown an eye, unused to flow, +For precious friends hid in death's dateless night, +And weep afresh love's long since cancell'd woe, +And moan the expense of many a vanish'd sight: +Then can I grieve at grievances foregone, +And heavily from woe to woe tell o'er +The sad account of fore-bemoaned moan, +Which I new pay as if not paid before. +But if the while I think on thee, dear friend, +All losses are restored and sorrows end. +Thy bosom is endeared with all hearts, +Which I by lacking have supposed dead, +And there reigns love and all love's loving parts, +And all those friends which I thought buried. +How many a holy and obsequious tear +Hath dear religious love stol'n from mine eye +As interest of the dead, which now appear +But things removed that hidden in thee lie! +Thou art the grave where buried love doth live, +Hung with the trophies of my lovers gone, +Who all their parts of me to thee did give; +That due of many now is thine alone: +Their images I loved I view in thee, +And thou, all they, hast all the all of me. +If thou survive my well-contented day, +When that churl Death my bones with dust shall cover, +And shalt by fortune once more re-survey +These poor rude lines of thy deceased lover, +Compare them with the bettering of the time, +And though they be outstripp'd by every pen, +Reserve them for my love, not for their rhyme, +Exceeded by the height of happier men. +O, then vouchsafe me but this loving thought: +'Had my friend's Muse grown with this growing age, +A dearer birth than this his love had brought, +To march in ranks of better equipage: +But since he died and poets better prove, +Theirs for their style I'll read, his for his love.' +Full many a glorious morning have I seen +Flatter the mountain-tops with sovereign eye, +Kissing with golden face the meadows green, +Gilding pale streams with heavenly alchemy; +Anon permit the basest clouds to ride +With ugly rack on his celestial face, +And from the forlorn world his visage hide, +Stealing unseen to west with this disgrace: +Even so my sun one early morn did shine +With all triumphant splendor on my brow; +But out, alack! he was but one hour mine; +The region cloud hath mask'd him from me now. +Yet him for this my love no whit disdaineth; +Suns of the world may stain when heaven's sun staineth. +Why didst thou promise such a beauteous day, +And make me travel forth without my cloak, +To let base clouds o'ertake me in my way, +Hiding thy bravery in their rotten smoke? +'Tis not enough that through the cloud thou break, +To dry the rain on my storm-beaten face, +For no man well of such a salve can speak +That heals the wound and cures not the disgrace: +Nor can thy shame give physic to my grief; +Though thou repent, yet I have still the loss: +The offender's sorrow lends but weak relief +To him that bears the strong offence's cross. +Ah! but those tears are pearl which thy love sheds, +And they are rich and ransom all ill deeds. +No more be grieved at that which thou hast done: +Roses have thorns, and silver fountains mud; +Clouds and eclipses stain both moon and sun, +And loathsome canker lives in sweetest bud. +All men make faults, and even I in this, +Authorizing thy trespass with compare, +Myself corrupting, salving thy amiss, +Excusing thy sins more than thy sins are; +For to thy sensual fault I bring in sense-- +Thy adverse party is thy advocate-- +And 'gainst myself a lawful plea commence: +Such civil war is in my love and hate +That I an accessary needs must be +To that sweet thief which sourly robs from me. +Let me confess that we two must be twain, +Although our undivided loves are one: +So shall those blots that do with me remain +Without thy help by me be borne alone. +In our two loves there is but one respect, +Though in our lives a separable spite, +Which though it alter not love's sole effect, +Yet doth it steal sweet hours from love's delight. +I may not evermore acknowledge thee, +Lest my bewailed guilt should do thee shame, +Nor thou with public kindness honour me, +Unless thou take that honour from thy name: +But do not so; I love thee in such sort +As, thou being mine, mine is thy good report. +As a decrepit father takes delight +To see his active child do deeds of youth, +So I, made lame by fortune's dearest spite, +Take all my comfort of thy worth and truth. +For whether beauty, birth, or wealth, or wit, +Or any of these all, or all, or more, +Entitled in thy parts do crowned sit, +I make my love engrafted to this store: +So then I am not lame, poor, nor despised, +Whilst that this shadow doth such substance give +That I in thy abundance am sufficed +And by a part of all thy glory live. +Look, what is best, that best I wish in thee: +This wish I have; then ten times happy me! \ No newline at end of file diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake new file mode 100644 index 0000000000000000000000000000000000000000..dde8cc20751b295e85818400193a8cc5e8169a2b --- /dev/null +++ b/cmake/cpu_extension.cmake @@ -0,0 +1,427 @@ +include(FetchContent) + +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_EXTENSIONS ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(MACOSX_FOUND TRUE) +endif() + + +# +# Define environment variables for special configurations +# +set(ENABLE_X86_ISA $ENV{VLLM_CPU_X86}) +set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16}) + +include_directories("${CMAKE_SOURCE_DIR}/csrc") + +set (ENABLE_NUMA TRUE) + +# +# Check the compile flags +# +if(MACOSX_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-DVLLM_CPU_EXTENSION") +else() + list(APPEND CXX_COMPILE_FLAGS + "-fopenmp" + "-DVLLM_CPU_EXTENSION") +endif() + +if (NOT MACOSX_FOUND) + execute_process(COMMAND cat /proc/cpuinfo + RESULT_VARIABLE CPUINFO_RET + OUTPUT_VARIABLE CPUINFO) + if (NOT CPUINFO_RET EQUAL 0) + message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") + endif() +endif() + + +function (find_isa CPUINFO TARGET OUT) + string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) + if(NOT ISA_FOUND EQUAL -1) + set(${OUT} ON PARENT_SCOPE) + else() + set(${OUT} OFF PARENT_SCOPE) + endif() +endfunction() + + +function(check_sysctl TARGET OUT) + execute_process(COMMAND sysctl -n "${TARGET}" + RESULT_VARIABLE SYSCTL_RET + OUTPUT_VARIABLE SYSCTL_INFO + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(SYSCTL_RET EQUAL 0 AND + (SYSCTL_INFO STREQUAL "1" OR SYSCTL_INFO GREATER 0)) + set(${OUT} ON PARENT_SCOPE) + else() + set(${OUT} OFF PARENT_SCOPE) + endif() +endfunction() + +if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + message(STATUS "Apple Silicon Detected") + set(APPLE_SILICON_FOUND TRUE) + set(ENABLE_NUMA OFF) + check_sysctl(hw.optional.neon ASIMD_FOUND) + check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND) +else() + find_isa(${CPUINFO} "Power11" POWER11_FOUND) + find_isa(${CPUINFO} "POWER10" POWER10_FOUND) + find_isa(${CPUINFO} "POWER9" POWER9_FOUND) + find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support + find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support + find_isa(${CPUINFO} "S390" S390_FOUND) + find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support + + # Support cross-compilation by allowing override via environment variables + if (ENABLE_ARM_BF16) + set(ARM_BF16_FOUND ON) + message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable") + endif() +endif() + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64" OR ENABLE_X86_ISA) + set(ENABLE_X86_ISA ON) + if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)) + message(FATAL_ERROR "X86 backend requires gcc/g++ >= 12.3") + endif() + list(APPEND CXX_COMPILE_FLAGS "-mf16c") + list(APPEND CXX_COMPILE_FLAGS_AVX512 ${CXX_COMPILE_FLAGS}) + list(APPEND CXX_COMPILE_FLAGS_AVX2 ${CXX_COMPILE_FLAGS}) + list(APPEND CXX_COMPILE_FLAGS_AVX512 + "-mavx512f" + "-mavx512vl" + "-mavx512bw" + "-mavx512dq" + "-mavx512bf16" + "-mavx512vnni" + "-mamx-bf16" + "-mamx-tile") + list(APPEND CXX_COMPILE_FLAGS_AVX2 + "-mavx2") +elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) + message(STATUS "PowerPC detected") + if (POWER9_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-mvsx" + "-mcpu=power9" + "-mtune=power9") + elseif (POWER10_FOUND OR POWER11_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-mvsx" + "-mcpu=power10" + "-mtune=power10") + endif() + +elseif (ASIMD_FOUND) + message(STATUS "ARMv8 or later architecture detected") + if(ARM_BF16_FOUND) + message(STATUS "BF16 extension detected") + set(MARCH_FLAGS "-march=armv8.2-a+bf16+dotprod+fp16") + add_compile_definitions(ARM_BF16_SUPPORT) + else() + message(WARNING "BF16 functionality is not available") + set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16") + endif() + list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS}) +elseif (S390_FOUND) + message(STATUS "S390 detected") + # Check for S390 VXE support + list(APPEND CXX_COMPILE_FLAGS + "-mvx" + "-mzvector" + "-march=native" + "-mtune=native") +elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64") + if(RVV_FOUND) + message(FAIL_ERROR "Can't support rvv now.") + else() + list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc") + endif() +else() + message(FATAL_ERROR "vLLM CPU backend requires X86, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") +endif() + + +# Build oneDNN for GEMM kernels +if (ENABLE_X86_ISA OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) + # Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64 + # TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN + set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "") + if(ASIMD_FOUND) + # Set number of parallel build processes + include(ProcessorCount) + ProcessorCount(NPROC) + if(NOT NPROC) + set(NPROC 4) + endif() + # locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0) + # and create a local shim dir with it + vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR) + + find_library(OPEN_MP + NAMES gomp + PATHS ${VLLM_TORCH_GOMP_SHIM_DIR} + NO_DEFAULT_PATH + REQUIRED + ) + # Set LD_LIBRARY_PATH to include the shim dir at build time to use the same libgomp as PyTorch + if (OPEN_MP) + set(ENV{LD_LIBRARY_PATH} "${VLLM_TORCH_GOMP_SHIM_DIR}:$ENV{LD_LIBRARY_PATH}") + endif() + + # Fetch and populate ACL + if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}") + message(STATUS "Using ACL from specified source directory: $ENV{ACL_ROOT_DIR}") + else() + message(STATUS "Downloading Arm Compute Library (ACL) from GitHub") + FetchContent_Populate(arm_compute + SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-subbuild" + SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-src" + GIT_REPOSITORY https://github.com/ARM-software/ComputeLibrary.git + GIT_TAG v52.6.0 + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE + ) + set(ENV{ACL_ROOT_DIR} "${arm_compute_SOURCE_DIR}") + set(ACL_LIB_DIR "$ENV{ACL_ROOT_DIR}/build") + endif() + + # Build ACL with CMake + set(_cmake_config_cmd + ${CMAKE_COMMAND} -G Ninja -B build + -DARM_COMPUTE_BUILD_SHARED_LIB=OFF + -DCMAKE_BUILD_TYPE=Release + -DARM_COMPUTE_ARCH=armv8.2-a + -DARM_COMPUTE_ENABLE_ASSERTS=OFF + -DARM_COMPUTE_ENABLE_CPPTHREADS=OFF + -DARM_COMPUTE_ENABLE_OPENMP=ON + -DARM_COMPUTE_ENABLE_WERROR=OFF + -DARM_COMPUTE_BUILD_EXAMPLES=OFF + -DARM_COMPUTE_BUILD_TESTING=OFF) + set(_cmake_build_cmd + ${CMAKE_COMMAND} --build build -- -j${NPROC} + ) + + execute_process( + COMMAND ${_cmake_config_cmd} + WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}" + ) + execute_process( + COMMAND ${_cmake_build_cmd} + WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}" + RESULT_VARIABLE _acl_rc + ) + + if(NOT _acl_rc EQUAL 0) + message(FATAL_ERROR "ACL SCons build failed (exit ${_acl_rc}).") + endif() + message(STATUS "Arm Compute Library (ACL) built successfully.") + + # VLLM/oneDNN settings for ACL + set(ONEDNN_AARCH64_USE_ACL ON CACHE BOOL "" FORCE) + add_compile_definitions(VLLM_USE_ACL) + endif() + + set(FETCHCONTENT_SOURCE_DIR_ONEDNN "$ENV{FETCHCONTENT_SOURCE_DIR_ONEDNN}" CACHE PATH "Path to a local oneDNN source directory.") + + if(FETCHCONTENT_SOURCE_DIR_ONEDNN) + message(STATUS "Using oneDNN from specified source directory: ${FETCHCONTENT_SOURCE_DIR_ONEDNN}") + FetchContent_Declare( + oneDNN + SOURCE_DIR ${FETCHCONTENT_SOURCE_DIR_ONEDNN} + ) + else() + message(STATUS "Downloading oneDNN from GitHub") + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.10 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + endif() + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "ON") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "ON") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "ON") + set(ONEDNN_VERBOSE "ON") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + # TODO: Refactor this + if (ENABLE_X86_ISA) + # Note: only enable oneDNN for AVX512 + list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512}) + else() + list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS}) + endif() + + set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) + set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size + FetchContent_MakeAvailable(oneDNN) + set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE}) + add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") + target_include_directories( + dnnl_ext + PUBLIC ${oneDNN_SOURCE_DIR}/include + PUBLIC ${oneDNN_BINARY_DIR}/include + PRIVATE ${oneDNN_SOURCE_DIR}/src + ) + target_link_libraries(dnnl_ext dnnl torch) + target_compile_options(dnnl_ext PRIVATE ${DNNL_COMPILE_FLAGS} -fPIC) + list(APPEND LIBS dnnl_ext) + set(USE_ONEDNN ON) +else() + set(USE_ONEDNN OFF) +endif() + +# TODO: Refactor this +if (ENABLE_X86_ISA) + message(STATUS "CPU extension (AVX512) compile flags: ${CXX_COMPILE_FLAGS_AVX512}") + message(STATUS "CPU extension (AVX2) compile flags: ${CXX_COMPILE_FLAGS_AVX2}") +else() + message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") +endif() + +if(ENABLE_NUMA) + list(APPEND LIBS numa) +else() + message(STATUS "NUMA is disabled") + add_compile_definitions(-DVLLM_NUMA_DISABLED) +endif() + +# +# Generate CPU attention dispatch header +# +message(STATUS "Generating CPU attention dispatch header") +execute_process( + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/csrc/cpu/generate_cpu_attn_dispatch.py + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/csrc/cpu + RESULT_VARIABLE GEN_RESULT +) +if(NOT GEN_RESULT EQUAL 0) + message(FATAL_ERROR "Failed to generate CPU attention dispatch header") +endif() + +# +# _C extension +# +set(VLLM_EXT_SRC + "csrc/cpu/activation.cpp" + "csrc/cpu/utils.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/mla_decode.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/torch_bindings.cpp") + +if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/shm.cpp" + ${VLLM_EXT_SRC}) +endif() + +if(USE_ONEDNN) + set(VLLM_EXT_SRC + "csrc/cpu/dnnl_kernels.cpp" + ${VLLM_EXT_SRC}) +endif() + +if (ENABLE_X86_ISA) + set(VLLM_EXT_SRC_AVX512 + "csrc/cpu/sgl-kernels/gemm.cpp" + "csrc/cpu/sgl-kernels/gemm_int8.cpp" + "csrc/cpu/sgl-kernels/gemm_fp8.cpp" + "csrc/cpu/sgl-kernels/moe.cpp" + "csrc/cpu/sgl-kernels/moe_int8.cpp" + "csrc/cpu/sgl-kernels/moe_fp8.cpp" + "csrc/cpu/shm.cpp" + "csrc/cpu/cpu_wna16.cpp" + "csrc/cpu/cpu_fused_moe.cpp" + "csrc/cpu/utils.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/dnnl_kernels.cpp" + "csrc/cpu/torch_bindings.cpp" + # TODO: Remove these files + "csrc/cpu/activation.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/mla_decode.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") + + set(VLLM_EXT_SRC_AVX2 + "csrc/cpu/utils.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/torch_bindings.cpp" + # TODO: Remove these files + "csrc/cpu/activation.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/mla_decode.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") + + message(STATUS "CPU extension (AVX512) source files: ${VLLM_EXT_SRC_AVX512}") + message(STATUS "CPU extension (AVX2) source files: ${VLLM_EXT_SRC_AVX2}") + + define_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC_AVX512} + LIBRARIES ${LIBS} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512} + USE_SABI 3 + WITH_SOABI + ) + + # For SGL kernels + target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AVX512") + # For AMX kernels + target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AMXBF16") + + define_extension_target( + _C_AVX2 + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC_AVX2} + LIBRARIES ${LIBS} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX2} + USE_SABI 3 + WITH_SOABI + ) +else() + message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") + # + # Define extension targets + # + define_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + LIBRARIES ${LIBS} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + USE_SABI 3 + WITH_SOABI + ) +endif() + +message(STATUS "Enabling C extension.") diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake new file mode 100644 index 0000000000000000000000000000000000000000..0f16b9161fa3ca17faaad664b344d4a5d623f12e --- /dev/null +++ b/cmake/external_projects/flashmla.cmake @@ -0,0 +1,186 @@ +include(FetchContent) + +# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory +# instead of downloading. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{FLASH_MLA_SRC_DIR}) + set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR}) +endif() + +if(FLASH_MLA_SRC_DIR) + FetchContent_Declare( + flashmla + SOURCE_DIR ${FLASH_MLA_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + flashmla + GIT_REPOSITORY https://github.com/vllm-project/FlashMLA + GIT_TAG 692917b1cda61b93ac9ee2d846ec54e75afe87b1 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +endif() + + +FetchContent_MakeAvailable(flashmla) +message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") + +# Vendor FlashMLA interface into vLLM with torch-ops shim. +set(FLASHMLA_VENDOR_DIR "${CMAKE_SOURCE_DIR}/vllm/third_party/flashmla") +file(MAKE_DIRECTORY "${FLASHMLA_VENDOR_DIR}") +file(READ "${flashmla_SOURCE_DIR}/flash_mla/flash_mla_interface.py" + FLASHMLA_INTERFACE_CONTENT) +string(REPLACE "import flash_mla.cuda as flash_mla_cuda" + "import vllm._flashmla_C\nflash_mla_cuda = torch.ops._flashmla_C" + FLASHMLA_INTERFACE_CONTENT + "${FLASHMLA_INTERFACE_CONTENT}") +file(WRITE "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py" + "${FLASHMLA_INTERFACE_CONTENT}") + +# Install the generated flash_mla_interface.py to the wheel +# Use COMPONENT _flashmla_C to ensure it's installed with the C extension +install(FILES "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py" + DESTINATION vllm/third_party/flashmla/ + COMPONENT _flashmla_C) + +# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. +# Only build FlashMLA kernels if we are building for something compatible with +# sm90a + +set(SUPPORT_ARCHS) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3) + list(APPEND SUPPORT_ARCHS "9.0a") +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9) + # CUDA 12.9 has introduced "Family-Specific Architecture Features" + # this supports all compute_10x family + list(APPEND SUPPORT_ARCHS "10.0f") +elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + list(APPEND SUPPORT_ARCHS "10.0a") +endif() + + +cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}") +if(FLASH_MLA_ARCHS) + message(STATUS "FlashMLA CUDA architectures: ${FLASH_MLA_ARCHS}") + set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS}) + list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math") + + set(FlashMLA_SOURCES + ${flashmla_SOURCE_DIR}/csrc/torch_api.cpp + + # Misc kernels for decoding + ${flashmla_SOURCE_DIR}/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu + ${flashmla_SOURCE_DIR}/csrc/smxx/decode/combine/combine.cu + + # sm90 dense decode + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/instantiations/fp16.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/instantiations/bf16.cu + + # sm90 sparse decode + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu + + # sm90 sparse prefill + ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu + + # sm100 dense prefill & backward + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu + + # sm100 sparse prefill + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu + + # sm100 sparse decode + ${flashmla_SOURCE_DIR}/csrc/sm100/decode/head64/instantiations/v32.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/decode/head64/instantiations/model1.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu + ) + + set(FlashMLA_Extension_SOURCES + ${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu + ) + + set(FlashMLA_INCLUDES + ${flashmla_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/kerutils/include + ${flashmla_SOURCE_DIR}/csrc/sm90 + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include + ) + + set(FlashMLA_Extension_INCLUDES + ${flashmla_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/ + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include + ) + + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_Extension_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + + define_extension_target( + _flashmla_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FlashMLA_SOURCES} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} + USE_SABI 3 + WITH_SOABI) + + # Keep Stable ABI for the module, but *not* for CUDA/C++ files. + # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles. + # Also enable C++20 for the FlashMLA sources (required for std::span, requires, etc.) + target_compile_options(_flashmla_C PRIVATE + $<$:-UPy_LIMITED_API> + $<$:-UPy_LIMITED_API> + $<$:-std=c++20> + $<$:-std=c++20>) + + define_extension_target( + _flashmla_extension_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FlashMLA_Extension_SOURCES} + COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES} + USE_SABI 3 + WITH_SOABI) + + # Keep Stable ABI for the module, but *not* for CUDA/C++ files. + # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles. + target_compile_options(_flashmla_extension_C PRIVATE + $<$:-UPy_LIMITED_API> + $<$:-UPy_LIMITED_API>) +else() + message(STATUS "FlashMLA will not compile: unsupported CUDA architecture ${CUDA_ARCHS}") + # Create empty targets for setup.py on unsupported systems + add_custom_target(_flashmla_C) + add_custom_target(_flashmla_extension_C) +endif() + diff --git a/cmake/external_projects/qutlass.cmake b/cmake/external_projects/qutlass.cmake new file mode 100644 index 0000000000000000000000000000000000000000..84bb1b00c1bba0fecb96ad2193587d9e52967040 --- /dev/null +++ b/cmake/external_projects/qutlass.cmake @@ -0,0 +1,102 @@ +include(FetchContent) + +set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory") + +if(DEFINED ENV{QUTLASS_SRC_DIR}) + set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR}) +endif() + +if(QUTLASS_SRC_DIR) + FetchContent_Declare( + qutlass + SOURCE_DIR ${QUTLASS_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + qutlass + GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git + GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +endif() + +FetchContent_Populate(qutlass) + +if(NOT qutlass_SOURCE_DIR) + message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.") +endif() +message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}") + +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0f" "${CUDA_ARCHS}") +else() + cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a;10.3a" "${CUDA_ARCHS}") +endif() + +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND QUTLASS_ARCHS) + + if(QUTLASS_ARCHS MATCHES "10\\.(0a|3a|0f)") + set(QUTLASS_TARGET_CC 100) + elseif(QUTLASS_ARCHS MATCHES "12\\.0a") + set(QUTLASS_TARGET_CC 120) + else() + message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.") + endif() + + set(QUTLASS_SOURCES + ${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp + ${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu + ) + + set(QUTLASS_INCLUDES + ${qutlass_SOURCE_DIR} + ${qutlass_SOURCE_DIR}/qutlass + ${qutlass_SOURCE_DIR}/qutlass/csrc/include + ${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions + ) + + if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h") + list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}") + elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h") + list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include") + message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).") + else() + message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. " + "Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include") + endif() + + set_gencode_flags_for_srcs( + SRCS "${QUTLASS_SOURCES}" + CUDA_ARCHS "${QUTLASS_ARCHS}" + ) + + target_sources(_C PRIVATE ${QUTLASS_SOURCES}) + target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES}) + target_compile_definitions(_C PRIVATE + QUTLASS_DISABLE_PYBIND=1 + TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC} + ) + + set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS + $<$:--expt-relaxed-constexpr --use_fast_math -O3> + ) + +else() + if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8") + message(STATUS + "[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).") + else() + message(STATUS + "[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in " + "CUDA_ARCHS='${CUDA_ARCHS}'.") + endif() +endif() diff --git a/cmake/external_projects/triton_kernels.cmake b/cmake/external_projects/triton_kernels.cmake new file mode 100644 index 0000000000000000000000000000000000000000..1d8b9779c8f72624d3dc72508436dc14eb72d2dd --- /dev/null +++ b/cmake/external_projects/triton_kernels.cmake @@ -0,0 +1,53 @@ +# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels + +set(DEFAULT_TRITON_KERNELS_TAG "v3.6.0") + +# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to +# be directly set to the triton_kernels python directory. +if (DEFINED ENV{TRITON_KERNELS_SRC_DIR}) + message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}") + FetchContent_Declare( + triton_kernels + SOURCE_DIR $ENV{TRITON_KERNELS_SRC_DIR} + ) + +else() + set(TRITON_GIT "https://github.com/triton-lang/triton.git") + message (STATUS "[triton_kernels] Fetch from ${TRITON_GIT}:${DEFAULT_TRITON_KERNELS_TAG}") + FetchContent_Declare( + triton_kernels + # TODO (varun) : Fetch just the triton_kernels directory from Triton + GIT_REPOSITORY https://github.com/triton-lang/triton.git + GIT_TAG ${DEFAULT_TRITON_KERNELS_TAG} + GIT_PROGRESS TRUE + SOURCE_SUBDIR python/triton_kernels/triton_kernels + ) +endif() + +# Fetch content +FetchContent_MakeAvailable(triton_kernels) + +if (NOT triton_kernels_SOURCE_DIR) + message (FATAL_ERROR "[triton_kernels] Cannot resolve triton_kernels_SOURCE_DIR") +endif() + +if (DEFINED ENV{TRITON_KERNELS_SRC_DIR}) + set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/") +else() + set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/python/triton_kernels/triton_kernels/") +endif() + +message (STATUS "[triton_kernels] triton_kernels is available at ${TRITON_KERNELS_PYTHON_DIR}") + +add_custom_target(triton_kernels) + +# Ensure the vllm/third_party directory exists before installation +install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/triton_kernels\")") + +## Copy .py files to install directory. +install(DIRECTORY + ${TRITON_KERNELS_PYTHON_DIR} + DESTINATION + vllm/third_party/triton_kernels/ + COMPONENT triton_kernels + FILES_MATCHING PATTERN "*.py") diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake new file mode 100644 index 0000000000000000000000000000000000000000..dd184e38eb5ec0e88df349c6007a258b2333429c --- /dev/null +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -0,0 +1,104 @@ +# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target +# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the +# arches in the CUDA case (and instead set the gencodes on a per file basis) +# we need to manually set VLLM_GPU_ARCHES here. +if(VLLM_GPU_LANG STREQUAL "CUDA") + foreach(_ARCH ${CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") + endforeach() +endif() + +# +# Build vLLM flash attention from source +# +# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. +# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. +# They should be identical but if they aren't, this is a massive footgun. +# +# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3), +# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files). +# If no component is specified, vllm-flash-attn is still installed. + +# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. +# This is to enable local development of vllm-flash-attn within vLLM. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) + set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) +endif() + +if(VLLM_FLASH_ATTN_SRC_DIR) + FetchContent_Declare( + vllm-flash-attn SOURCE_DIR + ${VLLM_FLASH_ATTN_SRC_DIR} + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +else() + FetchContent_Declare( + vllm-flash-attn + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2 + GIT_PROGRESS TRUE + # Don't share the vllm-flash-attn build between build types + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +endif() + +# Make sure vllm-flash-attn install rules are nested under vllm/ +# ALL_COMPONENTS ensures the save/modify/restore runs exactly once regardless +# of how many components are being installed, avoiding double-append of /vllm/. +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS) +install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) +install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS) + +# Fetch the vllm-flash-attn library +FetchContent_MakeAvailable(vllm-flash-attn) +message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") + +# Restore the install prefix after FA's install rules +install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) + +# Install shared Python files for both FA2 and FA3 components +foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C) + # Ensure the vllm/vllm_flash_attn directory exists before installation + install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" + COMPONENT ${_FA_COMPONENT}) + + # Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py + # which are source-controlled in vllm) + install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm/vllm_flash_attn + COMPONENT ${_FA_COMPONENT} + FILES_MATCHING PATTERN "*.py" + PATTERN "__init__.py" EXCLUDE + PATTERN "flash_attn_interface.py" EXCLUDE + ) + +endforeach() + +# +# FA4 CuteDSL component +# This is a Python-only component that copies the flash_attn/cute directory +# and transforms imports to match our package structure. +# +add_custom_target(_vllm_fa4_cutedsl_C) + +# Copy flash_attn/cute directory (needed for FA4) and transform imports +# The cute directory uses flash_attn.cute imports internally, which we replace +# with vllm.vllm_flash_attn.cute to match our package structure. +install(CODE " + file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\") + foreach(SRC_FILE \${CUTE_PY_FILES}) + file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE}) + set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\") + get_filename_component(DST_DIR \${DST_FILE} DIRECTORY) + file(MAKE_DIRECTORY \${DST_DIR}) + file(READ \${SRC_FILE} FILE_CONTENTS) + string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\") + file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\") + endforeach() +" COMPONENT _vllm_fa4_cutedsl_C) diff --git a/cmake/hipify.py b/cmake/hipify.py new file mode 100644 index 0000000000000000000000000000000000000000..8504f9defee96bcd1d3f6eb2698c78061b41bcec --- /dev/null +++ b/cmake/hipify.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# +# A command line tool for running pytorch's hipify preprocessor on CUDA +# source files. +# +# See https://github.com/ROCm/hipify_torch +# and /utils/hipify/hipify_python.py +# + +import argparse +import os +import shutil + +from torch.utils.hipify.hipify_python import hipify + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Project directory where all the source + include files live. + parser.add_argument( + "-p", + "--project_dir", + help="The project directory.", + ) + + # Directory where hipified files are written. + parser.add_argument( + "-o", + "--output_dir", + help="The output directory.", + ) + + # Source files to convert. + parser.add_argument( + "sources", help="Source files to hipify.", nargs="*", default=[] + ) + + args = parser.parse_args() + + # Limit include scope to project_dir only + includes = [os.path.join(args.project_dir, "*")] + + # Get absolute path for all source files. + extra_files = [os.path.abspath(s) for s in args.sources] + + # Copy sources from project directory to output directory. + # The directory might already exist to hold object files so we ignore that. + shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) + + hipify_result = hipify( + project_directory=args.project_dir, + output_directory=args.output_dir, + header_include_dirs=[], + includes=includes, + extra_files=extra_files, + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, + ) + + hipified_sources = [] + for source in args.sources: + s_abs = os.path.abspath(source) + hipified_s_abs = ( + hipify_result[s_abs].hipified_path + if ( + s_abs in hipify_result + and hipify_result[s_abs].hipified_path is not None + ) + else s_abs + ) + hipified_sources.append(hipified_s_abs) + + assert len(hipified_sources) == len(args.sources) + + # Print hipified source files. + print("\n".join(hipified_sources)) diff --git a/cmake/utils.cmake b/cmake/utils.cmake new file mode 100644 index 0000000000000000000000000000000000000000..bdb2ba74d944d91c6a487e810cb4d9d54fbbf2f2 --- /dev/null +++ b/cmake/utils.cmake @@ -0,0 +1,548 @@ +# +# Attempt to find the python package that uses the same python executable as +# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`. +# +macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) + file(REAL_PATH ${EXECUTABLE} EXECUTABLE) + set(Python_EXECUTABLE ${EXECUTABLE}) + find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) + if (NOT Python_FOUND) + message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") + endif() + set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}") + set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN}) + if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST) + message(FATAL_ERROR + "Python version (${_VER}) is not one of the supported versions: " + "${_SUPPORTED_VERSIONS_LIST}.") + endif() + message(STATUS "Found python matching: ${EXECUTABLE}.") +endmacro() + +# +# Run `EXPR` in python. The standard output of python is stored in `OUT` and +# has trailing whitespace stripped. If an error is encountered when running +# python, a fatal message `ERR_MSG` is issued. +# +function (run_python OUT EXPR ERR_MSG) + execute_process( + COMMAND + "${Python_EXECUTABLE}" "-c" "${EXPR}" + OUTPUT_VARIABLE PYTHON_OUT + RESULT_VARIABLE PYTHON_ERROR_CODE + ERROR_VARIABLE PYTHON_STDERR + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT PYTHON_ERROR_CODE EQUAL 0) + message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}") + endif() + set(${OUT} ${PYTHON_OUT} PARENT_SCOPE) +endfunction() + +# Run `EXPR` in python after importing `PKG`. Use the result of this to extend +# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. +macro (append_cmake_prefix_path PKG EXPR) + run_python(_PREFIX_PATH + "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") + list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) +endmacro() + +# +# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set +# of CUDA source files. The names of the corresponding "hipified" sources are +# stored in `OUT_SRCS`. +# +function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS) + # + # Split into C++ and non-C++ (i.e. CUDA) sources. + # + set(SRCS ${ORIG_SRCS}) + set(CXX_SRCS ${ORIG_SRCS}) + list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$") + list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$") + + # + # Generate ROCm/HIP source file names from CUDA file names. + # Since HIP files are generated code, they will appear in the build area + # `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir. + # + set(HIP_SRCS) + foreach (SRC ${SRCS}) + string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC}) + string(REGEX REPLACE "cuda" "hip" SRC ${SRC}) + list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}") + endforeach() + + set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc) + add_custom_target( + hipify${NAME} + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS} + DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS} + BYPRODUCTS ${HIP_SRCS} + COMMENT "Running hipify on ${NAME} extension source files.") + + # Swap out original extension sources with hipified sources. + list(APPEND HIP_SRCS ${CXX_SRCS}) + set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE) +endfunction() + +# +# Get additional GPU compiler flags from torch. +# +function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) + if (${GPU_LANG} STREQUAL "CUDA") + # + # Get common NVCC flags from torch. + # + run_python(GPU_FLAGS + "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))" + "Failed to determine torch nvcc compiler flags") + + if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) + list(APPEND GPU_FLAGS "-DENABLE_FP8") + endif() + if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) + list(REMOVE_ITEM GPU_FLAGS + "-D__CUDA_NO_HALF_OPERATORS__" + "-D__CUDA_NO_HALF_CONVERSIONS__" + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" + "-D__CUDA_NO_HALF2_OPERATORS__") + endif() + + elseif(${GPU_LANG} STREQUAL "HIP") + # + # Get common HIP/HIPCC flags from torch. + # + run_python(GPU_FLAGS + "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))" + "Failed to determine torch nvcc compiler flags") + + list(APPEND GPU_FLAGS + "-DUSE_ROCM" + "-DENABLE_FP8" + "-U__HIP_NO_HALF_CONVERSIONS__" + "-U__HIP_NO_HALF_OPERATORS__" + "-Werror=unused-variable" + "-fno-gpu-rdc") + + endif() + set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE) +endfunction() + +# Find libgomp that gets shipped with PyTorch wheel and create a shim dir with: +# libgomp.so -> libgomp-.so... +# libgomp.so.1 -> libgomp-.so... +# OUTPUT: TORCH_GOMP_SHIM_DIR ("" if not found) +function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR) + set(${TORCH_GOMP_SHIM_DIR} "" PARENT_SCOPE) + + # Use run_python to locate vendored libgomp; never throw on failure. + run_python(_VLLM_TORCH_GOMP_PATH + " +import os, glob +import torch +torch_pkg = os.path.dirname(torch.__file__) +site_root = os.path.dirname(torch_pkg) + +# Search both torch.libs and torch/lib +roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')] +candidates = [] +for root in roots: + if not os.path.isdir(root): + continue + candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*'))) + +print(candidates[0] if candidates else '') +" + "failed to probe for libgomp") + + if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}") + return() + endif() + + # Create shim under the build tree + set(_shim "${CMAKE_BINARY_DIR}/gomp_shim") + file(MAKE_DIRECTORY "${_shim}") + + execute_process(COMMAND ${CMAKE_COMMAND} -E rm -f "${_shim}/libgomp.so") + execute_process(COMMAND ${CMAKE_COMMAND} -E rm -f "${_shim}/libgomp.so.1") + execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink "${_VLLM_TORCH_GOMP_PATH}" "${_shim}/libgomp.so") + execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink "${_VLLM_TORCH_GOMP_PATH}" "${_shim}/libgomp.so.1") + + set(${TORCH_GOMP_SHIM_DIR} "${_shim}" PARENT_SCOPE) +endfunction() + +# Macro for converting a `gencode` version number to a cmake version number. +macro(string_to_ver OUT_VER IN_STR) + string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) +endmacro() + +# +# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in +# `CUDA_ARCH_FLAGS`. +# +# Example: +# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" +# clear_cuda_arches(CUDA_ARCH_FLAGS) +# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75" +# CMAKE_CUDA_FLAGS="-Wall" +# +macro(clear_cuda_arches CUDA_ARCH_FLAGS) + # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` + string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS + ${CMAKE_CUDA_FLAGS}) + + # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified + # and passed back via the `CUDA_ARCHITECTURES` property. + string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS + ${CMAKE_CUDA_FLAGS}) +endmacro() + +# +# Extract unique CUDA architectures from a list of compute capabilities codes in +# the form `[]`, convert them to the form sort +# `.`, dedupes them and then sorts them in ascending order and +# stores them in `OUT_ARCHES`. +# +# Example: +# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a" +# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS) +# OUT_ARCHES="7.5;...;9.0" +function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS) + set(_CUDA_ARCHES) + foreach(_ARCH ${CUDA_ARCH_FLAGS}) + string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) + if (_COMPUTE) + set(_COMPUTE ${CMAKE_MATCH_1}) + endif() + + string_to_ver(_COMPUTE_VER ${_COMPUTE}) + list(APPEND _CUDA_ARCHES ${_COMPUTE_VER}) + endforeach() + + list(REMOVE_DUPLICATES _CUDA_ARCHES) + list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING) + set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE) +endfunction() + +# +# For a specific file set the `-gencode` flag in compile options conditionally +# for the CUDA language. +# +# Example: +# set_gencode_flag_for_srcs( +# SRCS "foo.cu" +# ARCH "compute_75" +# CODE "sm_75") +# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for +# `foo.cu` (only for the CUDA language). +# +macro(set_gencode_flag_for_srcs) + set(options) + set(oneValueArgs ARCH CODE) + set(multiValueArgs SRCS) + cmake_parse_arguments(arg "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE}) + set_property( + SOURCE ${arg_SRCS} + APPEND PROPERTY + COMPILE_OPTIONS "$<$:${_FLAG}>" + ) + + message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}") +endmacro(set_gencode_flag_for_srcs) + +# +# For a list of source files set the `-gencode` flags in the files specific +# compile options (specifically for the CUDA language). +# +# arguments are: +# SRCS: list of source files +# CUDA_ARCHS: list of CUDA architectures in the form `.[letter]` +# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built +# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS +# that is larger than BUILD_PTX_FOR_ARCH. +# +macro(set_gencode_flags_for_srcs) + set(options) + set(oneValueArgs BUILD_PTX_FOR_ARCH) + set(multiValueArgs SRCS CUDA_ARCHS) + cmake_parse_arguments(arg "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + foreach(_ARCH ${arg_CUDA_ARCHS}) + # handle +PTX suffix: generate both sm and ptx codes if requested + string(FIND "${_ARCH}" "+PTX" _HAS_PTX) + if(NOT _HAS_PTX EQUAL -1) + string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}") + string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "compute_${_STRIPPED_ARCH}") + else() + string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + endif() + endforeach() + + if (${arg_BUILD_PTX_FOR_ARCH}) + list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH) + if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH}) + string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_PTX_ARCH}" + CODE "compute_${_PTX_ARCH}") + endif() + endif() +endmacro() + +# +# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form +# `.[letter]` compute the "loose intersection" with the +# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in +# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there +# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the +# architecture in `SRC_CUDA_ARCHS`. +# The loose intersection is defined as: +# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } +# where `<=` is the version comparison operator. +# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version +# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. +# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is +# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add +# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). +# The result is stored in `OUT_CUDA_ARCHS`. +# +# Example: +# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a" +# TGT_CUDA_ARCHS="8.0;8.9;9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" +# +# Example With PTX: +# SRC_CUDA_ARCHS="8.0+PTX" +# TGT_CUDA_ARCHS="9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0+PTX" +# +function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) + set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}") + set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS}) + + # handle +PTX suffix: separate base arch for matching, record PTX requests + set(_PTX_ARCHS) + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "\\+PTX$") + string(REPLACE "+PTX" "" _base "${_arch}") + list(APPEND _PTX_ARCHS "${_base}") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + list(APPEND _SRC_CUDA_ARCHS "${_base}") + endif() + endforeach() + list(REMOVE_DUPLICATES _PTX_ARCHS) + list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) + + # If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should + # remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS + set(_CUDA_ARCHS) + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "[af]$") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + string(REGEX REPLACE "[af]$" "" _base "${_arch}") + if ("${_base}" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") + list(APPEND _CUDA_ARCHS "${_arch}") + endif() + endif() + endforeach() + + list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + + # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that + # is less or equal to ARCH (but has the same major version since SASS binary + # compatibility is only forward compatible within the same major version). + foreach(_ARCH ${_TGT_CUDA_ARCHS}) + set(_TMP_ARCH) + # Extract the major version of the target arch + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") + foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS}) + # Extract the major version of the source arch + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") + # Check version-less-or-equal, and allow PTX arches to match across majors + if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) + if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) + set(_TMP_ARCH "${_SRC_ARCH}") + endif() + else() + # If we hit a version greater than the target, we can break + break() + endif() + endforeach() + + # If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS + if (_TMP_ARCH) + list(APPEND _CUDA_ARCHS "${_TMP_ARCH}") + endif() + endforeach() + + list(REMOVE_DUPLICATES _CUDA_ARCHS) + + # reapply +PTX suffix to architectures that requested PTX + set(_FINAL_ARCHS) + foreach(_arch ${_CUDA_ARCHS}) + if(_arch IN_LIST _PTX_ARCHS) + list(APPEND _FINAL_ARCHS "${_arch}+PTX") + else() + list(APPEND _FINAL_ARCHS "${_arch}") + endif() + endforeach() + set(_CUDA_ARCHS ${_FINAL_ARCHS}) + + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) +endfunction() + +# +# Override the GPU architectures detected by cmake/torch and filter them by +# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in +# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set +# the architectures on a per file basis. +# +# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`. +# +macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) + set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN}) + message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}") + + if (${GPU_LANG} STREQUAL "HIP") + # + # `GPU_ARCHES` controls the `--offload-arch` flags. + # + # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list, + # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling + # "rocm_agent_enumerator" in "enable_language(HIP)" + # (in file Modules/CMakeDetermineHIPCompiler.cmake) + # + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH}) + else() + set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) + endif() + # + # Find the intersection of the supported + detected architectures to + # set the module architecture flags. + # + set(${GPU_ARCHES}) + foreach (_ARCH ${HIP_ARCHITECTURES}) + if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST) + list(APPEND ${GPU_ARCHES} ${_ARCH}) + endif() + endforeach() + + if(NOT ${GPU_ARCHES}) + message(FATAL_ERROR + "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" + " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") + endif() + endif() +endmacro() + +# +# Define a target named `MOD_NAME` for a single extension. The +# arguments are: +# +# DESTINATION - Module destination directory. +# LANGUAGE - The language for this module, e.g. CUDA, HIP, +# CXX, etc. +# SOURCES - List of source files relative to CMakeLists.txt +# directory. +# +# Optional arguments: +# +# ARCHITECTURES - A list of target architectures in cmake format. +# For GPU, refer to CMAKE_CUDA_ARCHITECTURES and +# CMAKE_HIP_ARCHITECTURES for more info. +# ARCHITECTURES will use cmake's defaults if +# not provided. +# COMPILE_FLAGS - Extra compiler flags passed to NVCC/hip. +# INCLUDE_DIRECTORIES - Extra include directories. +# LIBRARIES - Extra link libraries. +# WITH_SOABI - Generate library with python SOABI suffix name. +# USE_SABI - Use python stable api +# +# Note: optimization level/debug info is set via cmake build type. +# +function (define_extension_target MOD_NAME) + cmake_parse_arguments(PARSE_ARGV 1 + ARG + "WITH_SOABI" + "DESTINATION;LANGUAGE;USE_SABI" + "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") + + # Add hipify preprocessing step when building with HIP/ROCm. + if (ARG_LANGUAGE STREQUAL "HIP") + hipify_sources_target(ARG_SOURCES ${MOD_NAME} "${ARG_SOURCES}") + endif() + + if (ARG_WITH_SOABI) + set(SOABI_KEYWORD WITH_SOABI) + else() + set(SOABI_KEYWORD "") + endif() + + run_python(IS_FREETHREADED_PYTHON + "import sysconfig; print(1 if sysconfig.get_config_var(\"Py_GIL_DISABLED\") else 0)" + "Failed to determine whether interpreter is free-threaded") + + # Free-threaded Python doesn't yet support the stable ABI (see PEP 803/809), + # so avoid using the stable ABI under free-threading only. + if (ARG_USE_SABI AND NOT IS_FREETHREADED_PYTHON) + Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}") + else() + Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}") + endif() + + if (ARG_LANGUAGE STREQUAL "HIP") + # Make this target dependent on the hipify preprocessor step. + add_dependencies(${MOD_NAME} hipify${MOD_NAME}) + # Make sure we include the hipified versions of the headers, and avoid conflicts with the ones in the original source folder + target_include_directories(${MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc + ${ARG_INCLUDE_DIRECTORIES}) + else() + target_include_directories(${MOD_NAME} PRIVATE csrc + ${ARG_INCLUDE_DIRECTORIES}) + endif() + + if (ARG_ARCHITECTURES) + set_target_properties(${MOD_NAME} PROPERTIES + ${ARG_LANGUAGE}_ARCHITECTURES "${ARG_ARCHITECTURES}") + endif() + + target_compile_options(${MOD_NAME} PRIVATE + $<$:${ARG_COMPILE_FLAGS}>) + + target_compile_definitions(${MOD_NAME} PRIVATE + "-DTORCH_EXTENSION_NAME=${MOD_NAME}") + + target_link_libraries(${MOD_NAME} PRIVATE torch ${ARG_LIBRARIES}) + + # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of + # dependencies that are not necessary and may not be installed. + if (ARG_LANGUAGE STREQUAL "CUDA") + target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES}) + else() + target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES}) + endif() + + install(TARGETS ${MOD_NAME} LIBRARY DESTINATION ${ARG_DESTINATION} COMPONENT ${MOD_NAME}) +endfunction() diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000000000000000000000000000000000000..304c0be8105fc2ee3596efd937d505ed9ba7d354 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,12 @@ +codecov: + require_ci_to_pass: false + +fixes: + # Map source code paths to repository root paths + # Wildcards match any Python version (python3.*) + - "/vllm-workspace/src/vllm/::vllm/" + - "/vllm-workspace/vllm/::vllm/" + - "/usr/local/lib/python3.*/dist-packages/vllm/::vllm/" + - "/usr/local/lib/python3.*/site-packages/vllm/::vllm/" + - "/usr/lib/python3.*/dist-packages/vllm/::vllm/" + - "/usr/lib/python3.*/site-packages/vllm/::vllm/" diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..758a777955535e0a948f63c810a5fdef4c1b1e11 --- /dev/null +++ b/csrc/activation_kernels.cu @@ -0,0 +1,587 @@ +#include +#include +#include + +#include + +#include "cuda_compat.h" +#include "cuda_vec_utils.cuh" +#include "dispatch_utils.h" + +namespace vllm { + +template +__device__ __forceinline__ scalar_t compute(const scalar_t& x, + const scalar_t& y) { + return act_first ? ACT_FN(x) * y : x * ACT_FN(y); +} + +template +__device__ __forceinline__ packed_t packed_compute(const packed_t& x, + const packed_t& y) { + return act_first ? packed_mul(PACKED_ACT_FN(x), y) + : packed_mul(x, PACKED_ACT_FN(y)); +} + +// Activation and gating kernel template. +template +__global__ void act_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const scalar_t* x_ptr = input + blockIdx.x * 2 * d; + const scalar_t* y_ptr = x_ptr + d; + scalar_t* out_ptr = out + blockIdx.x * d; + + if constexpr (use_vec) { + using cuda_t = typename CUDATypeConverter::Type; + using pvec_t = PackedVec; + + const pvec_t* x_vec = reinterpret_cast(x_ptr); + const pvec_t* y_vec = reinterpret_cast(y_ptr); + pvec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / pvec_t::NUM_ELTS; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + pvec_t x, y; + if constexpr (use_256b) { + ld256(x, &x_vec[i]); + ld256(y, &y_vec[i]); + } else { + ld128(x, &x_vec[i]); + ld128(y, &y_vec[i]); + } +#pragma unroll + for (int j = 0; j < pvec_t::NUM_ELTS; j++) { + x.elts[j] = packed_compute( + x.elts[j], y.elts[j]); + } + if constexpr (use_256b) { + st256(x, &out_vec[i]); + } else { + st128(x, &out_vec[i]); + } + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = compute(x, y); + } + } +} + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ packed_t packed_silu_kernel(const packed_t& val) { + // x * sigmoid(x) + float2 fval = cast_to_float2(val); + fval.x = fval.x / (1.0f + expf(-fval.x)); + fval.y = fval.y / (1.0f + expf(-fval.y)); + return cast_to_packed(fval); +} + +template +__device__ __forceinline__ T gelu_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'none' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 + const float f = (float)x; + constexpr float ALPHA = M_SQRT1_2; + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); +} + +template +__device__ __forceinline__ packed_t packed_gelu_kernel(const packed_t& val) { + // Equivalent to PyTorch GELU with 'none' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 + constexpr float ALPHA = M_SQRT1_2; + float2 fval = cast_to_float2(val); + fval.x = fval.x * 0.5f * (1.0f + ::erf(fval.x * ALPHA)); + fval.y = fval.y * 0.5f * (1.0f + ::erf(fval.y * ALPHA)); + return cast_to_packed(fval); +} + +template +__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'tanh' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 + const float f = (float)x; + constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; + constexpr float KAPPA = 0.044715; + float x_cube = f * f * f; + float inner = BETA * (f + KAPPA * x_cube); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); +} + +template +__device__ __forceinline__ packed_t +packed_gelu_tanh_kernel(const packed_t& val) { + // Equivalent to PyTorch GELU with 'tanh' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 + float2 fval = cast_to_float2(val); + constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; + constexpr float KAPPA = 0.044715; + + float x_cube = fval.x * fval.x * fval.x; + float inner = BETA * (fval.x + KAPPA * x_cube); + fval.x = 0.5f * fval.x * (1.0f + ::tanhf(inner)); + + x_cube = fval.y * fval.y * fval.y; + inner = BETA * (fval.y + KAPPA * x_cube); + fval.y = 0.5f * fval.y * (1.0f + ::tanhf(inner)); + return cast_to_packed(fval); +} + +} // namespace vllm + +// Launch activation and gating kernel. +// Use ACT_FIRST (bool) indicating whether to apply the activation function +// first. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, true, true><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, true, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, + true); +} + +void mul_and_silu(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + // The difference between mul_and_silu and silu_and_mul is that mul_and_silu + // applies the silu to the latter half of the input. + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, + false); +} + +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel, + true); +} + +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, + vllm::packed_gelu_tanh_kernel, true); +} + +namespace vllm { + +template +__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { + const float f = (float)x; + return (T)(f > threshold ? f : 0.0f); +} + +template +__device__ __forceinline__ packed_t +packed_fatrelu_kernel(const packed_t& val, const float threshold) { + float2 fval = cast_to_float2(val); + fval.x = fval.x > threshold ? fval.x : 0.0f; + fval.y = fval.y > threshold ? fval.y : 0.0f; + return cast_to_packed(fval); +} + +template +__global__ void act_and_mul_kernel_with_param( + scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, + const float param) { + const scalar_t* x_ptr = input + blockIdx.x * 2 * d; + const scalar_t* y_ptr = x_ptr + d; + scalar_t* out_ptr = out + blockIdx.x * d; + + if constexpr (use_vec) { + using cuda_t = typename CUDATypeConverter::Type; + using pvec_t = PackedVec; + + const pvec_t* x_vec = reinterpret_cast(x_ptr); + const pvec_t* y_vec = reinterpret_cast(y_ptr); + pvec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / pvec_t::NUM_ELTS; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + pvec_t x, y; + if constexpr (use_256b) { + ld256(x, &x_vec[i]); + ld256(y, &y_vec[i]); + } else { + ld128(x, &x_vec[i]); + ld128(y, &y_vec[i]); + } +#pragma unroll + for (int j = 0; j < pvec_t::NUM_ELTS; j++) { + x.elts[j] = packed_mul(PACKED_ACT_FN(x.elts[j], param), y.elts[j]); + } + if constexpr (use_256b) { + st256(x, &out_vec[i]); + } else { + st128(x, &out_vec[i]); + } + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = ACT_FN(x, param) * y; + } + } +} + +template +__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up, + float alpha, float limit) { + // Clamp gate to (-inf, limit] and up to [-limit, limit] + const float g = fminf((float)gate, limit); + const float u = fmaxf(fminf((float)up, limit), -limit); + // glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu + return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha))); +} + +// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...]. +template +__global__ void swigluoai_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2 * d] (interleaved) + const int d, const float alpha, const float limit) { + // For interleaved data: input has 2*d elements per token (gate/up pairs) + // output has d elements per token + constexpr int VEC_SIZE = 16 / sizeof(scalar_t); + constexpr int PAIRS = VEC_SIZE / 2; // Number of gate/up pairs per int4 load + const int64_t token_idx = blockIdx.x; + const scalar_t* in_ptr = input + token_idx * 2 * d; + scalar_t* out_ptr = out + token_idx * d; + + // Check alignment for 128-bit vectorized access on input. + // For output we use int2 (64-bit) which has 8-byte alignment requirement. + const bool in_aligned = is_16byte_aligned(in_ptr); + const bool out_aligned = + (reinterpret_cast(out_ptr) & 7) == 0; // 8-byte for int2 + + if (in_aligned && out_aligned && d >= PAIRS) { + // Fast path: vectorized loop + // Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs + // Each int2 store writes PAIRS output elements + const int4* in_vec = reinterpret_cast(in_ptr); + int2* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / PAIRS; + const int vec_end = num_vecs * PAIRS; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + int4 v = VLLM_LDG(&in_vec[i]); + int2 r; + auto* vp = reinterpret_cast(&v); + auto* rp = reinterpret_cast(&r); +#pragma unroll + for (int j = 0; j < PAIRS; j++) { + rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit); + } + out_vec[i] = r; + } + // Scalar cleanup for remaining elements + for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) { + out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]), + VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit); + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + // gate = x[..., ::2] (even indices) + const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]); + // up = x[..., 1::2] (odd indices) + const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]); + out_ptr[idx] = ACT_FN(gate, up, alpha, limit); + } + } +} + +} // namespace vllm + +#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PACKED_KERNEL, PARAM) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES( \ + dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL< \ + typename vllm::PackedTypeConverter::Type>, \ + true, true><<>>( \ + out.data_ptr(), input.data_ptr(), d, \ + PARAM); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES( \ + dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL< \ + typename vllm::PackedTypeConverter::Type>, \ + true, false><<>>( \ + out.data_ptr(), input.data_ptr(), d, \ + PARAM); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + false><<>>( \ + out.data_ptr(), input.data_ptr(), d, PARAM); \ + }); \ + } + +#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ + vllm::swigluoai_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, ALPHA, \ + LIMIT); \ + }); + +void fatrelu_and_mul(torch::Tensor& out, // [..., d], + torch::Tensor& input, // [..., 2 * d] + double threshold) { + LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM( + vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold); +} +void swigluoai_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + double alpha, double limit) { + LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit); +} +namespace vllm { + +// Element-wise activation kernel template. +template +__global__ void activation_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { + const scalar_t* in_ptr = input + blockIdx.x * d; + scalar_t* out_ptr = out + blockIdx.x * d; + + if constexpr (use_vec) { + // Fast path: 128-bit/256-bit vectorized loop + using vec_t = typename VecTraits::vec_t; + constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; + constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(scalar_t); + const vec_t* in_vec = reinterpret_cast(in_ptr); + vec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / VEC_SIZE; + + for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { + vec_t v; + if constexpr (use_256b) { + ld256(v, &in_vec[i]); + } else { + v = VLLM_LDG(&in_vec[i]); + } + auto* vp = reinterpret_cast(&v); +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + vp[j] = ACT_FN(vp[j]); + } + if constexpr (use_256b) { + st256(v, &out_vec[i]); + } else { + out_vec[i] = v; + } + } + } else { + // Scalar fallback for unaligned data or small d + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&in_ptr[idx]); + out_ptr[idx] = ACT_FN(x); + } + } +} + +} // namespace vllm + +// Launch element-wise activation kernel. +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, true> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, false> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, false> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); \ + } + +namespace vllm { + +template +__device__ __forceinline__ T gelu_new_kernel(const T& x) { + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); +} + +template +__device__ __forceinline__ T gelu_fast_kernel(const T& x) { + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); +} + +template +__device__ __forceinline__ T gelu_quick_kernel(const T& x) { + // x * sigmoid(1.702 * x) + return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x))); +} + +} // namespace vllm + +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); +} + +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); +} + +void gelu_quick(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); +} diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h new file mode 100644 index 0000000000000000000000000000000000000000..64f86381d9db902a6ff04ebe9520d332d40ff1ff --- /dev/null +++ b/csrc/attention/attention_dtypes.h @@ -0,0 +1,7 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float16.cuh" +#include "dtype_float32.cuh" +#include "dtype_bfloat16.cuh" +#include "dtype_fp8.cuh" diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh new file mode 100644 index 0000000000000000000000000000000000000000..62409c0cce93e696cebcb69cb7b34526d6b26a47 --- /dev/null +++ b/csrc/attention/attention_generic.cuh @@ -0,0 +1,65 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..052ff168cec4fe15d60711c7f8bf215043ea60b0 --- /dev/null +++ b/csrc/attention/attention_kernels.cuh @@ -0,0 +1,670 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" +#include "../cuda_compat.h" + +#ifdef USE_ROCM + #include + #include "../quantization/w8a8/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh" +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return VLLM_SHFL_SYNC(sum, 0); +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, *k_scale); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using V_quant_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + *v_scale); + } + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +// Grid: (num_heads, num_seqs, 1). +template +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs). +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = VLLM_SHFL_SYNC(max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..826b0edffae67f772828aefcd44f8a073bf892b9 --- /dev/null +++ b/csrc/attention/attention_utils.cuh @@ -0,0 +1,57 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../cuda_compat.h" +#include "attention_dtypes.h" + +#include +#include + +namespace vllm { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = vllm::fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += VLLM_SHFL_XOR_SYNC(qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace vllm diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh new file mode 100644 index 0000000000000000000000000000000000000000..97a25baa1fc0de977f3068a7a6a901d27fcfa6ad --- /dev/null +++ b/csrc/attention/dtype_bfloat16.cuh @@ -0,0 +1,463 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +#include + +namespace vllm { + +// Define custom BF16 vector data types. +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +// BF16 vector types for Q, K, V. +template <> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template <> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template <> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template <> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template <> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat1622float2(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat162bfloat162(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +// Vector addition. +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + #ifndef USE_ROCM + return a + b; + #else + return __hadd(a, b); + #endif +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hadd2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template <> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template <> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template <> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template <> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template <> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template <> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(a, b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(bf162bf162(a), b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template <> +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +template <> +inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); +} + +template <> +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} + +// From float32 to bfloat16. +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst = __float22bfloat162_rn(src); +#endif +} + +inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#endif +} + +inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#endif +} + +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +// Zero-out a variable. +inline __device__ void zero(__nv_bfloat16& dst) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2. + dst = __ushort_as_bfloat16((unsigned short)0x0000U); +#endif +} + +} // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3a1815f0ed4fc4706840d0136abfe7f96b6fd48a --- /dev/null +++ b/csrc/attention/dtype_float16.cuh @@ -0,0 +1,504 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifdef USE_ROCM + #include +#endif + +#include + +namespace vllm { + +// FP16 vector types for Q, K, V. +template <> +struct Vec { + using Type = uint16_t; +}; +template <> +struct Vec { + using Type = uint32_t; +}; +template <> +struct Vec { + using Type = uint2; +}; +template <> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + return tmp.u32; +#endif +} + +inline __device__ float half_to_float(uint16_t h) { + float f; +#ifndef USE_ROCM + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); +#else + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); +#endif + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#else + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); +#endif + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif +#else + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template <> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#ifndef USE_ROCM + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); +#endif + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template <> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template <> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template <> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } + +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +// Zero-out a variable. +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } + +} // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7c6a686db3ba94f114bb965b6a7c94c6a71ecdb7 --- /dev/null +++ b/csrc/attention/dtype_float32.cuh @@ -0,0 +1,251 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template <> +struct Vec { + using Type = float; +}; +template <> +struct Vec { + using Type = float2; +}; +template <> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { return a + b; } + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { return a * b + c; } + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template <> +inline __device__ float sum(float v) { + return v; +} + +template <> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template <> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template <> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template <> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { return a * b; } + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { dst = src; } + +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } + +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } + +// From float to float. +inline __device__ float to_float(float u) { return u; } + +inline __device__ float2 to_float(float2 u) { return u; } + +inline __device__ float4 to_float(float4 u) { return u; } + +inline __device__ Float4_ to_float(Float4_ u) { return u; } + +inline __device__ Float8_ to_float(Float8_ u) { return u; } + +// Zero-out a variable. +inline __device__ void zero(float& dst) { dst = 0.f; } + +} // namespace vllm diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e714e321b0beb2bd4b03bdabbdcd118502ccea46 --- /dev/null +++ b/csrc/attention/dtype_fp8.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 + +namespace vllm { + +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> +struct Vec { + using Type = uint8_t; +}; + +template <> +struct Vec { + using Type = uint16_t; +}; + +template <> +struct Vec { + using Type = uint32_t; +}; + +template <> +struct Vec { + using Type = uint2; +}; + +} // namespace vllm diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu new file mode 100644 index 0000000000000000000000000000000000000000..27d1e990c611e0c8b3dde41c40530e7c87741ea1 --- /dev/null +++ b/csrc/attention/merge_attn_states.cu @@ -0,0 +1,209 @@ +#include +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +namespace vllm { + +// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +// can be used to combine partial attention results (in the split-KV case) +template +__global__ void merge_attn_states_kernel( + scalar_t* output, float* output_lse, const scalar_t* prefix_output, + const float* prefix_lse, const scalar_t* suffix_output, + const float* suffix_lse, const uint num_tokens, const uint num_heads, + const uint head_size, const uint prefix_head_stride, + const uint output_head_stride) { + using pack_128b_t = uint4; + const uint pack_size = 16 / sizeof(scalar_t); + const uint threads_per_head = head_size / pack_size; + + const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; + const uint token_head_threads = num_tokens * num_heads * threads_per_head; + + if (global_idx >= token_head_threads) return; + + // global_idx -> token_idx + head_idx + pack_idx + const uint token_head_idx = global_idx / threads_per_head; + const uint pack_idx = global_idx % threads_per_head; + + const uint token_idx = token_head_idx / num_heads; + const uint head_idx = token_head_idx % num_heads; + + const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. + const uint src_head_offset = token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride; + const uint dst_head_offset = token_idx * num_heads * output_head_stride + + head_idx * output_head_stride; + const scalar_t* prefix_head_ptr = prefix_output + src_head_offset; + const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; + scalar_t* output_head_ptr = output + dst_head_offset; + + float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; + float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; + p_lse = std::isinf(p_lse) ? -std::numeric_limits::infinity() : p_lse; + s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; + + const float max_lse = fmaxf(p_lse, s_lse); + + /* In certain edge cases, MLA can produce p_lse = s_lse = -inf; + continuing the pipeline then yields NaN. Root cause: with chunked prefill + a batch may be split into two chunks; if a request in that batch has no + prefix hit, every LSE entry for that request’s position is -inf, and at + this moment we merge cross-attention at first. For now we simply emit + prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix + this problem. + */ + if (std::isinf(max_lse)) { + if (pack_offset < head_size) { + // Pack 128b load + pack_128b_t p_out_pack = reinterpret_cast( + prefix_head_ptr)[pack_offset / pack_size]; + + // Pack 128b storage + reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = + p_out_pack; + } + // We only need to write to output_lse once per head. + if (output_lse != nullptr && pack_idx == 0) { + output_lse[head_idx * num_tokens + token_idx] = max_lse; + } + return; + } + + p_lse = p_lse - max_lse; + s_lse = s_lse - max_lse; + const float p_se = expf(p_lse); + const float s_se = expf(s_lse); + const float out_se = p_se + s_se; + const float p_scale = p_se / out_se; + const float s_scale = s_se / out_se; + + if (pack_offset < head_size) { + // Pack 128b load + pack_128b_t p_out_pack = reinterpret_cast( + prefix_head_ptr)[pack_offset / pack_size]; + pack_128b_t s_out_pack = reinterpret_cast( + suffix_head_ptr)[pack_offset / pack_size]; + pack_128b_t o_out_pack; + +#pragma unroll + for (uint i = 0; i < pack_size; ++i) { + // Always use float for FMA to keep high precision. + // half(uint16_t), bfloat16, float -> float. + const float p_out_f = + vllm::to_float(reinterpret_cast(&p_out_pack)[i]); + const float s_out_f = + vllm::to_float(reinterpret_cast(&s_out_pack)[i]); + // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) + const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); + // float -> half(uint16_t), bfloat16, float. + vllm::from_float(reinterpret_cast(&o_out_pack)[i], o_out_f); + } + + // Pack 128b storage + reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = + o_out_pack; + } + // We only need to write to output_lse once per head. + if (output_lse != nullptr && pack_idx == 0) { + float out_lse = logf(out_se) + max_lse; + output_lse[head_idx * num_tokens + token_idx] = out_lse; + } +} + +} // namespace vllm + +// The following macro is used to dispatch the conversion function based on +// the output data type. The FN is a macro that calls a function with +// template. +#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ + { \ + if (scalar_dtype == at::ScalarType::Float) { \ + fn(float); \ + } else if (scalar_dtype == at::ScalarType::Half) { \ + fn(uint16_t); \ + } else if (scalar_dtype == at::ScalarType::BFloat16) { \ + fn(__nv_bfloat16); \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ + } \ + } + +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ + { \ + vllm::merge_attn_states_kernel \ + <<>>( \ + reinterpret_cast(output.data_ptr()), output_lse_ptr, \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ + num_heads, head_size, prefix_head_stride, output_head_stride); \ + } + +/*@brief Merges the attention states from prefix and suffix + * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d + * + * @param output [n,h,d] The output tensor to store the merged attention states. + * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. + * @param prefix_output [n,h,d] The prefix attention states. + * @param prefix_lse [h,n] The log-sum-exp values for the prefix attention + * states. + * @param suffix_output [n,h,d] The suffix attention states. + * @param suffix_lse [h,n] The log-sum-exp values for the suffix attention + * states. + */ +template +void merge_attn_states_launcher(torch::Tensor& output, + std::optional output_lse, + const torch::Tensor& prefix_output, + const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, + const torch::Tensor& suffix_lse) { + constexpr uint NUM_THREADS = 128; + const uint num_tokens = output.size(0); + const uint num_heads = output.size(1); + const uint head_size = output.size(2); + const uint prefix_head_stride = prefix_output.stride(1); + const uint output_head_stride = output.stride(1); + const uint pack_size = 16 / sizeof(scalar_t); + TORCH_CHECK(head_size % pack_size == 0, + "headsize must be multiple of pack_size:", pack_size); + float* output_lse_ptr = nullptr; + if (output_lse.has_value()) { + output_lse_ptr = output_lse.value().data_ptr(); + } + // Process one pack elements per thread. for float, the + // pack_size is 4 for half/bf16, the pack_size is 8. + const uint threads_per_head = head_size / pack_size; + const uint total_threads = num_tokens * num_heads * threads_per_head; + + dim3 block(NUM_THREADS); + dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + + const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); +} + +#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ + { \ + merge_attn_states_launcher(output, output_lse, prefix_output, \ + prefix_lse, suffix_output, \ + suffix_lse); \ + } + +void merge_attn_states(torch::Tensor& output, + std::optional output_lse, + const torch::Tensor& prefix_output, + const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, + const torch::Tensor& suffix_lse) { + DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); +} diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2d4b4a67d242168dc36d4da63f56a50bc36cd9c2 --- /dev/null +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ + +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +// clang-format off +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_mla_reduction.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +using namespace cute; +using namespace cutlass::fmha::kernel; + + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class Kernel_ +> +class MLA { +public: + + using Kernel = Kernel_; + + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< + typename Kernel::ElementOut, + typename Kernel::ElementAcc, + typename Kernel::ElementAcc, + Kernel::TileShapeH::value, + Kernel::TileShapeL::value, + 256 /*Max split*/ + >; + + /// Argument structure: User API + using KernelArguments = typename Kernel::Arguments; + using ReductionArguments = typename ReductionKernel::Arguments; + + using Arguments = KernelArguments; + + /// Argument structure: Kernel API + using KernelParams = typename Kernel::Params; + using ReductionParams = typename ReductionKernel::Params; + struct Params { + KernelParams fmha_params; + ReductionParams reduction_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + static ReductionArguments to_reduction_args(Arguments const& args) { + auto [H, K, D, B] = args.problem_shape; + return ReductionArguments{ + nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, + args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, + args.ptr_split_kv, Kernel::TileShapeS::value + }; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + static void set_split_kv (KernelArguments& args) { + if (args.split_kv >= 1) return; + auto [H, K, D, B] = args.problem_shape; + int sm_count = args.hw_info.sm_count; + float seq_length_k = static_cast(K) / 1024.0f; + int max_splits = 1; + + if (B <= 4 && seq_length_k >= 16) { + max_splits = 16; + } + else if (B <= 8 && seq_length_k >= 4) { + max_splits = 8; + } + else if ((B <= 16 && seq_length_k >= 8) || + (B == 48 && seq_length_k >= 32)) { + max_splits = 4; + } + else if ((B <= 32 && seq_length_k >= 16) || + (B == 96 && seq_length_k >= 16)) { + max_splits = 2; + } + else { + max_splits = 1; + } + + // Wave-aware scheduling: ensure integer number of waves in K dimension + int sms_per_batch = max(1, sm_count / B); + int split_heur = min(max_splits, sms_per_batch); + int waves = ceil_div(B * split_heur, sm_count); + int k_waves = ceil_div(max_splits, split_heur); + int split_wave_aware = ceil_div(max_splits, k_waves); + args.split_kv = split_wave_aware; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (! Kernel::can_implement(args)) { + return Status::kInvalid; + } + if (! ReductionKernel::can_implement(to_reduction_args(args))) { + return Status::kInvalid; + } + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); + return workspace_bytes; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); + if (status != Status::kSuccess) { + return status; + } + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {kernel_params, reduction_params}; + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + // no dynamic smem is needed for reduction kernel + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {fmha_params, reduction_params}; + + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms.fmha_params}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params.fmha_params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess != result or Status::kSuccess != launch_result) { + //return Status::kSuccess; + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + if (params.reduction_params.split_kv > 1) { + // launch reduction kernel + dim3 const block = ReductionKernel::get_block_shape(); + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); + device_kernel<<>>(params.reduction_params); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + else { + return Status::kSuccess; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b6e1dd2657da5205c6d83399a2c91cc6d216e40 --- /dev/null +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +template< + class ElementOut, + class ElementAcc, + class ElementScale, + size_t kNumHeads, + size_t kHeadDimLatent, + int kMaxSplits +> +struct Sm100FmhaMlaReductionKernel { + + static const int SharedStorageSize = 0; + static const int MaxThreadsPerBlock = 128; + static const int MinBlocksPerMultiprocessor = 1; + + using ArchTag = cutlass::arch::Sm100; + + static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); + struct Arguments { + ElementAcc* ptr_oaccum = nullptr; + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_lseaccum = nullptr; + ElementAcc* ptr_lse = nullptr; + ElementScale scale = 1.f; + int num_batches = 0; + int split_kv = -1; + int dim_k = -1; + int* ptr_seq = nullptr; + int* ptr_split_kv = nullptr; + int tile_shape_s = 128; + }; + using Params = Arguments; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, + args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq, + args.ptr_split_kv, args.tile_shape_s}; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { + return 0; + } + + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return dim3(kNumHeads, 1, params.num_batches); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + static bool can_implement(Arguments const& args) { + if (args.num_batches <= 0) return false; + if (args.split_kv <= 0) return false; + return true; + } + + CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) { + if (params.split_kv <= 1) return; + auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); + + __shared__ ElementAcc sLseScale[kMaxSplits]; + const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); + const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); + + Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), + make_shape(params.split_kv), Stride>{}); + + Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), + Shape<_1>{}, Stride<_1>{}); + + auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)]; + auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; + auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); + auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); + local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + ElementAcc local_lse[kNLsePerThread]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits::infinity(); + } + + ElementAcc lse_max = -std::numeric_limits::infinity(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + lse_max = max(lse_max, local_lse[i]); + } + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset)); + } + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = __shfl_sync(0xffffffff, lse_max, 0); + + ElementAcc sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + sum_lse = sum_lse + expf(local_lse[i] - lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); + } + + sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); + + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + lse_max; + if (threadIdx.x == 0 and params.ptr_lse != nullptr) { + gLSE(0) = global_lse; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + if (split < local_split_kv) { + sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + } + __syncthreads(); + + constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; + const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); + Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), + Shape>{}, Stride<_1>{}); + ElementAcc local_val[Elements] = {0}; + for (int split = 0; split < local_split_kv; ++split) { + ElementAcc lse_scale = sLseScale[split]; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); + } + gOaccum.data() = gOaccum.data() + kHeadDimLatent; + } + auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; + Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape>{}, Stride<_1>{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast(local_val[i]); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f62c37ba4b7f86eef5ce77d02b6e6280b810508 --- /dev/null +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -0,0 +1,2023 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "gather_tensor.hpp" // from examples/common +#include "common/pow_2.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class TileShape, + class Element_, + class ElementAcc_, + class ElementOut_, + class ElementLSE_, + class TileScheduler, +#ifdef CPASYNC + bool kIsCpAsync = true +#else + bool kIsCpAsync = false +#endif +> +struct Sm100FmhaMlaKernelTmaWarpspecialized { + + using Element = Element_; + using ElementAcc = ElementAcc_; + using ElementOut = ElementOut_; + using ElementLSE = ElementLSE_; + + // only 2Sm mode is supported + static const bool kIs2Sm = true; + static const int MaxThreadsPerBlock = 256; + static const int MinBlocksPerMultiprocessor = 1; + static const int TotalSNum = 2; + static const int TotalPNum = 2; + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + + using TileShapeH = tuple_element_t<0, TileShape>; + using TileShapeS = tuple_element_t<1, TileShape>; + using TileShapeD = tuple_element_t<2, TileShape>; + + using TileShapeL = tuple_element_t<0, TileShapeD>; + using TileShapeR = tuple_element_t<1, TileShapeD>; + static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = cute::conditional_t; + + static_assert(TileShapeH{} == 128); + static const int kWarpsInN = kIs2Sm ? 2 : 1; + + static const int kNumComputeWarps = 4; + static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; + + enum class WarpRole { + kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0 + }; + + static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + + static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + static const int Alignment = 128 / sizeof_bits_v; + static const int AlignmentOut = 128 / sizeof_bits_v; + + using TileShapeQK = Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQK = IterationsQKLatent + IterationsQKRope; + + using Schedule = cute::conditional_t; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TensorStride, Alignment, + ElementAcc, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; + + // chosen for unified smem staging between K and V + using TileShapePV = Shape; + using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{})); + static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TransposeTensorStride, Alignment, + ElementAcc, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; + static_assert(std::is_same_v); + + using TiledMmaPV = typename CollectiveMmaPV::TiledMma; + + using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; + static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match"); + + static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; + + // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd + // use expect_tx for Q load + using PipelineLoadQK = cute::conditional_t, PipelineTmaUmmaAsync>; + using PipelineLoadPV = PipelineLoadQK; + // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages + using PipelineS = PipelineUmmaAsync; + // pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages + using PipelineP = PipelineUmmaConsumerAsync; + // pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage + using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>; + + using PipelinePT = PipelineAsync; + + struct PipelineStorage { + alignas(16) typename PipelineLoadQK::SharedStorage load_qk; + alignas(16) typename PipelineS::SharedStorage mma_s; + alignas(16) typename PipelineP::SharedStorage p_mma; + alignas(16) typename PipelineO::SharedStorage mma_o; + alignas(16) typename PipelinePT::SharedStorage load_page_table; + }; + + template + static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; + using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int{}, _2{}))); + + static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v); + static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + // pre-condition for overlapped smem staging + static_assert(kBytesLoadKC == kBytesLoadVC); + static_assert(StagesQK == StagesPV); + + static const int kTransactionsBytesLoadQK = kBytesLoadKC; + static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; + static const int kTransactionsBytesLoadPV = kBytesLoadVC; + + static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent + // tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + + enum class TmemAllocation : uint32_t { + kSizeS = TileShapeS::value / kWarpsInN, + // Overall + kSizeO = TileShapeL::value / kWarpsInN, + // Between accumulators we loop over + kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN, + kNumS = TotalSNum, + kNumP = TotalPNum, + kNumO = 1, + kS0 = 0, + kS1 = kS0 + kSizeS, + kO0 = kS1 + kSizeS, + kTotal = kO0 + kSizeO + }; + + static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem"); + + struct TensorStorage { + // to communicate max and row_sum + cute::array smem_exchange; + cute::array smem_page_table; + alignas(2048) cute::array> smem_q; + union { + alignas(2048) cute::array> smem_kc; + alignas(2048) cute::array> smem_vc; + }; + alignas(2048) cute::array> smem_p; + }; + + struct SharedStorage { + PipelineStorage pipelines; + TensorStorage tensors; + uint32_t tmem_base_ptr; + }; + + static const int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + struct MainloopArguments { + ElementAcc softmax_scale; + + // all tensors strides are (num_heads or seqlen, head_dim, batch) + // head_dim stride is always 1 + Element* ptr_q_latent; + TensorStride stride_q_latent; + Element* ptr_q_rope; + TensorStride stride_q_rope; + + Element* ptr_c_latent; + TensorStride stride_c_latent; + Element* ptr_k_rope; + TensorStride stride_k_rope; + + // for paged attention, we interpret what was previously [batch, seqlen] + // as [page_count, page_size], and index according to page_table + int* ptr_seq = nullptr; + int* ptr_page_table = nullptr; + // page table is [batch, seqlen or similar] + Stride<_1, int> stride_page_table = {}; + int page_count = 0; + int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + }; + + struct EpilogueArguments { + ElementOut* ptr_o = nullptr; + TensorStride stride_o; + ElementLSE* ptr_lse = nullptr; + Stride<_1, int> stride_lse; + ElementAcc output_scale = 1.0f; + }; + + struct Arguments { + // (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count) + // for paged attention, seqlen is max seqlen + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; + + struct MainloopParams { + TmaLoadQLatent tma_load_q_latent; + TmaLoadQRope tma_load_q_rope; + TmaLoadCLatent tma_load_c_latent; + TmaLoadKRope tma_load_k_rope; + TmaLoadCLatentTranspose tma_load_c_latent_transpose; + }; + + struct EpilogueParams { + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_o_acc = nullptr; + TensorStride stride_o; + TensorStride stride_o_acc; + ElementLSE* ptr_lse = nullptr; + ElementLSE* ptr_lse_acc = nullptr; + Stride<_1, int> stride_lse; + Stride<_1, int> stride_lse_acc; + ElementAcc output_scale = 1.0f; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueParams epilogue; + MainloopParams mainloop_params; + typename TileScheduler::Params tile_scheduler; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + //workspace = nullptr; // let's get an error if one of these needs workspace + + auto [H, K, D, B] = args.problem_shape; + auto [L, R] = D; + + int paged_B = B; + int paged_K = K; + if (args.mainloop.ptr_page_table != nullptr) { + paged_B = args.mainloop.page_count; + paged_K = args.mainloop.page_size; + } + + auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, L, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, L, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, R, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, R, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + + auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent); + auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( + make_shape(H, L, paged_K, paged_B), + typename CollectiveMmaPV::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, stride_c_latent_transpose, + }, nullptr); + + MainloopParams mainloop_params { + params_qk_latent.tma_load_a, + params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, + params_pv_latent.tma_load_b + }; + + EpilogueParams epilogue_params; + + epilogue_params.ptr_o = args.epilogue.ptr_o; + epilogue_params.stride_o = args.epilogue.stride_o; + epilogue_params.ptr_lse = args.epilogue.ptr_lse; + epilogue_params.stride_lse = args.epilogue.stride_lse; + epilogue_params.output_scale = args.epilogue.output_scale; + + if (args.split_kv > 1) { + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; + epilogue_params.ptr_lse_acc = ptr_lse_acc; + + epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + } + + return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params, + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv}; + } + + static size_t get_workspace_size(Arguments const& args) { + ProblemShape problem_shape = args.problem_shape; + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto split_kv = args.split_kv; + return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + } + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static bool can_implement(Arguments const& args) { + if (kIsCpAsync) { + if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + return false; + } + if (args.mainloop.page_size > TileShapeS{}) { + return false; + } + } + else { + if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + return false; + } + } + if (get<0>(args.problem_shape) != 128) { + return false; + } + if (get<1>(args.problem_shape) <= 0) { + return false; + } + if (args.split_kv <= 0) { + return false; + } + return true; + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { + + TileScheduler tile_scheduler(params.tile_scheduler); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); + bool is_mma_leader_cta = cta_coord_v == 0; + + if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + } + SharedStorage& shared_storage = *reinterpret_cast(smem_raw); + + typename PipelineLoadQK::Params pipeline_load_qk_params; + if (role == WarpRole::kLoad) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer; + } + if (role == WarpRole::kMma) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer; + } + if constexpr (kIsCpAsync) { + // we can make our life easier by unconditionally loading blocks + // since we know it'll always be legal + pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } + else { + pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; + } + pipeline_load_qk_params.initializing_warp = 0; + PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineS::Params pipeline_mma_s_params; + if (role == WarpRole::kMma) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.initializing_warp = 1; + PipelineS pipeline_mma_s( + shared_storage.pipelines.mma_s, + pipeline_mma_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineP::Params pipeline_p_mma_params; + if (role == WarpRole::kMma) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer; + } + if (role == WarpRole::kCompute) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; + } + pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.consumer_arv_count = 1; + pipeline_p_mma_params.initializing_warp = 2; + PipelineP pipeline_p_mma( + shared_storage.pipelines.p_mma, + pipeline_p_mma_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineO::Params pipeline_mma_o_params; + if (role == WarpRole::kMma) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.initializing_warp = 3; + PipelineO pipeline_mma_o( + shared_storage.pipelines.mma_o, + pipeline_mma_o_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelinePT::Params pipeline_pt_params; + if (role == WarpRole::kLoad) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer; + } + if (role == WarpRole::kLoadPageTable) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; + } + pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; + pipeline_pt_params.initializing_warp = 4; + PipelinePT pipeline_page_table( + shared_storage.pipelines.load_page_table, + pipeline_pt_params); + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_mma_s.init_masks(ClusterShape{}); + pipeline_p_mma.init_masks(ClusterShape{}); + pipeline_mma_o.init_masks(ClusterShape{}); + + typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state(); + + typename PipelineS::PipelineState pipeline_mma_s_consumer_state; + typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state(); + + typename PipelineP::PipelineState pipeline_p_mma_consumer_state; + typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state(); + + typename PipelineO::PipelineState pipeline_mma_o_consumer_state; + typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state(); + + typename PipelinePT::PipelineState pipeline_pt_consumer_state; + typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + if (role == WarpRole::kLoadPageTable) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_page_table( + blk_coord, + problem_shape, + params.mainloop, + shared_storage.tensors, + pipeline_page_table, pipeline_pt_producer_state, + local_split_kv + ); + } + } + else if (role == WarpRole::kLoad) { + if constexpr (kIsCpAsync) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_cpasync( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, pipeline_pt_consumer_state + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + if (params.mainloop.ptr_page_table != nullptr) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + } + } + else if (role == WarpRole::kMma) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + if (is_mma_leader_cta) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + mma(blk_coord, + problem_shape, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_mma_s, pipeline_mma_s_producer_state, + pipeline_p_mma, pipeline_p_mma_consumer_state, + pipeline_mma_o, pipeline_mma_o_producer_state, + local_split_kv + ); + } + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait(); + + //uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + //tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (role == WarpRole::kCompute) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + compute( + blk_coord, + problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, pipeline_mma_s_consumer_state, + pipeline_p_mma, pipeline_p_mma_producer_state, + pipeline_mma_o, pipeline_mma_o_consumer_state, + local_split_kv + ); + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + } + + cute::cluster_sync(); + cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + if (role == WarpRole::kMma) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + template + CUTLASS_DEVICE void load_page_table( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + int batch_coord = get<2>(blk_coord); + + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), + mainloop_args.stride_page_table); + auto mPT = mPT_l(_, batch_coord); + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + auto page_size = Pow2{mainloop_args.page_size}; + auto pages_per_tile = Pow2{TileShapeS{} / page_size}; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + +#if 1 + for (; k_tile_count > 0; ++k_index, --k_tile_count) { + pipeline_page_table.producer_acquire(pipeline_pt_producer_state); + + // assume a single warp + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { + int idx = i + thread_idx; + bool guard = idx < pages_per_tile; + int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int pt_idx = pages_per_tile * k_index + idx; + + cutlass::arch::cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard + ); + } + + pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_pt_producer_state; + } +#endif + } + + + struct Gather { + int& page_table_stage; + Pow2 pages_per_tile; + const int * __restrict__ smem_page_table; + + CUTLASS_DEVICE int operator()(int idx) const { + return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + } + + CUTLASS_DEVICE friend void print(Gather const&) { + printf(""); + } + + }; + + + template + CUTLASS_DEVICE void load_cpasync( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load, + typename PipelineLoadQK::PipelineState& pipeline_load_producer_state, + int const& split_kv, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using X = Underscore; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // partition all tensors + auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent); + auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; + auto paged_K = Pow2{mainloop_args.page_size}; + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + int batch_coord = get<2>(blk_coord); + auto mPT = mPT_l(_, batch_coord); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto make_copy_for = [](auto sT) { + auto rT_a = sT.layout()(_, _, _, _0{}); + auto rT = make_ordered_layout(shape(rT_a), stride(rT_a)); + auto threads = Int{}; + auto values = Int{}; + return make_cotiled_copy( + Copy_Atom, Element>{}, + make_ordered_layout( + make_shape(threads, values), + make_stride(_1{}, _0{})), + rT); + }; + + // like cute::copy, but makes sure we do all page table lookups first + auto copy_split = [](auto atom, auto src, auto dst) { + auto src_v = group_modes<1, rank_v>(src); + auto dst_v = group_modes<1, rank_v>(dst); + + auto src_v_ptrs = make_tensor(size<1>(src_v)); + for (int i = 0; i < size<1>(src_v); i++) { + src_v_ptrs(i) = &src_v(_0{}, i); + } + + + for (int i = 0; i < size<1>(src_v); i++) { + auto src_v_i = make_tensor( + make_gmem_ptr(src_v_ptrs(i)), + make_shape(shape<0>(src_v)), + make_stride(make_stride(_1{}, _0{})) + ); + atom.call(src_v_i, dst_v(_, i)); + } + }; + + auto tiled_copy_q = make_copy_for(sQ); + auto tiled_copy_kc = make_copy_for(sKC); + auto tiled_copy_vc = make_copy_for(sVC); + + auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQL = thr_copy_q.partition_S(tSgQL); + auto tQgQR = thr_copy_q.partition_S(tSgQR); + + auto tKCsKC = thr_copy_kc.partition_D(sKC); + auto tVCsVC = thr_copy_vc.partition_D(sVC); + + auto pipeline_pt_release_state = pipeline_pt_consumer_state; + + int page_table_stage = -1; + Pow2 pages_per_tile{TileShapeS{} / paged_K}; + const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + Gather gather{page_table_stage, pages_per_tile, smem_page_table}; + + auto mCL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mKR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_k_rope), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mCLT = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(_1{}, make_shape(paged_K, paged_B)), + make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + auto tKCgCL = thr_copy_kc.partition_S(tSgCL); + auto tKCgKR = thr_copy_kc.partition_S(tSgKR); + auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + auto& pipeline_acquire_state = pipeline_load_producer_state; + auto pipeline_commit_state = pipeline_acquire_state; + int pipeline_offset = 0; + + for (int i = 0; i < StagesPV; i++) { + cutlass::arch::cp_async_fence(); + } + + auto load_stage = [&](auto fn) { + pipeline_load.producer_acquire(pipeline_acquire_state); + fn(pipeline_acquire_state.index()); + cutlass::arch::cp_async_fence(); + + ++pipeline_acquire_state; + ++pipeline_offset; + + if (pipeline_offset == StagesPV - 1) { + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + }; + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + k_index += 1; + k_tile_count -= 1; + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + while (pipeline_offset > 0) { + cutlass::arch::cp_async_fence(); + + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + + cutlass::arch::cp_async_wait<0>(); + + } + + + template + CUTLASS_DEVICE void load_tma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + using X = Underscore; + + // partition all tensors + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + + int paged_B = B; + int paged_K = K; + if constexpr (kIsPaged) { + paged_B = mainloop_args.page_count; + paged_K = mainloop_args.page_size; + } + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B)); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto [tQLgQL_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQL)); + + auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition( + mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQR)); + + auto [tCLgCL_nkl, tKCsKC] = tma_partition( + mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgCL)); + + auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition( + mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgKR)); + + auto [tCLTgCLT_nkl, tVCsVC] = tma_partition( + mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), + group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT)); + + uint16_t mcast_mask = 0; + + int batch_coord = get<2>(blk_coord); + Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord); + Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord); + + auto mPT = mPT_l(_, batch_coord); + + Tensor tCLgCL = tCLgCL_nkl(_, _, _, _); + Tensor tKRgKR = tKRgKR_nkl(_, _, _, _); + + // careful: stage and k are swapped here! + Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + // perform K load + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + // prefetch next K load to keep busy while we transpose-load from cache + const int kPrefetchDistance = 1; + for (int i = 0; i < IterationsQKLatent; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + for (int i = 0; i < IterationsQKRope; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + // perform V load (k_idx - 1) + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices! + // note we are off-by-one on k_index + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + + k_index += 1; + k_tile_count -= 1; + } + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices + // note we are off-by-one on k_index + + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + } + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_producer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_consumer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // mma init + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}); + + Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); + Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); + Tensor tOrP = TiledMmaPV::make_fragment_A(sP); + Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC); + + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + Tensor tItI = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + + // Mma S0 S1 O0 S2 O1 ... Sn On-1 On + // S0 ownership -- ----- -- -- + // S1 ownership -- ----- ---- + // O ownership -- -- ---- -- + + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tItI.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tItI); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + + --k_tile_count; + } + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tItI.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tItI); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + } + + + template + CUTLASS_DEVICE void softmax( + IsLastTile const& is_last_tile, + ElementAcc& row_max, + ElementAcc& row_sum, + ElementAcc& correction_factor, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + int k_index, + uint32_t tmem_s, + int smem_p_index) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaQK tiled_mma_qk; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tStS.data() = tmem_s; + + CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); + Tensor tAcc = tStS(make_coord(_,_),_0{},_0{}); + + Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{})); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); + + Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); + const int AlignmentS = 4; + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + Tensor tTR_rAcc_vec = recast>(tTR_rAcc); + Tensor tTR_rS_vec = recast>(tTR_rS_frag); + + // load s + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + if (is_last_tile) { + for (int i = 0; i < size(tTR_rAcc); i++) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + tTR_rAcc(i) = -std::numeric_limits::infinity(); + } + } + } + + // max + ElementAcc row_max_new = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 1) { + row_max_new = ::fmax(row_max_new, tTR_rAcc(i)); + } + + // for 2x2 dp, reduce here + if constexpr (kWarpsInN > 1) { + shared_tensors.smem_exchange[threadIdx.x] = row_max_new; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + } + +#ifndef B2B + // find correction factor + ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); + row_max = row_max_new; + + // softmax + ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + } +#endif + + // quantize + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_vec); i++) { + tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); + } + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index)); + + Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); + + // have a mapping for each thread to coord + // find identical mapping to coords for the MMA + auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto sP_ = as_position_independent_swizzle_tensor(sP); + copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); + + // sum + row_sum *= correction_factor; + + static_assert(cute::is_same_v); + auto tTR_rAcc_float2 = recast(tTR_rAcc); + auto sums = make_tensor(_4{}); + static_assert(size(tTR_rAcc_float2) % size(sums) == 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(sums); i++) { + sums(i) = tTR_rAcc_float2(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j++) { + cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < size(sums); i *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j += 2*i) { + cute::add(sums(j), sums(j), sums(j+i)); + } + } + row_sum += sums(0).x + sums(0).y; + } + + + CUTLASS_DEVICE void rescale( + ElementAcc correction_factor, + uint32_t tmem_o) { + + // for b2b gemm, do nothing +#ifndef B2B + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + auto store_op = TMEM::tmem_load_to_store(load_op); + + TiledMmaPV tiled_mma_pv; + + Tensor tItI = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + tItI.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tItI) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tItI) == _1{}); + Tensor tAcc = tItI(make_coord(_,_),_0{},_0{}); + + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto tiled_r2t = make_tmem_copy(store_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + auto thread_r2t = tiled_r2t.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + // load o + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + // multiply by correction factor + float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 2) { + float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); + float2 out; + cute::mul(out, in, correction_factor_vec); + tTR_rAcc(i + 0) = out.x; + tTR_rAcc(i + 1) = out.y; + } + + // store o + copy(tiled_r2t, tTR_rAcc, tTR_tAcc); +#endif + } + + + template + CUTLASS_DEVICE void epilogue( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + uint32_t tmem_o, + int const& split_kv) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + + Tensor tItI = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + tItI.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tItI) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tItI) == _1{}); + Tensor tAcc = tItI(make_coord(_,_),_0{},_0{}); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + if (epilogue_args.ptr_o_acc != nullptr) { + using ElementOutAcc = ElementAcc; + constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + #endif + } + else { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } +#endif + } + } + + + template + CUTLASS_DEVICE void compute( + CtaCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_consumer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_producer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + + // if we return early, we have to make sure we release the load warp + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue + ).arrive_and_wait(); + + return; + } + int k_index_final = k_tile_total - 1; + + ElementAcc row_max = -std::numeric_limits::infinity(); + ElementAcc row_sum = 0; + ElementAcc correction_factor = 1; + + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + // softmax s0 -> p0 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + k_index += 1; + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + // softmax s1 -> p1 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + // rescale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + } + + cutlass::arch::fence_view_async_tmem_store(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + + --k_tile_count; + k_index += 1; + } + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + +#ifdef B2B + row_sum = 1; +#else + if constexpr (kWarpsInN > 1) { + // reduce row_sum if needed (for 2x2 dp) + shared_tensors.smem_exchange[threadIdx.x] = row_sum; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_sum += shared_tensors.smem_exchange[peer_index]; + } +#endif + + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive(); + + // epilogue + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + epilogue( + row_max, row_sum, + replace<1>(cta_coord, j), problem_shape, + mainloop_args, epilogue_args, shared_tensors, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv + ); + } + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c990ee2d856fbf8a3632276c0e26659134abb5f3 --- /dev/null +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp @@ -0,0 +1,165 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaIndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); + } + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_split_kv; + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = size<0>(cluster_shape); + int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */; + num_blocks *= split_kv; /* Maximum Split KV*/ + + return Params { + num_blocks, + { num_m_blocks}, { get<3>(problem_shape) }, {split_kv}, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, n_split_kv; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_split_kv(block_decode, n_split_kv, block_decode); + return make_coord(m_block, _0{}, bidb, n_split_kv); + } + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d1874515cc8fd53b814b24c9453872767a156c1a --- /dev/null +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -0,0 +1,291 @@ +/* +Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ +#include "core/registration.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "cutlass_sm100_mla/device/sm100_mla.hpp" +#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp" + +// clang-format off +#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 +void sm100_cutlass_mla_decode( + torch::Tensor const& out, + torch::Tensor const& lse, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, + torch::Tensor const& workspace, + double sm_scale, + int64_t num_kv_splits) { + TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); +} +int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { + TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size"); +} +#else + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +using namespace cute; +using namespace cutlass::fmha::kernel; + +template +struct IsPersistent { + static const bool value = v; +}; + +template > +struct MlaSm100 { + using Element = T; + using ElementAcc = float; + using ElementOut = TOut; + + using TileShape = Shape<_128, _128, Shape<_512, _64>>; + using TileShapeH = cute::tuple_element_t<0, TileShape>; + using TileShapeD = cute::tuple_element_t<2, TileShape>; + + // H K (D_latent D_rope) B + using ProblemShape = cute::tuple; + + using StrideQ = cute::tuple; // H D B + using StrideK = cute::tuple; // K D B + using StrideO = StrideK; // H D B + using StrideLSE = cute::tuple<_1, int>; // H B + + using TileScheduler = + std::conditional_t; + + using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< + TileShape, + Element, + ElementAcc, + ElementOut, + ElementAcc, + TileScheduler, + /*kIsCpAsync=*/!IsPaged128>; + using Fmha = cutlass::fmha::device::MLA; +}; + +template +typename T::Fmha::Arguments args_from_options( + at::Tensor const& out, + at::Tensor const& lse, + at::Tensor const& q_nope, + at::Tensor const& q_pe, + at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, + at::Tensor const& page_table, + double sm_scale, + int64_t num_kv_splits) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = q_nope.device().index(); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + int batches = q_nope.sizes()[0]; + int page_count_per_seq = page_table.sizes()[1]; + int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; + int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int max_seq_len = page_size * page_count_per_seq; + using TileShapeH = typename T::TileShapeH; + using TileShapeD = typename T::TileShapeD; + auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + float scale = float(sm_scale); + + using StrideQ = typename T::StrideQ; + using StrideK = typename T::StrideK; + using StrideO = typename T::StrideO; + using StrideLSE = typename T::StrideLSE; + + StrideQ stride_Q_nope = cute::make_tuple( + static_cast(q_nope.stride(1)), _1{}, static_cast(q_nope.stride(0))); + StrideQ stride_Q_pe = cute::make_tuple( + static_cast(q_pe.stride(1)), _1{}, static_cast(q_pe.stride(0))); + + StrideK stride_C = cute::make_tuple( + static_cast(0 + D_latent + D_rope), _1{}, static_cast(page_size * (D_latent + D_rope))); + StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); + StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H); + StrideO stride_O = cute::make_tuple(static_cast(0 + D_latent), _1{}, static_cast(0 + H * D_latent)); + + using Element = typename T::Element; + using ElementOut = typename T::ElementOut; + using ElementAcc = typename T::ElementAcc; + auto Q_nope_ptr = static_cast(q_nope.data_ptr()); + auto Q_pe_ptr = static_cast(q_pe.data_ptr()); + auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); + typename T::Fmha::Arguments arguments{ + problem_shape, + {scale, + Q_nope_ptr, + stride_Q_nope, + Q_pe_ptr, + stride_Q_pe, + C_ptr, + stride_C, + C_ptr + D_latent, + stride_C, + static_cast(seq_lens.data_ptr()), + static_cast(page_table.data_ptr()), + stride_PT, + page_count_total, + page_size}, + {static_cast(out.data_ptr()), + stride_O, + static_cast(lse.defined() ? lse.data_ptr() : nullptr), + stride_LSE}, + hw_info, + // TODO(trevor-m): Change split_kv back to -1 when + // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will + // perform worse with larger context length and smaller batch sizes. + static_cast(num_kv_splits), // split_kv + nullptr, // is_var_split_kv + }; + // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute + // split_kv automatically based on batch size and sequence length to balance + // workload across available SMs. Consider using var_split_kv for manual + // control if needed. + T::Fmha::set_split_kv(arguments); + return arguments; +} + +template +void runMla( + at::Tensor const& out, + at::Tensor const& lse, + at::Tensor const& q_nope, + at::Tensor const& q_pe, + at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, + at::Tensor const& page_table, + at::Tensor const& workspace, + double sm_scale, + int64_t num_kv_splits, + cudaStream_t stream) { + using MlaSm100Type = MlaSm100; + typename MlaSm100Type::Fmha fmha; + auto arguments = args_from_options(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); + + CUTLASS_CHECK(fmha.can_implement(arguments)); + + CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); +} + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + +void sm100_cutlass_mla_decode( + torch::Tensor const& out, + torch::Tensor const& lse, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, + torch::Tensor const& workspace, + double sm_scale, + int64_t num_kv_splits) { + auto in_dtype = q_nope.dtype(); + at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); + const int page_size = kv_c_and_k_pe_cache.sizes()[1]; + + // NOTE(alcanderian): IsPersistent has bug with manual split_kv. + // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8) + // Maybe per batch split kv will fix this. + DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { + DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { + if (in_dtype == at::ScalarType::Half) { + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + } else { + TORCH_CHECK(false, "Unsupported input data type of MLA"); + } + return true; + }); + return true; + }); +} + +int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { + // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) + // which are float, so Element type here doesn't matter. + using MlaSm100Type = MlaSm100; + + // Get split kv. Requires problem shape and sm_count only. + typename MlaSm100Type::Fmha::Arguments arguments; + using TileShapeH = typename MlaSm100Type::TileShapeH; + using TileShapeD = typename MlaSm100Type::TileShapeD; + arguments.problem_shape = + cute::make_tuple(TileShapeH{}, static_cast(max_seq_len), TileShapeD{}, static_cast(num_batches)); + // Assumes device 0 when getting sm_count. + arguments.hw_info.sm_count = + sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; + arguments.split_kv = static_cast(num_kv_splits); + MlaSm100Type::Fmha::set_split_kv(arguments); + + return MlaSm100Type::Fmha::get_workspace_size(arguments); +} + +#endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) { + m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); +} + +// clang-format on diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu new file mode 100644 index 0000000000000000000000000000000000000000..307300e556660be4a679269f87051878e634d461 --- /dev/null +++ b/csrc/attention/paged_attention_v1.cu @@ -0,0 +1,186 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "attention_kernels.cuh" +#include "../cuda_compat.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +// TODO(woosuk): Tune NUM_THREADS. +template +void paged_attention_v1_launcher( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + + const int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V1(32); + break; + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} + +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu new file mode 100644 index 0000000000000000000000000000000000000000..eb9b4feb4a892c5f0e781d581af1b4b023b3b94d --- /dev/null +++ b/csrc/attention/paged_attention_v2.cu @@ -0,0 +1,196 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "attention_kernels.cuh" +#include "../cuda_compat.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template +void paged_attention_v2_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + + const int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V2(32); + break; + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) +} + +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/vertical_slash_index.cu b/csrc/attention/vertical_slash_index.cu new file mode 100644 index 0000000000000000000000000000000000000000..c1b45b143f4e1ad11548ecd981572257482694a7 --- /dev/null +++ b/csrc/attention/vertical_slash_index.cu @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include + +#include + +__device__ int64_t save_blocks(int* block_offset, int64_t range_start, + int64_t range_end, int64_t block_size, + int64_t input_block_count, int64_t kv_seqlen) { + if (range_start >= kv_seqlen) { + return input_block_count; + } + if (range_end > kv_seqlen) { + range_end = kv_seqlen; + } + int64_t current_block_count = input_block_count; + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[current_block_count++] = idx; + } + return current_block_count; +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count, + block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, + BLOCK_SIZE_N, NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * This function builds the index of each row of blocks from vertical indices + * and slash indices. The vertical indices are treated as points, while the + * slash indices are converted as ranges. The output consists of the merged + * ranges and separate column indices, where the ranges are represented by + * block indices. + * + * The implementation is referenced from the original MInference repo: + * https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu. + */ +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + block_count.data_ptr(), block_offset.data_ptr(), + column_count.data_ptr(), column_index.data_ptr(), batch_size, + num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash, + causal); +} + +__global__ void convert_vertical_slash_indexes_kernel_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + const int* per_head_vertical_topkv, const int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + // MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S + // above is buffer size, use to compute offset) + NNZ_S = per_head_slash_topkv[head_idx]; + NNZ_V = per_head_vertical_topkv[head_idx]; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* per_head_vertical_topkv, int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel_mergehead<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, + per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset, + column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, + NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * Like the above convert_vertical_slash_indexes, but with + * pre-computed vertical and slash counts. + */ +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, // [N_HEADS, ] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64_mergehead( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + vertical_indices_count.data_ptr(), + slash_indices_count.data_ptr(), block_count.data_ptr(), + block_offset.data_ptr(), column_count.data_ptr(), + column_index.data_ptr(), batch_size, num_heads, num_rows, + block_size_M, block_size_N, nnz_vertical, nnz_slash, causal); +} diff --git a/csrc/cache.h b/csrc/cache.h new file mode 100644 index 0000000000000000000000000000000000000000..0c7823ffe9e2ed92c1469788ab571d0337cea48e --- /dev/null +++ b/csrc/cache.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include + +#include +#include + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + int64_t block_size_in_bytes, + const torch::Tensor& block_mapping); + +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); + +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); + +void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, + torch::Tensor& kv_cache, torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& scale); + +// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla +void concat_and_cache_mla_rope_fused( + torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe, + torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox, + torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache, + const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale); + +// Just for unittest +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype); + +void gather_and_maybe_dequant_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS] + int64_t num_tokens, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt); + +// TODO(hc): cp_gather_cache need support scaled kvcahe in the future. +void cp_gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); + +// Gather and upconvert FP8 KV cache to BF16 workspace +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size); + +// Indexer K quantization and cache function +void indexer_k_quant_and_cache( + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt); + +// Extract function to gather quantized K cache +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens); // [batch_size + 1] diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..3e8ffe15b42d48e0ca6a09fdf1338c99b6f96494 --- /dev/null +++ b/csrc/cache_kernels.cu @@ -0,0 +1,1367 @@ +#include +#include +#include +#include +#include + +#include "cuda_utils.h" +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "quantization/vectorization_utils.cuh" + +#ifdef USE_ROCM + #include "quantization/w8a8/fp8/amd/quant_utils.cuh" +#else + #include "quantization/w8a8/fp8/nvidia/quant_utils.cuh" +#endif + +#include +#include +#include + +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +#if defined(__gfx942__) +constexpr float kFp8ScaleDivisor = 224.f; +#else +constexpr float kFp8ScaleDivisor = 448.f; +#endif + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + int64_t block_size_in_bytes, + const torch::Tensor& block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cuda() && dst_device.is_cuda()) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + memcpy_type = cudaMemcpyDeviceToDevice; + } else if (src_device.is_cuda() && dst_device.is_cpu()) { + memcpy_type = cudaMemcpyDeviceToHost; + } else if (src_device.is_cpu() && dst_device.is_cuda()) { + memcpy_type = cudaMemcpyHostToDevice; + } else { + TORCH_CHECK(false, "Invalid device combination"); + } + + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); + + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // NOTE(woosuk): This can be slow if the number of blocks is large. + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); + } +} + +namespace vllm { + +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + + scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + +// Kernel for MLA, which works on a single joint kv_cache +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_mla_kernel( + int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping, + const int mem_footprint_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + scalar_t* cache = reinterpret_cast(cache_ptrs[layer_idx]); + int64_t src_block = block_mapping[2 * pair_idx]; + int64_t dst_block = block_mapping[2 * pair_idx + 1]; + int64_t src_offset = src_block * mem_footprint_per_block; + int64_t dst_offset = dst_block * mem_footprint_per_block; + for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) { + cache[dst_offset + i] = cache[src_offset + i]; + } +} + +} // namespace vllm + +namespace vllm { + +// Used to copy/convert one element +template +struct CopyWithScaleOp { + float scale; + + __device__ __forceinline__ void operator()(OutT& dst, const InT src) const { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst = static_cast(src); + } else { + dst = fp8::scaled_convert(src, scale); + } + } +}; + +template +__global__ void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, + const float* k_scale, const float* v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int h_block_count = head_size / x; // head_size//x + + const int h_block_idx = threadIdx.x; + if (h_block_idx >= num_heads * h_block_count) { + return; + } + + const int head_idx = h_block_idx / h_block_count; + const int h_block = h_block_idx % h_block_count; + + const scalar_t* __restrict__ key_src = + key + token_idx * key_stride + head_idx * head_size + h_block * x; + const int64_t src_value_start = + token_idx * value_stride + head_idx * head_size + h_block * x; + + cache_t* __restrict__ key_dst = + key_cache + block_idx * num_heads * h_block_count * block_size * x + + head_idx * h_block_count * block_size * x + h_block * block_size * x + + block_offset * x; + const int64_t tgt_value_start = + block_idx * num_heads * h_block_count * x * block_size + + head_idx * h_block_count * x * block_size + h_block * x * block_size + + block_offset; + + constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4; + float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale; + CopyWithScaleOp k_op{k_scale_val}; + float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale; + CopyWithScaleOp v_op{v_scale_val}; + + vectorize_with_alignment(key_src, key_dst, x, 0, 1, k_op); + + const scalar_t* __restrict__ value_src = value + src_value_start; + cache_t* __restrict__ value_dst = value_cache + tgt_value_start; +#pragma unroll + for (int i = 0; i < x; i++) { + v_op(value_dst[i * block_size], value_src[i]); + } +} + +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // NHD or HND, shape see comments below + cache_t* __restrict__ value_cache, // same above + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int64_t block_stride, const int64_t page_stride, + const int64_t head_stride, const int64_t key_stride, + const int64_t value_stride, const int num_heads, const int head_size, + const int block_size, const float* k_scale, const float* v_scale, + const int kv_scale_stride) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n_elems = num_heads * head_size; + + // pointers to the beginning of the source row for this token. + const scalar_t* __restrict__ key_src = key + token_idx * key_stride; + const scalar_t* __restrict__ value_src = value + token_idx * value_stride; + + // find the start position inside the kv-cache for this token. + cache_t* __restrict__ key_dst = + key_cache + block_idx * block_stride + block_offset * page_stride; + cache_t* __restrict__ value_dst = + value_cache + block_idx * block_stride + block_offset * page_stride; + + // this is true for the NHD layout where `head_stride == head_size` + const bool is_contiguous_heads = (head_stride == head_size); + + constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4; + + if (is_contiguous_heads && kv_scale_stride == 0) { + // NHD layout and k/v_scales are [1] (i.e. single scale for all heads) + // kv cache: [num_blocks, block_size, num_heads, head_size] + float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale; + float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale; + + CopyWithScaleOp k_op{k_scale_val}; + CopyWithScaleOp v_op{v_scale_val}; + + vectorize_with_alignment(key_src, key_dst, n_elems, threadIdx.x, + blockDim.x, k_op); + vectorize_with_alignment(value_src, value_dst, n_elems, + threadIdx.x, blockDim.x, v_op); + } else { + // HND layout OR k/v_scales are [num_heads] (i.e. per-attn-head) + // HND layout: heads are strided, but each head_size segment is contiguous + // kv cache: [num_blocks, num_heads, block_size, head_size] + const int lane = threadIdx.x & 31; // 0..31 within warp + const int warp_id = threadIdx.x >> 5; // warp index within block + const int warps_per_block = blockDim.x >> 5; + + for (int head = warp_id; head < num_heads; head += warps_per_block) { + const scalar_t* __restrict__ k_src_h = key_src + head * head_size; + const scalar_t* __restrict__ v_src_h = value_src + head * head_size; + + cache_t* __restrict__ k_dst_h = + key_dst + static_cast(head) * head_stride; + cache_t* __restrict__ v_dst_h = + value_dst + static_cast(head) * head_stride; + + float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) + ? 0.f + : k_scale[head * kv_scale_stride]; + float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) + ? 0.f + : v_scale[head * kv_scale_stride]; + + CopyWithScaleOp k_op{k_scale_val}; + CopyWithScaleOp v_op{v_scale_val}; + + // within each head, let the 32 threads of the warp perform the vector + // copy + vectorize_with_alignment(k_src_h, k_dst_h, head_size, lane, 32, + k_op); + + vectorize_with_alignment(v_src_h, v_dst_h, head_size, lane, 32, + v_op); + } + } +} + +template +__global__ void concat_and_cache_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, + int src_stride, int dst_stride, int size, int offset) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i + offset; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = + fp8::scaled_convert(src[src_idx], *scale); + } + } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); +} + +template +__global__ void concat_and_cache_ds_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int64_t dst_idx_start = + block_idx * block_stride + block_offset * entry_stride; + + // For the NoPE part, each tile of 128 elements is handled by half of one warp + // (16 threads). There are 4 total tiles, so 2 warps (64 threads). + // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. + // The RoPE part (last 64 elements) is handled by another 1 warp (32 threads). + // So in total, we use 3 warps (96 threads) per block. + + // Cast kv_cache to 16_bit for RoPE values + scalar_t* kv_cache_16bit = + reinterpret_cast(&kv_cache[dst_idx_start]); + + // The last warp handles the RoPE part + if (threadIdx.x >= 64) { + // Each thread handles two elements of RoPE + const int8_t pe_idx_start = (threadIdx.x - 64) * 2; + const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start; + // Vectorized load of two 16-bit values, performed as one 32-bit load + const int32_t vals = *reinterpret_cast(&k_pe[src_idx]); + // RoPE values start after the packed 8-bit NoPE values and the + // 32-bit scales + const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start; + // Vectorized store of two 16-bit values, performed as one 32-bit store + *reinterpret_cast(&kv_cache_16bit[dst_idx]) = vals; + return; + } + + // The first two warps handle the NoPE part + const int8_t warp_idx = threadIdx.x >> 5; + const int8_t lane_idx = threadIdx.x & 31; + const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4); + + // Each thread handles 8 elements of NoPE + // Load the NoPE elements for this thread into registers + const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8); + // Vectorized load of eight 16-bit values, performed as an int4 load + const int4 vals_i4 = *reinterpret_cast(&kv_c[src_idx_start]); + const scalar_t* vals = reinterpret_cast(&vals_i4); + + // Max absolute value of this thread's elements + float max_abs = fmaxf(fmaxf(fmaxf(fabsf(vals[0]), fabsf(vals[1])), + fmaxf(fabsf(vals[2]), fabsf(vals[3]))), + fmaxf(fmaxf(fabsf(vals[4]), fabsf(vals[5])), + fmaxf(fabsf(vals[6]), fabsf(vals[7])))); + + // Warp-level reduction to find the max absolute value in each half-warp +#pragma unroll + for (int offset = 8; offset > 0; offset /= 2) { + max_abs = fmaxf(max_abs, VLLM_SHFL_XOR_SYNC_WIDTH(max_abs, offset, 16)); + } + + // Compute the scale for the tile + float tile_scale = fmaxf(max_abs / kFp8ScaleDivisor, FLT_MIN); + + // The first lane of each half-warp writes the scale to kv_cache + if ((lane_idx == 0) || (lane_idx == 16)) { + float* kv_cache_32bit = reinterpret_cast(&kv_cache[dst_idx_start]); + const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; + kv_cache_32bit[dst_idx] = tile_scale; + } + + // Now all threads in the block scale and write their elements + // NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes) + const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8); + + uint8_t result[8]; +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = + fp8::scaled_convert( + vals[i], tile_scale); + } + + // Store as aligned 64-bit writes + *reinterpret_cast(&kv_cache[dst_idx_base]) = + *reinterpret_cast(result); +} + +template +__global__ void indexer_k_quant_and_cache_kernel( + const scalar_t* __restrict__ k, // [num_tokens, head_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int head_dim, // dimension of each head + const int quant_block_size, // quantization block size + const int cache_block_size, // cache block size + const int cache_stride, // stride for each token in kv_cache + + const bool use_ue8m0 // use ue8m0 scale format +) { + constexpr int VEC_SIZE = 4; + const int64_t token_idx = blockIdx.x; + const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x) * + VEC_SIZE; + const int64_t slot_idx = slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / cache_block_size; + const int64_t block_offset = slot_idx % cache_block_size; + + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0 || (head_dim_idx >= head_dim)) { + return; + } + + float2 k_val = (reinterpret_cast( + k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE]; + scalar_t* k_val_ptr = reinterpret_cast(&k_val); + float amax = 0.0f; + for (int i = 0; i < VEC_SIZE; i++) { + amax = fmaxf(amax, fabsf(float(k_val_ptr[i]))); + } + + // Reduced amax + for (int mask = 16; mask > 0; mask /= 2) { +#ifdef USE_ROCM + amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask)); +#else + amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask)); +#endif + } + + float scale = fmaxf(amax, 1e-4) / kFp8ScaleDivisor; + + if (use_ue8m0) { + scale = exp2f(ceilf(log2f(scale))); + } + + const int64_t dst_offset = block_idx * cache_block_size * cache_stride + + block_offset * head_dim + head_dim_idx; + for (int i = 0; i < VEC_SIZE; i++) { + kv_cache[dst_offset + i] = + fp8::scaled_convert(k_val_ptr[i], scale); + } + if (threadIdx.x == 0) { + const int64_t dst_scale_idx = + block_idx * cache_block_size * cache_stride + + cache_block_size * head_dim + + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; + reinterpret_cast(kv_cache)[dst_scale_idx / 4] = scale; + } +} + +template +__global__ void cp_gather_indexer_k_quant_cache_kernel( + const char* __restrict__ kv_cache, // [num_blocks, block_size, + // cache_stride] + char* __restrict__ dst_k, // [num_tokens, head_dim] + char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size * + // 4] + const int* __restrict__ block_table, // [batch_size, num_blocks] + const int* __restrict__ cu_seq_lens, // [batch_size + 1] + const int batch_size, // batch size + const int64_t token_stride, // stride for each token in dst_k + const int64_t head_dim, // dimension of each head + const int64_t block_stride, // stride for each block in kv_cache + const int64_t cache_token_stride, // stride for each token in kv_cache + const int64_t cache_block_size, // num_tokens for each block in kv_cache + const int num_blocks, // number of blocks + const int num_tokens, // number of tokens + const int quant_block_size // quantization block size +) { + constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + // Find batch index within a block + __shared__ int batch_idx[BLOCK_Y_SIZE]; + for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); + iter++) { + int tid = iter * blockDim.x + threadIdx.x; + if (tid < batch_size) { + const int seq_start = cu_seq_lens[tid]; + const int seq_end = cu_seq_lens[tid + 1]; + if (token_idx >= seq_start && token_idx < seq_end) { + batch_idx[threadIdx.y] = tid; + } + } + } + +#ifndef USE_ROCM + __syncwarp(); +#endif + + if (head_idx >= head_dim || token_idx >= num_tokens) { + return; + } + const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; + const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks + + inbatch_seq_idx / cache_block_size]; + const int64_t src_block_offset = block_idx * block_stride; + const int64_t cache_inblock_offset = + (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; + const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset; + const int64_t dst_inblock_offset = token_idx * token_stride + head_idx; + + reinterpret_cast(dst_k)[dst_inblock_offset / VEC_SIZE] = + reinterpret_cast(kv_cache)[src_inblock_offset / VEC_SIZE]; + ; + if (threadIdx.x == 0) { + const int64_t src_scale_offset = + src_block_offset + cache_block_size * head_dim + + cache_inblock_offset * 4 / quant_block_size; + reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = + reinterpret_cast(kv_cache)[src_scale_offset / 4]; + } +} + +} // namespace vllm + +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); + +void reshape_and_cache( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + int num_tokens = slot_mapping.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int head_div_x = head_size / x; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_div_x, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE); +} + +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, page_stride, \ + head_stride, key_stride, value_stride, num_heads, head_size, \ + block_size, reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr()), \ + kv_scale_stride); + +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, // [1] or [num_heads] + torch::Tensor& v_scale) { // [1] or [num_heads] + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(1); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int64_t block_stride = key_cache.stride(0); + int64_t page_stride = key_cache.stride(1); + int64_t head_stride = key_cache.stride(2); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + + TORCH_CHECK(k_scale.sizes() == v_scale.sizes(), + "k_scale and v_scale must have the same shape"); + TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads, + "k_scale and v_scale must be of shape [1] or [num_heads]"); + int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_FLASH); +} + +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_ds_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + +void concat_and_cache_mla( + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_tokens, pe_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + if (kv_cache_dtype == "fp8_ds_mla") { + TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"); + TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"); + TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(), + "kv_cache.size(2) must be 656 bytes for fp8_ds_mla"); + TORCH_CHECK(kv_c.itemsize() == 2, + "kv_c.itemsize() must be 2 for fp8_ds_mla"); + TORCH_CHECK(k_pe.itemsize() == 2, + "k_pe.itemsize() must be 2 for fp8_ds_mla"); + } else { + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + } + + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (kv_cache_dtype == "fp8_ds_mla") { + dim3 grid(num_tokens); + // For the NoPE part, each tile of 128 elements is handled by half of one + // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). + // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. + // The RoPE part (last 64 elements) is handled by another 1 warp (32 + // threads). So in total, we use 3 warps (96 threads) per block. + dim3 block(96); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_DS_MLA); + } else { + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA); + } +} + +namespace vllm { + +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const float scale, + const int64_t block_stride) { + const int64_t block_idx = blockIdx.x; + for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { + int64_t idx = block_idx * block_stride + i; + dst_cache[idx] = + fp8::scaled_convert(src_cache[idx], scale); + } +} + +} // namespace vllm + +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), scale, block_stride); + +// Only for testing. +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype) { + torch::Device src_device = src_cache.device(); + torch::Device dst_device = dst_cache.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + + int64_t num_blocks = src_cache.size(0); + int64_t block_stride = src_cache.stride(0); + + dim3 grid(num_blocks); + dim3 block(std::min(block_stride, int64_t(512))); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); + } +} + +namespace vllm { + +// grid is launched with dimensions (batch, num_splits) +template +__global__ void gather_and_maybe_dequant_cache( + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRIES...] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK] + const int32_t num_tokens, const int32_t block_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const float* __restrict__ scale, + const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per + // batch + constexpr int vec_size = sizeof(float4) / sizeof(scalar_t); + using ltype = vllm::vec_n_t; + using stype = vllm::vec_n_t; + // We are adding this for code readability which will be optimized out when + // build in release. + assert(CTA_SIZE == blockDim.x); + +#pragma unroll + for (int token_id = blockIdx.x; token_id < num_tokens; + token_id += gridDim.x) { + int64_t batch_id = token_to_seq[token_id]; + int64_t batch_start = cu_seq_lens[batch_id]; + int64_t batch_end = cu_seq_lens[batch_id + 1]; + int32_t batch_offset = token_id - batch_start; + + if (token_id >= batch_end) return; + int32_t offset = 0; + if (seq_starts != nullptr) { + offset = seq_starts[batch_id]; + } + batch_offset += offset; + int32_t block_table_id = batch_offset / block_size; + int32_t slot_id = batch_offset % block_size; + int32_t block_table_offset = batch_id * block_table_stride + block_table_id; + int32_t block_id = block_table[block_table_offset]; + int64_t cache_offset = + block_id * cache_block_stride + slot_id * cache_entry_stride; + constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size; + scalar_t* dst_ = dst + token_id * dst_entry_stride; + cache_t* src_ = const_cast(src_cache) + cache_offset; + +#pragma unroll + for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + reinterpret_cast(dst_)[idx] = + static_cast(reinterpret_cast(src_)[idx]); + } else { + ltype loaded_val = reinterpret_cast(src_)[idx]; + stype store_val; +#pragma unroll + for (int j = 0; j < vec_size; ++j) { + store_val.val[j] = fp8::scaled_convert( + loaded_val.val[j], *scale); + } + reinterpret_cast(dst_)[idx] = store_val; + } + } + // process tail + constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size; + dst_ = dst_ + ENTRY_SIZE - tail_cnt; + src_ = src_ + ENTRY_SIZE - tail_cnt; +#pragma unroll + for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst_[idx] = static_cast(src_[idx]); + } else { + dst_[idx] = + fp8::scaled_convert(src_[idx], *scale); + } + } + } +} + +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +// SCALAR_T is the data type of the destination tensor. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + token_to_seq.data_ptr(), num_tokens, block_size, \ + block_table_stride, cache_block_stride, cache_entry_stride, \ + dst_entry_stride, reinterpret_cast(scale.data_ptr()), \ + seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - token_to_seq contains the back mapping from token_id to batch_id +// - Optionally, seq_starts (if provided) offsets the starting block index by +// (seq_starts[bid] / page_size) +void gather_and_maybe_dequant_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS] + int64_t num_tokens, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t head_dim = dst.size(-1); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + TORCH_CHECK(head_dim == 576, + "gather_and_maybe_dequant_cache only support the head_dim to 576 " + "for better performance") + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + constexpr int32_t thread_block_size = 64; + dim3 grid(num_tokens); + dim3 block(thread_block_size); + + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); +} + +namespace vllm { + +// Gather and upconvert FP8 KV cache tokens to BF16 workspace +// Similar to cp_gather_cache but specifically for FP8->BF16 conversion +__global__ void cp_gather_and_upconvert_fp8_kv_cache( + const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + __nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ seq_lens, // [BATCH] + const int32_t* __restrict__ workspace_starts, // [BATCH] + const int32_t block_size, const int32_t head_dim, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = workspace_starts[bid]; + const int32_t seq_len = seq_lens[bid]; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths + dst += seq_start * dst_entry_stride; + + const int tid = threadIdx.x; + + // Process each token in this split + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + const uint8_t* token_ptr = + src_cache + block_id * cache_block_stride + offset * cache_entry_stride; + __nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride; + + // FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16) + const uint8_t* no_pe_ptr = token_ptr; + const float* scales_ptr = reinterpret_cast(token_ptr + 512); + const __nv_bfloat16* rope_ptr = + reinterpret_cast(token_ptr + 512 + 16); + + // Parallelize fp8 dequant (512 elements) and rope copy (64 elements) + if (tid < 512) { + // FP8 dequantization + const int tile = tid >> 7; // each tile is 128 elements + const float scale = scales_ptr[tile]; + const uint8_t val = no_pe_ptr[tid]; + dst_ptr[tid] = + fp8::scaled_convert<__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale); + } else if (tid < 576) { + // Rope copy (64 bf16 elements) + const int rope_idx = tid - 512; + dst_ptr[512 + rope_idx] = rope_ptr[rope_idx]; + } + + // Move to next token + offset += 1; + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} + +template +// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by +// block_size. +__global__ void cp_gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRY_SIZE] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts // Optional: starting offsets per + // batch +) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on it + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + if (seq_starts != nullptr) { + offset += seq_starts[bid]; + } + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * dst_entry_stride; + copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr); + offset += 1; + // bump to next block + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \ + vllm::cp_gather_cache<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting slot index by +// seq_starts[bid] +void cp_gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + TORCH_CHECK(src_cache.dtype() == dst.dtype(), + "src_cache and dst must have the same dtype"); + + const int dtype_bits = src_cache.element_size() * 8; + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + if (dtype_bits == 32) { + CALL_CP_GATHER_CACHE(uint32_t); + } else if (dtype_bits == 16) { + CALL_CP_GATHER_CACHE(uint16_t); + } else if (dtype_bits == 8) { + CALL_CP_GATHER_CACHE(uint8_t); + } else { + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); + } +} + +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t head_dim = dst.size(1); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32"); + TORCH_CHECK(workspace_starts.dtype() == torch::kInt32, + "workspace_starts must be int32"); + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == seq_lens.device(), + "src_cache and seq_lens must be on the same device"); + TORCH_CHECK(src_cache.device() == workspace_starts.device(), + "src_cache and workspace_starts must be on the same device"); + auto dtype = src_cache.scalar_type(); + TORCH_CHECK( + dtype == at::ScalarType::Byte || // uint8 + dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3 + dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2 + "src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ", + src_cache.dtype()); + TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); + TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + const uint8_t* src_ptr = nullptr; + if (dtype == at::ScalarType::Byte) { + src_ptr = src_cache.data_ptr(); + } else { + // float8_e4m3fn or float8_e5m2 + src_ptr = reinterpret_cast(src_cache.data_ptr()); + } + + // Decide on the number of splits based on the batch size + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(576); + + vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( + src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), + block_table.data_ptr(), seq_lens.data_ptr(), + workspace_starts.data_ptr(), block_size, head_dim, + block_table_stride, cache_block_stride, cache_entry_stride, + dst_entry_stride); +} + +// Macro to dispatch the kernel based on the data type. +#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::indexer_k_quant_and_cache_kernel \ + <<>>( \ + reinterpret_cast(k.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), head_dim, quant_block_size, \ + cache_block_size, cache_stride, use_ue8m0); + +void indexer_k_quant_and_cache( + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt) { + int num_tokens = k.size(0); + int head_dim = k.size(1); + int cache_block_size = kv_cache.size(1); + int cache_stride = kv_cache.size(2); + bool use_ue8m0 = scale_fmt == "ue8m0"; + + TORCH_CHECK(k.device() == kv_cache.device(), + "k and kv_cache must be on the same device"); + TORCH_CHECK(k.device() == slot_mapping.device(), + "k and slot_mapping must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, + "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 4; + dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) / + (quant_block_size * vec_size)); + dim3 block(32, vec_size); + const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + static const std::string kv_cache_dtype = "fp8_e4m3"; + DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), kv_cache_dtype, + CALL_INDEXER_K_QUANT_AND_CACHE); +} + +// Macro to dispatch the kernel based on the data amount. +#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ + vllm::cp_gather_indexer_k_quant_cache_kernel \ + <<>>( \ + reinterpret_cast(kv_cache.data_ptr()), \ + reinterpret_cast(dst_k.data_ptr()), \ + reinterpret_cast(dst_scale.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \ + kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \ + num_tokens, quant_block_size); + +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens // [batch_size + 1] +) { + int batch_size = block_table.size(0); + int num_tokens = dst_k.size(0); + int head_dim = dst_k.size(1); + int quant_block_size = head_dim * 4 / dst_scale.size(1); + + TORCH_CHECK(kv_cache.device() == dst_k.device(), + "kv_cache and dst_k must be on the same device"); + TORCH_CHECK(kv_cache.device() == dst_scale.device(), + "kv_cache and dst_scale must be on the same device"); + TORCH_CHECK(kv_cache.device() == block_table.device(), + "kv_cache and block_table must be on the same device"); + TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(), + "kv_cache and cu_seq_lens must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, + "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 16; + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (num_tokens < 32) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1); + } else if (num_tokens < 64) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2); + } else if (num_tokens < 128) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4); + } else if (num_tokens < 256) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8); + } else if (num_tokens < 512) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16); + } else { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); + } +} diff --git a/csrc/cache_kernels_fused.cu b/csrc/cache_kernels_fused.cu new file mode 100644 index 0000000000000000000000000000000000000000..be037b2fdec2be66a67c4a61144a0e45b009fc7a --- /dev/null +++ b/csrc/cache_kernels_fused.cu @@ -0,0 +1,279 @@ +#include +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#include "quantization/w8a8/fp8/common.cuh" +#ifdef USE_ROCM + #include "quantization/w8a8/fp8/amd/quant_utils.cuh" +#else + #include "quantization/w8a8/fp8/nvidia/quant_utils.cuh" +#endif + +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +namespace vllm { + +// NOTE Be EXTRA careful with raw_kv_scalar_t, for __half and __nv_bfloat16 it's +// using u16 as the backing type. +template +__global__ void concat_and_cache_mla_rope_fused_kernel( + const int64_t* __restrict__ positions, // [num_tokens] + qk_t* __restrict__ q_pe, // [num_tokens, num_q_heads, rot_dim] + qk_t* __restrict__ k_pe, // [num_tokens, rot_dim] + const qk_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const qk_t* __restrict__ rope_cos_sin_cache, // [max_position, 2, + // rot_dim // 2] + const int rot_dim, const int64_t q_pe_stride_token, + const int64_t q_pe_stride_head, const int64_t k_pe_stride, + const int64_t kv_c_stride, const int num_q_heads, + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // rot_dim)] + const int64_t* __restrict__ kv_cache_slot_mapping, // [num_tokens] + const int block_stride, const int entry_stride, const int kv_lora_rank, + const int block_size, const float* kv_cache_quant_scale) { + // Each thread block is responsible for one token. + const int64_t token_idx = blockIdx.x; + const int64_t pos = positions[token_idx]; + + const qk_t* cos_sin_ptr = rope_cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + + // Q ROPE + const int nq = num_q_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + int head_idx = i / embed_dim; + int pair_idx = i % embed_dim; + + // NOTE: Would be nice to have interleaved sin/cos so we could just load + // both at the same time. + qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx); + qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim); + + qk_t* q_pe_head_ptr = + q_pe + token_idx * q_pe_stride_token + head_idx * q_pe_stride_head; + + int pair_idx_x, pair_idx_y; + if constexpr (IS_NEOX) { + // GPT-NeoX style rotary embedding. + pair_idx_x = pair_idx; + pair_idx_y = embed_dim + pair_idx; + } else { + // GPT-J style rotary embedding. + pair_idx_x = pair_idx * 2; + pair_idx_y = pair_idx * 2 + 1; + } + + qk_t x_src = q_pe_head_ptr[pair_idx_x]; + qk_t y_src = q_pe_head_ptr[pair_idx_y]; + + qk_t x_dst = x_src * cos - y_src * sin; + qk_t y_dst = y_src * cos + x_src * sin; + + q_pe_head_ptr[pair_idx_x] = x_dst; + q_pe_head_ptr[pair_idx_y] = y_dst; + } + + const int64_t slot_idx = kv_cache_slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / block_size; + const int64_t entry_idx = slot_idx % block_size; + + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + + // K with 1 HEAD + for (int i = threadIdx.x; i < embed_dim; i += blockDim.x) { + int pair_idx = i; + + qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx); + qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim); + + qk_t* k_pe_head_ptr = k_pe + token_idx * k_pe_stride; + + int pair_idx_x, pair_idx_y; + if constexpr (IS_NEOX) { + // GPT-NeoX style rotary embedding. + pair_idx_x = pair_idx; + pair_idx_y = embed_dim + pair_idx; + } else { + // GPT-J style rotary embedding. + pair_idx_x = pair_idx * 2; + pair_idx_y = pair_idx * 2 + 1; + } + + qk_t x_src = k_pe_head_ptr[pair_idx_x]; + qk_t y_src = k_pe_head_ptr[pair_idx_y]; + + qk_t x_dst = x_src * cos - y_src * sin; + qk_t y_dst = y_src * cos + x_src * sin; + + k_pe_head_ptr[pair_idx_x] = x_dst; + k_pe_head_ptr[pair_idx_y] = y_dst; + + // NOTE Why is this monster necessary? + // When K is of type float16, the actual template replacement for + // raw_kv_scalar_t with be u16. That's why it's used at the last moment + // otherwise CUDA ALU would break. + const raw_kv_scalar_t raw_x_value = + *reinterpret_cast(&x_dst); + const raw_kv_scalar_t raw_y_value = + *reinterpret_cast(&y_dst); + + cache_t* kv_cache_ptr = kv_cache + block_idx * block_stride + + entry_idx * entry_stride + kv_lora_rank; + + // MLA Cache Store + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + kv_cache_ptr[pair_idx_x] = raw_x_value; + kv_cache_ptr[pair_idx_y] = raw_y_value; + } else { + kv_cache_ptr[pair_idx_x] = + fp8::scaled_convert( + raw_x_value, *kv_cache_quant_scale); + kv_cache_ptr[pair_idx_y] = + fp8::scaled_convert( + raw_y_value, *kv_cache_quant_scale); + } + } + + // NOPE + for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) { + const qk_t* src_ptr = kv_c + token_idx * kv_c_stride + i; + const raw_kv_scalar_t src_value = + *reinterpret_cast(src_ptr); + + cache_t* kv_cache_ptr = + kv_cache + block_idx * block_stride + entry_idx * entry_stride; + + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + kv_cache_ptr[i] = src_value; + } else { + kv_cache_ptr[i] = fp8::scaled_convert( + src_value, *kv_cache_quant_scale); + } + } +} + +} // namespace vllm + +#define CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED(RAW_KV_T, CACHE_T, KV_DTYPE) \ + do { \ + VLLM_DISPATCH_FLOATING_TYPES(q_pe.scalar_type(), "qk_scalar_type", [&] { \ + using qk_t = scalar_t; \ + if (rope_is_neox) { \ + vllm::concat_and_cache_mla_rope_fused_kernel \ + <<>>( \ + positions.data_ptr(), q_pe.data_ptr(), \ + k_pe.data_ptr(), kv_c.data_ptr(), \ + rope_cos_sin_cache.data_ptr(), rot_dim, \ + q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \ + num_q_heads, reinterpret_cast(kv_cache.data_ptr()), \ + kv_cache_slot_mapping.data_ptr(), block_stride, \ + entry_stride, kv_lora_rank, block_size, \ + kv_cache_quant_scale.data_ptr()); \ + } else { \ + vllm::concat_and_cache_mla_rope_fused_kernel \ + <<>>( \ + positions.data_ptr(), q_pe.data_ptr(), \ + k_pe.data_ptr(), kv_c.data_ptr(), \ + rope_cos_sin_cache.data_ptr(), rot_dim, \ + q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \ + num_q_heads, reinterpret_cast(kv_cache.data_ptr()), \ + kv_cache_slot_mapping.data_ptr(), block_stride, \ + entry_stride, kv_lora_rank, block_size, \ + kv_cache_quant_scale.data_ptr()); \ + } \ + }); \ + } while (false) + +// Executes RoPE on q_pe and k_pe, then writes k_pe and kv_c in the kv cache. +// q_pe and k_pe are modified in place. +// Replaces DeepseekScalingRotaryEmbedding.self.rotary_emb and +// concat_and_cache_mla. +void concat_and_cache_mla_rope_fused( + torch::Tensor& positions, // [num_tokens] + torch::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim] + torch::Tensor& k_pe, // [num_tokens, rot_dim] + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& rope_cos_sin_cache, // [max_position, rot_dim] + bool rope_is_neox, + torch::Tensor& + kv_cache_slot_mapping, // [num_tokens] or [num_actual_tokens] + torch::Tensor& + kv_cache, // [num_blocks, block_size, (kv_lora_rank + rot_dim)] + const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale) { + const int64_t num_tokens = q_pe.size(0); + + const int num_q_heads = q_pe.size(1); + const int rot_dim = q_pe.size(2); + const int kv_lora_rank = kv_c.size(1); + + TORCH_CHECK(positions.size(0) >= + num_tokens); // CUDA Graphs might pad this for us + TORCH_CHECK_EQ(positions.dim(), 1); + TORCH_CHECK_EQ(positions.scalar_type(), c10::ScalarType::Long); + + TORCH_CHECK_EQ(q_pe.size(0), num_tokens); + TORCH_CHECK_EQ(q_pe.size(1), num_q_heads); + TORCH_CHECK_EQ(q_pe.size(2), rot_dim); + TORCH_CHECK_EQ(q_pe.dim(), 3); + + TORCH_CHECK_EQ(k_pe.size(0), num_tokens); + TORCH_CHECK_EQ(k_pe.size(1), rot_dim); + TORCH_CHECK_EQ(k_pe.dim(), 2); + TORCH_CHECK_EQ(k_pe.scalar_type(), q_pe.scalar_type()); + + TORCH_CHECK_EQ(kv_c.size(0), num_tokens); + TORCH_CHECK_EQ(kv_c.size(1), kv_lora_rank); + TORCH_CHECK_EQ(kv_c.dim(), 2); + TORCH_CHECK_EQ(kv_c.scalar_type(), q_pe.scalar_type()); + TORCH_CHECK_EQ(kv_c.dtype(), q_pe.dtype()); + + TORCH_CHECK_EQ(rope_cos_sin_cache.size(1), rot_dim); + TORCH_CHECK_EQ(rope_cos_sin_cache.scalar_type(), q_pe.scalar_type()); + + TORCH_CHECK_EQ(kv_cache_slot_mapping.size(0), num_tokens); + TORCH_CHECK_EQ(kv_cache_slot_mapping.scalar_type(), c10::ScalarType::Long); + + TORCH_CHECK_EQ(kv_cache.size(2), kv_lora_rank + rot_dim); + TORCH_CHECK_EQ(kv_cache.dim(), 3); + + TORCH_CHECK_EQ(kv_cache_quant_scale.numel(), 1); + TORCH_CHECK_EQ(kv_cache_quant_scale.scalar_type(), c10::ScalarType::Float); + + int64_t q_pe_stride_token = q_pe.stride(0); + int64_t q_pe_stride_head = q_pe.stride(1); + + int64_t k_pe_stride = k_pe.stride(0); + int64_t kv_c_stride = kv_c.stride(0); + + int block_size = kv_cache.size(1); + + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + + int rope_block_size = std::min(num_q_heads * rot_dim / 2, 512); + int mla_block_size = kv_lora_rank; + int thread_block_size = + std::min(std::max(rope_block_size, mla_block_size), 512); + + dim3 grid(num_tokens, 1, 1); + dim3 block(thread_block_size, 1, 1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(positions)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED); +} diff --git a/csrc/core/batch_invariant.hpp b/csrc/core/batch_invariant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fffe96b868575d17b9191ab610a2862ff00b0a43 --- /dev/null +++ b/csrc/core/batch_invariant.hpp @@ -0,0 +1,19 @@ +#pragma once +#include +#include +#include + +namespace vllm { + +// vllm_is_batch_invariant(); returns true +// if env VLLM_BATCH_INVARIANT=1 +inline bool vllm_is_batch_invariant() { + static bool cached = []() { + std::string env_key = "VLLM_BATCH_INVARIANT"; + const char* val = std::getenv(env_key.c_str()); + return (val && std::atoi(val) != 0) ? 1 : 0; + }(); + return cached; +} + +} // namespace vllm diff --git a/csrc/core/exception.hpp b/csrc/core/exception.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f3b2ffaef6cce0b85f25fdd5090a227b581d4d3f --- /dev/null +++ b/csrc/core/exception.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define VLLM_IMPLIES(p, q) (!(p) || (q)) diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6764e1fd60545ad89d809934d6be02b04475ed2d --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +inline constexpr uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) { + return a % b == 0 ? a : (a / b) * b; +} + +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) { + return a % b == 0 ? a : ((a / b) + 1) * b; +} diff --git a/csrc/core/registration.h b/csrc/core/registration.h new file mode 100644 index 0000000000000000000000000000000000000000..4d0ce1c572c1c1ea947db0720ace5e7abe2a5624 --- /dev/null +++ b/csrc/core/registration.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..68a8750f583b46d344cd3180ffb334f38e3ae1f8 --- /dev/null +++ b/csrc/core/scalar_type.hpp @@ -0,0 +1,352 @@ +#pragma once + +// For TORCH_CHECK +#include + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, + nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, + auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = + ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +}; // namespace vllm diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..039b8d5c30d46e29110ca64a641aab165308550e --- /dev/null +++ b/csrc/cpu/activation.cpp @@ -0,0 +1,163 @@ +#include "cpu_types.hpp" + +namespace { +template +void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, + scalar_t* __restrict__ output) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + + TORCH_CHECK(d % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + for (int j = 0; j < d; j += VEC_ELEM_NUM) { + int start = i * d; + if constexpr (is_gated) { + start *= 2; + } + + const scalar_vec_t x(input + start + j); + const vec_op::FP32Vec8 f32_x(x); + vec_op::FP32Vec8 f32_ans = func(f32_x); + + if constexpr (is_gated) { + const scalar_vec_t y(input + start + d + j); + const vec_op::FP32Vec8 f32_y(y); + f32_ans = f32_y * f32_ans; + } + + const scalar_vec_t result(f32_ans); + result.save(output + i * d + j); + } + } +} + +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + return x / (ones + (zeros - x).exp()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 x3 = x * x * x; + const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); + const vec_op::FP32Vec8 w3(0.5); + const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh(); + return w3 * x * (ones + t); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(1.702f); + return x / (ones + (zeros - w1 * x).exp()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT1_2); + const vec_op::FP32Vec8 w2(0.5); + return x * w2 * (ones + (x * w1).er()); +} + +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); + const vec_op::FP32Vec8 w2(0.5); + const vec_op::FP32Vec8 w3(0.044715); + const vec_op::FP32Vec8 x_3 = x * x * x; + const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); + return x * w2 * (ones + inner.tanh()); +} +}; // namespace + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); +} + +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); +} + +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "gelu_tanh_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl) + }); +} + +void gelu_new(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_new_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_new_impl) + }); +} + +void gelu_fast(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_fast_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_fast_impl) + }); +} + +void gelu_quick(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_quick_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_quick_impl) + }); +} diff --git a/csrc/cpu/cpu_arch_macros.h b/csrc/cpu/cpu_arch_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..c73b62ecdec901f4cf543bc12176aa6645b2a7dc --- /dev/null +++ b/csrc/cpu/cpu_arch_macros.h @@ -0,0 +1,113 @@ +#ifndef CPU_ARCH_MACROS_H +#define CPU_ARCH_MACROS_H + +// x86_64 +#ifdef __x86_64__ + #define FAST_SPINNING _mm_pause(); + + #ifdef __AVX512F__ + #define DEFINE_FAST_EXP \ + const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); \ + const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); \ + const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); \ + const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); \ + const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); \ + const __m512 vec_exp_log2ef = \ + _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); \ + const __m512 vec_half = _mm512_set1_ps(0.5f); \ + const __m512 vec_one = _mm512_set1_ps(1.f); \ + const __m512 vec_zero = _mm512_set1_ps(0.f); \ + const __m512 vec_two = _mm512_set1_ps(2.f); \ + const __m512 vec_ln2f = \ + _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); \ + const __m512 vec_ln_flt_min = \ + _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); \ + const __m512 vec_ln_flt_max = \ + _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \ + const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \ + const int n_mantissa_bits = 23; \ + auto fast_exp = [&](const vec_op::FP32Vec16& vec) __attribute__(( \ + always_inline)) { \ + __m512 values = vec.reg; \ + auto less_ln_flt_min_mask = \ + _mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); \ + auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); \ + vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); \ + auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); \ + auto vec_fx_i = _mm512_cvt_roundps_epi32( \ + vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); \ + vec_fx = _mm512_cvtepi32_ps(vec_fx_i); \ + auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); \ + auto vec_res = \ + _mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); \ + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); \ + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); \ + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); \ + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); \ + auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); \ + auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); \ + auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); \ + vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); \ + auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); \ + vec_two_pow_n = _mm512_mask_blend_ps(less_ln_flt_min_mask, \ + vec_two_pow_n, vec_zero); \ + vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); \ + vec_res = _mm512_mul_ps(vec_res, vec_two); \ + vec_op::FP32Vec16 res(vec_res); \ + return res; \ + }; + #endif + +#endif + +#ifdef __aarch64__ + // Implementation copied from Arm Optimized Routines (expf AdvSIMD) + // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c + #include + #define DEFINE_FAST_EXP \ + const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \ + const float ln2_hi = 0x1.62e4p-1f; \ + const float ln2_lo = 0x1.7f7d1cp-20f; \ + const float c0 = 0x1.0e4020p-7f; \ + const float c2 = 0x1.555e66p-3f; \ + const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \ + const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \ + const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \ + const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \ + const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \ + const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \ + const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \ + const float32x4_t inf = \ + vdupq_n_f32(std::numeric_limits::infinity()); \ + const float32x4_t zero = vdupq_n_f32(0.0f); \ + auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \ + float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \ + float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \ + r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \ + uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \ + float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \ + float32x4_t r2 = vmulq_f32(r, r); \ + float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \ + float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \ + q = vfmaq_f32(q, p, r2); \ + p = vmulq_f32(c4, r); \ + float32x4_t poly = vfmaq_f32(p, q, r2); \ + poly = vfmaq_f32(scale, poly, scale); \ + const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \ + const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \ + poly = vbslq_f32(hi_mask, inf, poly); \ + return vbslq_f32(lo_mask, zero, poly); \ + }; \ + auto fast_exp = [&](const vec_op::FP32Vec16& vec) \ + __attribute__((always_inline)) { \ + float32x4x4_t result; \ + result.val[0] = neon_expf(vec.reg.val[0]); \ + result.val[1] = neon_expf(vec.reg.val[1]); \ + result.val[2] = neon_expf(vec.reg.val[2]); \ + result.val[3] = neon_expf(vec.reg.val[3]); \ + return vec_op::FP32Vec16(result); \ + }; + +#endif // __aarch64__ + +#endif diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a582b4b4d7cc7004d423025228d94cca1ea2bc46 --- /dev/null +++ b/csrc/cpu/cpu_attn.cpp @@ -0,0 +1,189 @@ +#include "cpu_attn_dispatch_generated.h" + +torch::Tensor get_scheduler_metadata( + const int64_t num_req, const int64_t num_heads_q, + const int64_t num_heads_kv, const int64_t head_dim, + const torch::Tensor& seq_lens, at::ScalarType dtype, + const torch::Tensor& query_start_loc, const bool casual, + const int64_t window_size, const std::string& isa_hint, + const bool enable_kv_split) { + cpu_attention::ISA isa; + if (isa_hint == "amx") { + isa = cpu_attention::ISA::AMX; + } else if (isa_hint == "vec") { + isa = cpu_attention::ISA::VEC; + } else if (isa_hint == "vec16") { + isa = cpu_attention::ISA::VEC16; + } else if (isa_hint == "neon") { + isa = cpu_attention::ISA::NEON; + } else if (isa_hint == "vxe") { + isa = cpu_attention::ISA::VXE; + } else { + TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint); + } + + cpu_attention::AttentionScheduler::ScheduleInput input; + input.num_reqs = num_req; + input.num_heads_q = num_heads_q; + input.num_heads_kv = num_heads_kv; + input.head_dim = head_dim; + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + if (window_size != -1) { + input.left_sliding_window_size = window_size - 1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = window_size - 1; + } + } else { + input.left_sliding_window_size = -1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = -1; + } + } + input.casual = casual; + input.isa = isa; + input.enable_kv_split = enable_kv_split; + + VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { + CPU_ATTN_DISPATCH(head_dim, isa, [&]() { + input.elem_size = sizeof(scalar_t); + input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); + input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); + input.output_buffer_elem_size = + sizeof(attn_impl::partial_output_buffer_t); + input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration; + input.kv_block_alignment = attn_impl::BlockSizeAlignment; + }); + }); + + cpu_attention::AttentionScheduler scheduler; + torch::Tensor metadata = scheduler.schedule(input); + return metadata; +} + +void cpu_attn_reshape_and_cache( + const torch::Tensor& key, // [token_num, head_num, head_size] + const torch::Tensor& value, // [token_num, head_num, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& slot_mapping, const std::string& isa) { + TORCH_CHECK_EQ(key.dim(), 3); + TORCH_CHECK_EQ(value.dim(), 3); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + TORCH_CHECK_EQ(key.stride(2), 1); + TORCH_CHECK_EQ(value.stride(2), 1); + + const int64_t token_num = key.size(0); + const int64_t key_token_num_stride = key.stride(0); + const int64_t value_token_num_stride = value.stride(0); + const int64_t head_num = value.size(1); + const int64_t key_head_num_stride = key.stride(1); + const int64_t value_head_num_stride = value.stride(1); + const int64_t num_blocks = key_cache.size(0); + const int64_t num_blocks_stride = key_cache.stride(0); + const int64_t cache_head_num_stride = key_cache.stride(1); + const int64_t block_size = key_cache.size(2); + const int64_t block_size_stride = key_cache.stride(2); + const int64_t head_dim = key.size(-1); + + cpu_attention::ISA isa_tag = [&]() { + if (isa == "amx") { + return cpu_attention::ISA::AMX; + } else if (isa == "vec") { + return cpu_attention::ISA::VEC; + } else if (isa == "vec16") { + return cpu_attention::ISA::VEC16; + } else if (isa == "neon") { + return cpu_attention::ISA::NEON; + } else if (isa == "vxe") { + return cpu_attention::ISA::VXE; + } else { + TORCH_CHECK(false, "Invalid ISA type: " + isa); + } + }(); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() { + CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() { + attn_impl::reshape_and_cache( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), token_num, key_token_num_stride, + value_token_num_stride, head_num, key_head_num_stride, + value_head_num_stride, num_blocks, num_blocks_stride, + cache_head_num_stride, block_size, block_size_stride); + }); + }); +} + +void cpu_attention_with_kv_cache( + const torch::Tensor& query, // [num_tokens, num_heads, head_size] + const torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& output, // [num_tokens, num_heads, head_size] + const torch::Tensor& query_start_loc, // [num_tokens + 1] + const torch::Tensor& seq_lens, // [num_tokens] + const double scale, const bool causal, + const std::optional& alibi_slopes, // [num_heads] + const int64_t sliding_window_left, const int64_t sliding_window_right, + const torch::Tensor& block_table, // [num_tokens, max_block_num] + const double softcap, const torch::Tensor& scheduler_metadata, + const std::optional& s_aux // [num_heads] +) { + TORCH_CHECK_EQ(query.dim(), 3); + TORCH_CHECK_EQ(query.stride(2), 1); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + + cpu_attention::AttentionInput input; + input.metadata = reinterpret_cast( + scheduler_metadata.data_ptr()); + input.num_tokens = query.size(0); + input.num_heads = query.size(1); + input.num_kv_heads = key_cache.size(1); + input.block_size = key_cache.size(2); + input.query = query.data_ptr(); + input.query_num_tokens_stride = query.stride(0); + input.query_num_heads_stride = query.stride(1); + input.cache_num_blocks_stride = key_cache.stride(0); + input.cache_num_kv_heads_stride = key_cache.stride(1); + input.blt_num_tokens_stride = block_table.stride(0); + input.key_cache = key_cache.data_ptr(); + input.value_cache = value_cache.data_ptr(); + input.output = output.data_ptr(); + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + input.block_table = block_table.data_ptr(); + input.alibi_slopes = + alibi_slopes.has_value() ? alibi_slopes->data_ptr() : nullptr; + // For now sink must be bf16 + input.s_aux = s_aux.has_value() ? s_aux->data_ptr() : nullptr; + input.scale = scale; + input.causal = causal; + input.sliding_window_left = sliding_window_left; + input.sliding_window_right = sliding_window_right; + if (input.causal) { + // to make boundary calculation easier + input.sliding_window_right = 0; + } + float softcap_fp32 = softcap; + input.softcap = softcap_fp32; + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "cpu_attention_with_kv_cache", [&]() { + CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() { + TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); + cpu_attention::AttentionMainLoop mainloop; + mainloop(&input); + }); + }); +} diff --git a/csrc/cpu/cpu_attn_amx.hpp b/csrc/cpu/cpu_attn_amx.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8da458b99119c31667ff875eeb947e5979f65968 --- /dev/null +++ b/csrc/cpu/cpu_attn_amx.hpp @@ -0,0 +1,511 @@ +#ifndef CPU_ATTN_AMX_HPP +#define CPU_ATTN_AMX_HPP + +#include "cpu_attn_impl.hpp" + +namespace cpu_attention { +namespace { +// AMX specific +constexpr static int64_t AMX_TILE_ROW_BYTES = 64; +constexpr static int64_t AMX_TILE_ROW_NUM = 16; +constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM; + +typedef struct __tile_config { + uint8_t palette_id = 1; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +} __tilecfg; + +// 2-2-4 pattern, for 16 < m <= 32 +// TILE 0, 1: load A matrix, row num should be 16, m - 16 +// TILE 2, 3: load B matrix, row num should be 16 +// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m +// - 16 +template +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } +}; + +template <> +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM; + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + // k_cache, v_cache are prepacked + const int32_t b_tile_stride = AMX_TILE_ROW_BYTES; + + // logits_buffer, output_buffer are not prepacked + float* __restrict__ c_tile_4 = c_tile; + float* __restrict__ c_tile_5 = + c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float); + float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc; + float* __restrict__ c_tile_7 = + c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float); + const int32_t c_tile_stride = ldc * sizeof(float); + + if (accum_c) { + _tile_loadd(4, c_tile_4, c_tile_stride); + _tile_loadd(5, c_tile_5, c_tile_stride); + _tile_loadd(6, c_tile_6, c_tile_stride); + _tile_loadd(7, c_tile_7, c_tile_stride); + } else { + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_tile_stride); + _tile_dpbf16ps(4, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_tile_stride); + _tile_dpbf16ps(5, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + _tile_stored(4, c_tile_4, c_tile_stride); + _tile_stored(5, c_tile_5, c_tile_stride); + _tile_stored(6, c_tile_6, c_tile_stride); + _tile_stored(7, c_tile_7, c_tile_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + const int32_t m_0 = AMX_TILE_ROW_NUM; + const int32_t m_1 = m - AMX_TILE_ROW_NUM; + config.rows[0] = m_0; + config.rows[1] = m_1; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = m_0; + config.rows[5] = m_0; + config.rows[6] = m_1; + config.rows[7] = m_1; + _tile_loadconfig(&config); + } +}; + +// 1-2-2 pattern, for 0 < m <= 16 +// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be +// m, m +// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row +// num should be 16 +// TILE 6, 7, (6, 7): store results C matrix, row num should be +// m +template +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } +}; + +template <> +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + c10::BFloat16* __restrict__ b_tile_4 = + b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + c10::BFloat16* __restrict__ b_tile_5 = + b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + int64_t b_stride = AMX_TILE_ROW_BYTES; + + float* __restrict__ c_tile_6 = c_tile; + float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float); + int64_t c_stride = ldc * sizeof(float); + + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + const int32_t k_group_times = k_times / 2; + const bool has_tail = (k_times % 2 == 1); + + if (accum_c) { + _tile_loadd(6, c_tile_6, c_stride); + _tile_loadd(7, c_tile_7, c_stride); + } else { + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_group_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_stream_loadd(4, b_tile_4, b_stride); + _tile_dpbf16ps(6, 1, 4); + _tile_stream_loadd(5, b_tile_5, b_stride); + _tile_dpbf16ps(7, 1, 5); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } + b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + if (has_tail) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + } + + _tile_stored(6, c_tile_6, c_stride); + _tile_stored(7, c_tile_7, c_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + config.rows[0] = m; + config.rows[1] = m; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = AMX_TILE_ROW_NUM; + config.rows[5] = AMX_TILE_ROW_NUM; + config.rows[6] = m; + config.rows[7] = m; + _tile_loadconfig(&config); + } +}; +} // namespace + +template +class AttentionImpl { + public: + using query_t = scalar_t; + using q_buffer_t = scalar_t; + using kv_cache_t = scalar_t; + using logits_buffer_t = float; + using partial_output_buffer_t = float; + using prob_buffer_t = scalar_t; + + constexpr static int64_t BlockSizeAlignment = + AMX_TILE_ROW_BYTES / + sizeof(kv_cache_t); // KV token num unit of QK and PV phases + constexpr static int64_t HeadDimAlignment = + 2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase + constexpr static int64_t MaxQHeadNumPerIteration = 32; + constexpr static int64_t HeadDim = head_dim; + constexpr static ISA ISAType = ISA::AMX; + constexpr static bool scale_on_logits = true; + + public: + AttentionImpl() : current_q_head_num_(0) { + // Use all columns in AMX tiles + vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; }); + } + + ~AttentionImpl() { _tile_release(); } + + template