diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..964712a78845f7bd81bc2a7b28fcb0cdcc45d080 --- /dev/null +++ b/.clang-format @@ -0,0 +1,8 @@ +--- +BasedOnStyle: LLVM +UseTab: Never +IndentWidth: 2 +ColumnLimit: 80 + +Language: Cpp +Standard: c++17 diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000000000000000000000000000000000000..f9b77bce8a0c2e6640e7d5f72a004527256ff510 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,60 @@ +--- +InheritParentConfig: true +ExtraArgs: [] +FormatStyle: file +UseColor: true +WarningsAsErrors: '*' +# FIXME: Use `ExcludeHeaderFilterRegex` instead when all maintainers upgraded their `clang-tidy` +HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*' +# ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$' + +# NOTE: there must be no spaces before the '-', so put the comma last. +Checks: >- + # 1. Retained categories: easier to find bugs/performance issues + clang-analyzer-*, + cppcoreguidelines-pro-type-static-cast-downcast, + cppcoreguidelines-pro-type-member-init, + cppcoreguidelines-pro-bounds-array-to-pointer-decay, + cppcoreguidelines-pro-bounds-pointer-arithmetic, + cppcoreguidelines-slicing, + cppcoreguidelines-narrowing-conversions, + performance-*, + + # 2. Readability: only keep useful rules + readability-braces-around-statements, + readability-container-size-empty, + readability-delete-null-pointer, + readability-redundant-member-init, + readability-redundant-smartptr-get, + readability-redundant-string-cstr, + + # 3. Disable all intrusive/style-breaking rules + -readability-identifier-length, + -readability-avoid-const-params-in-decls, + -readability-else-after-return, + -cppcoreguidelines-avoid-magic-numbers, + -modernize-use-trailing-return-type, + -modernize-use-nodiscard, + -modernize-use-auto, + -modernize-pass-by-value, + -modernize-return-braced-init-list, + -modernize-use-default-member-init, + -modernize-loop-convert, + -modernize-concat-nested-namespaces, + -llvm-include-order, + -bugprone-unused-return-value, + -clang-diagnostic-unused-result, + -cppcoreguidelines-special-member-functions, + -performance-noexcept-move-constructor, + -cppcoreguidelines-narrowing-conversions, + -clang-diagnostic-error, + -cppcoreguidelines-pro-type-member-init, + -clang-analyzer-optin.cplusplus.UninitializedObject, + -cppcoreguidelines-pro-type-static-cast-downcast, + -performance-unnecessary-value-param, + -performance-enum-size, + -cppcoreguidelines-pro-bounds-pointer-arithmetic, + -cppcoreguidelines-pro-bounds-array-to-pointer-decay, + -clang-analyzer-deadcode.DeadStores, + -clang-analyzer-optin.cplusplus.VirtualCall, + -clang-diagnostic-tautological-constant-compare, diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..a9e8a6df4a64c1f2fe0d0be940e9b318ed0e6e2b --- /dev/null +++ b/.editorconfig @@ -0,0 +1,44 @@ +# https://editorconfig.org/ + +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_style = space +indent_size = 4 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{py,pyi}] +indent_size = 4 + +[*.{cpp,hpp,cxx,cc,c,h,cu,cuh}] +indent_size = 2 + +[{*.cmake,CMakeLists.txt}] +indent_size = 2 + +[*.{yaml,yml}] +indent_size = 2 + +[.clang-{format,tidy}] +indent_size = 2 + +[Makefile] +indent_style = tab + +[*.sh] +indent_size = 4 + +[*.bat] +indent_size = 4 +end_of_line = crlf + +[*.md] +indent_size = 2 +x-soft-wrap-text = true + +[*.rst] +indent_size = 4 +x-soft-wrap-text = true diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..bbb14db3706e0be486e78629b4a282709a00b6ac --- /dev/null +++ b/.gitattributes @@ -0,0 +1,10 @@ +* text eol=lf +*.bat eol=crlf + +*.svg binary +*.jpg binary +*.jpeg binary +*.png binary +*.gif binary + +*.h linguist-language=C++ diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000000000000000000000000000000000000..64235155241baebcfc741f1e1892f466477a8661 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,112 @@ +name: ๐Ÿ› Bug Report +description: File an issue about a bug. +title: "[BUG] " +labels: [bug] +assignees: [] +body: + - type: markdown + attributes: + value: >- + Please do your best to make the issue as easy to act on as possible, + and only submit here if there is clearly a problem with TileLang. + + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: I have read the documentation . + required: true + - label: >- + I have searched the [Issue Tracker](https://github.com/tile-ai/tilelang/issues) + that this hasn't already been reported. (comment there if it has.) + required: true + + - type: input + id: version + attributes: + label: What version of TileLang are you using? + description: >- + Run command `python3 -c 'print(__import__("tilelang").__version__)'` in your shell + and paste the output here. + placeholder: E.g., 0.1.5 + validations: + required: true + + - type: textarea + id: system-info + attributes: + label: System information + description: | + Describe the characteristic of your environment: + + - Describe how the library was installed (pip, conda, source, ...) + - Python version + - Versions of any other relevant libraries + + ```python + import sys, tilelang, torch + print(sys.version, sys.platform) + print(tilelang.__version__) + print(torch.__version__) + ``` + + ```bash + python3 -m torch.utils.collect_env + ``` + validations: + required: true + + - type: textarea + id: description + attributes: + label: Problem description + description: >- + Provide a short description, state the expected behavior and what actually happens. Include + relevant information like what version of TileLang you are using, what system you are on, and + any useful commands / output. + validations: + required: true + + - type: textarea + id: code + attributes: + label: Reproducible example code + description: >- + The code should be minimal, have minimal external dependencies, and isolate the functions + that cause breakage. Submit matched and complete snippets that can be easily run to diagnose + the issue. + value: | + The Python snippets: + + ```python + + ``` + validations: + required: true + + - type: textarea + id: traceback + attributes: + label: Traceback + description: Put the Python traceback information here. + placeholder: | + Traceback (most recent call last): + File ... + render: pytb + + - type: textarea + id: expected + attributes: + label: Expected behavior + description: Provide a clear and concise description of what you expected to happen. + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: >- + Add any other context about the problem here. Screenshots may also be helpful. + + If you know or suspect the reason for this bug, paste the code lines and suggest modifications. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..0086358db1eb971c0cfa8739c27518bbc18a5ff4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: true diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 0000000000000000000000000000000000000000..c1b520f72ed044e17200756cc6b4879f62167d28 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,45 @@ +name: โœจ Feature Request +description: Suggest an idea for this project. +title: "[Feature Request] " +labels: [enhancement] +body: + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: >- + I have searched the [Issue Tracker](https://github.com/tile-ai/tilelang/issues) + that this hasn't already been reported. (comment there if it has.) + required: true + + - type: textarea + id: motivation + attributes: + label: Motivation + description: Outline the motivation for the proposal. + value: | + + validations: + required: true + + - type: textarea + id: solution + attributes: + label: Solution + description: Provide a clear and concise description of what you want to happen. + + - type: textarea + id: alternatives + attributes: + label: Alternatives + description: A clear and concise description of any alternative solutions or features you've considered. + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add any other context about the problem here. Screenshots may also be helpful. diff --git a/.github/ISSUE_TEMPLATE/questions.yml b/.github/ISSUE_TEMPLATE/questions.yml new file mode 100644 index 0000000000000000000000000000000000000000..e7f948d4e38dfafb362212c86d6f9f5adbc512ce --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions.yml @@ -0,0 +1,25 @@ +name: ๐Ÿค” Questions / Help / Support +description: Do you need support? +title: "[Question] " +labels: [question] +body: + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: I have read the documentation . + required: true + - label: >- + I have searched the [Issue Tracker](https://github.com/tile-ai/tilelang/issues) + that this hasn't already been reported. (comment there if it has.) + required: true + + - type: textarea + id: questions + attributes: + label: Questions + description: Describe your questions with relevant resources such as snippets, links, images, etc. + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/release-plan.yml b/.github/ISSUE_TEMPLATE/release-plan.yml new file mode 100644 index 0000000000000000000000000000000000000000..a3528275c8b15d1a2928c5fd4d2eb0315da0bb08 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/release-plan.yml @@ -0,0 +1,63 @@ +name: "Release Plan" +description: "Plan the next release" +title: "[Release Plan] vX.Y.Z" +labels: + - release-plan + - tracking +assignees: [] +body: + - type: input + id: version + attributes: + label: "Version" + placeholder: "v0.2.0" + validations: + required: true + + - type: input + id: milestone + attributes: + label: "Milestone" + description: "Link or name of the milestone for this release" + placeholder: "https://github.com/tile-ai/tilelang/milestone/XX" + + - type: textarea + id: scope + attributes: + label: "Scope" + description: "Goals and non-goals (brief)" + placeholder: | + - Goals: ... + - Non-goals: ... + + - type: textarea + id: tasks + attributes: + label: "Tasks" + description: "Task list; link issues/PRs" + value: | + - [ ] Features + - [ ] Fixes + - [ ] Docs + - [ ] API/Breaking changes + - [ ] Benchmarks + - [ ] Release notes + + - type: checkboxes + id: readiness + attributes: + label: "Readiness" + options: + - label: "All planned issues closed or deferred" + - label: "Docs updated" + - label: "CI green; artifacts verified" + - label: "Release notes drafted" + + - type: textarea + id: notes + attributes: + label: "Notes" + description: "Risks or communications (optional)" + placeholder: | + - Risk: ... + - Communication: ... diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000000000000000000000000000000000..63e1f3bd558e988616673ecd0e19277d04e79696 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + day: "monday" + time: "12:00" + timezone: "Asia/Shanghai" + commit-message: + prefix: "[CI]" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..d7abaeb0f1baf93c416675abea4848d541ba35c4 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,428 @@ +name: CI +on: + pull_request: + types: + - labeled + - unlabeled + - opened + - synchronize + - reopened + # Allow to trigger the workflow manually + workflow_dispatch: + +permissions: + contents: read + +concurrency: + group: "${{ github.workflow }}-${{ github.ref }}" + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +env: + CLANG_TIDY_CMAKE_OPTIONS: "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON" # to be updated + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + PYTHONPATH: "" # explicit cleanup + PIP_USER: "" # explicit cleanup + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + UV_INDEX_STRATEGY: "unsafe-best-match" + UV_HTTP_TIMEOUT: "600" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated + PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pip/.pre-commit" # to be updated + +jobs: + lint: + name: Quick Lint + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: recursive + + - name: Setup Python 3.8 + id: setup-pylowest + uses: actions/setup-python@v6 + with: + python-version: "3.8" # use lowest supported version for linting + update-environment: false + + - name: Check AST with Python 3.8 + run: | + "${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang + + - name: Setup Python 3.9 + uses: actions/setup-python@v6 + with: + python-version: "3.9" + update-environment: true + cache: pip + cache-dependency-path: | + pyproject.toml + requirements*.txt + .pre-commit-config.yaml + + - name: Pre-commit Lint + run: | + if ! pipx run pre-commit run --all-files --color=always --show-diff-on-failure; then + echo "::error::Pre-commit checks failed. Please run 'pre-commit install' and 'pre-commit run --all-files' locally to see the issues." + exit 1 + fi + + tests: + name: Test for Python ${{ matrix.python-version }} with ${{ matrix.runner.toolkit }} (on ${{ matrix.runner.name }}) + if: | + github.repository_owner == 'tile-ai' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) + needs: [lint] + runs-on: ${{ matrix.runner.tags }} + strategy: + matrix: + runner: + - tags: [self-hosted, nvidia] + name: self-hosted-nvidia + # Format: [Nightly-]CUDA-.[.]. E.g., "CUDA-12.8" or "Nightly-CUDA-13.0". + # Use "Nightly-" prefix to use torch nightly builds. + toolkit: CUDA-12.8 + - tags: [self-hosted, amd, gpu] + name: self-hosted-amd + # Format: [Nightly-]ROCm-.[.]. E.g., "ROCm-6.4" or "Nightly-ROCm-7.0". + # Use "Nightly-" prefix to use torch nightly builds. + toolkit: Nightly-ROCm-7.1 + - tags: [macos-latest] + name: macos-latest + toolkit: Metal # or Nightly-Metal + python-version: + - "3.12" + fail-fast: false + timeout-minutes: 120 + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: recursive + + - name: Set environment (self-hosted runners) + if: startsWith(matrix.runner.name, 'self-hosted') + run: | + # Hide sensitive data in logs for self-hosted runners + if [[ -n "${{ secrets.SECRET_PATH_PREFIXES }}" ]]; then + echo "::add-mask::${{ secrets.SECRET_PATH_PREFIXES }}" + # Colon separated list of secrets to mask + for secret in $(echo "${{ secrets.SECRET_PATH_PREFIXES }}" | tr ':' '\n'); do + echo "::add-mask::${secret}" + done + fi + + # Use runner tool_cache as cache root for self-hosted runners to avoid internet connection + # issues and to share cache between jobs. + export XDG_CACHE_HOME="${{ runner.tool_cache }}/.ci-cache-${{ github.workflow }}" + echo "XDG_CACHE_HOME=${XDG_CACHE_HOME}" | tee -a "${GITHUB_ENV}" + echo "PIP_CACHE_DIR=${XDG_CACHE_HOME}/pip" | tee -a "${GITHUB_ENV}" + echo "UV_CACHE_DIR=${XDG_CACHE_HOME}/uv" | tee -a "${GITHUB_ENV}" + echo "PRE_COMMIT_HOME=${XDG_CACHE_HOME}/pip/.pre-commit" | tee -a "${GITHUB_ENV}" + + # Do not use ccache on self-hosted runners, as it will download/upload caches which is slow. + # Self-hosted runners usually have more CPU power to compile without ccache. + - name: Setup ccache (GitHub-hosted runners) + id: setup-ccache + if: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + uses: hendrikmuhs/ccache-action@v1 + with: + create-symlink: true + evict-old-files: "7d" + append-timestamp: false + key: ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }}-${{ hashFiles('**/*.cc') }} + restore-keys: | + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }}-${{ hashFiles('**/*.cc') }} + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.runner.toolkit }} + ${{ runner.os }}-${{ runner.arch }} + + - name: Set environment (CUDA) + if: contains(matrix.runner.toolkit, 'CUDA') + run: | + TOOLKIT="${{ matrix.runner.toolkit }}" + CUDA_VERSION="${TOOLKIT##*-}" + CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)" + CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}" + if [[ "${TOOLKIT}" == "Nightly-"* ]]; then + # Use torch nightly builds + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION_MAJMIN_NODOT}" + else + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}" + fi + export UV_INDEX="${PIP_EXTRA_INDEX_URL}" + export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_CUDA=ON" + + echo "USE_CUDA=ON" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN=${CUDA_VERSION_MAJMIN}" | tee -a "${GITHUB_ENV}" + echo "CUDA_VERSION_MAJMIN_NODOT=${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" + echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" + + if [[ ! -x "$(command -v nvcc)" ]]; then + export PATH="/usr/local/cuda/bin:${PATH}" + export LD_LIBRARY_PATH="/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + echo "PATH=${PATH}" | tee -a "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "${GITHUB_ENV}" + fi + if [[ -x "$(command -v nvcc)" ]]; then + echo "\$ $(command -v nvcc) --version" && nvcc --version + else + echo "::warning::nvcc not found in PATH!" + fi + + - name: Set environment (ROCm) + if: contains(matrix.runner.toolkit, 'ROCm') + run: | + TOOLKIT="${{ matrix.runner.toolkit }}" + ROCM_VERSION="${TOOLKIT##*-}" + ROCM_VERSION_MAJMIN="$(echo ${ROCM_VERSION} | cut -d '.' -f-2)" + ROCM_VERSION_MAJMIN_NODOT="${ROCM_VERSION_MAJMIN//./}" + if [[ "${TOOLKIT}" == "Nightly-"* ]]; then + # Use torch nightly builds + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/nightly/rocm${ROCM_VERSION_MAJMIN}" + else + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/rocm${ROCM_VERSION_MAJMIN}" + fi + export UV_INDEX="${PIP_EXTRA_INDEX_URL}" + export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_ROCM=ON" + + echo "USE_ROCM=ON" | tee -a "${GITHUB_ENV}" + echo "ROCM_VERSION=${ROCM_VERSION}" | tee -a "${GITHUB_ENV}" + echo "ROCM_VERSION_MAJMIN=${ROCM_VERSION_MAJMIN}" | tee -a "${GITHUB_ENV}" + echo "ROCM_VERSION_MAJMIN_NODOT=${ROCM_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}" + echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" + + if [[ ! -x "$(command -v hipcc)" ]]; then + export PATH="/opt/rocm/bin:${PATH}" + export LD_LIBRARY_PATH="/opt/rocm/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + echo "PATH=${PATH}" | tee -a "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" | tee -a "${GITHUB_ENV}" + fi + if [[ -x "$(command -v hipcc)" ]]; then + echo "\$ $(command -v hipcc) --version" && hipcc --version + else + echo "::warning::hipcc not found in PATH!" + fi + + - name: Set environment (Metal) + if: contains(matrix.runner.toolkit, 'Metal') + run: | + if [[ "${{ matrix.runner.toolkit }}" == "Nightly-"* ]]; then + # Use torch nightly builds + export PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/nightly/cpu" + export UV_INDEX="${PIP_EXTRA_INDEX_URL}" + echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}" + echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}" + fi + export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_METAL=ON" + + echo "USE_METAL=ON" | tee -a "${GITHUB_ENV}" + echo "CLANG_TIDY_CMAKE_OPTIONS=${CLANG_TIDY_CMAKE_OPTIONS}" | tee -a "${GITHUB_ENV}" + + - name: Setup Python and uv with caching + id: setup-uv + uses: astral-sh/setup-uv@v7 + with: + python-version: ${{ matrix.python-version }} + activate-environment: true + # Do not use cache for self-hosted runners, as it will download/upload caches which is slow. + enable-cache: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + prune-cache: ${{ !startsWith(matrix.runner.name, 'self-hosted') }} + # Use runner tool_cache for self-hosted runners + cache-local-path: ${{ env.UV_CACHE_DIR }} + ignore-nothing-to-cache: true + # Extra cache key to upload/download caches on GitHub-hosted runners + cache-suffix: uv-${{ runner.os }}-${{ runner.arch }}-${{ matrix.python-version }}-${{ matrix.runner.name }}-${{ matrix.runner.toolkit }} + cache-dependency-glob: | + pyproject.toml + requirements*.txt + .pre-commit-config.yaml + + - name: Setup venv + id: setup-venv + run: | + set -o pipefail + + uv pip install --upgrade pip setuptools wheel + if [[ "${UV_INDEX}" == *"/nightly/"* ]]; then + uv pip install --prerelease=allow -v torch + fi + uv pip install -v -r requirements-test.txt + echo "import torch; print(f'torch: {torch.__version__}')" | uv run --no-project --script - + if [[ "${{ matrix.runner.toolkit }}" == *"CUDA"* ]]; then + uv pip install --no-build-isolation-package=flash-attn -v -r requirements-test-cuda.txt + echo "import flash_attn; print(f'flash_attn: {flash_attn.__version__}')" | uv run --no-project --script - + elif [[ "${{ matrix.runner.toolkit }}" == *"ROCm"* ]]; then + uv pip install -v -r requirements-test-rocm.txt + elif [[ "${{ matrix.runner.toolkit }}" == *"Metal"* ]]; then + uv pip install -v -r requirements-test-metal.txt + else + echo "::error::Unknown toolkit: ${{ matrix.runner.toolkit }}" + exit 1 + fi + echo "::group::torch.utils.collect_env" + uv run --no-project -m -- torch.utils.collect_env + echo "::endgroup::" + + - name: Clear uv cache for self-hosted runners (if setup failed) + if: >- + ${{ + failure() && + startsWith(matrix.runner.name, 'self-hosted') && + (steps.setup-uv.conclusion == 'failure' || steps.setup-venv.conclusion == 'failure') + }} + run: | + echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure." + uv cache clean + + - name: Enable core dump generation (Linux / GitHub-hosted runners) + if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kernel.core_uses_pid=0 + sudo sysctl -w fs.suid_dumpable=1 + sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable + + - name: Enable core dump generation (macOS / GitHub-hosted runners) + if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }} + run: | + sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P" + sudo sysctl -w kern.coredump=1 + sudo sysctl -w kern.sugid_coredump=1 + sysctl kern.corefile kern.coredump kern.sugid_coredump + + - name: Install project (wheel form) + run: | + uv pip install -v . + + - name: Run clang-tidy + id: clang-tidy + if: runner.os == 'Linux' + run: | + echo "\$ $(command -v clang-tidy) --version" && clang-tidy --version + + # Download run-clang-tidy script + RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py + echo "Downloading run-clang-tidy script from ${RCT_URL}" + echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script - + RUN_CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py) + + if [[ -x "$(command -v clang-apply-replacements)" ]]; then + echo "Using clang-apply-replacements from $(command -v clang-apply-replacements)" + RUN_CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)") + else + echo "::warning::clang-apply-replacements not found in PATH, automatic fixing disabled." + fi + + # Run cmake to create the build directory with compile_commands.json + cmake -S . -B cmake-build --fresh ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here + echo "::group::compile_commands.json" + ls -alh cmake-build/compile_commands.json + uv run --no-project -m -- json.tool --no-ensure-ascii cmake-build/compile_commands.json + echo "::endgroup::" + + CXX_FILES=$(find src -type f -iname "*.[ch]pp" -o -iname "*.cc" -o -iname "*.c" -o -iname "*.h") + rc=0 + echo "::group::run-clang-tidy" + "${RUN_CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \ + -exclude-header-filter='^(3rdparty|tvm)/.*$' \ + -p="cmake-build" ${CXX_FILES} || rc="$?" + echo "::endgroup::" + rm -rf cmake-build run-clang-tidy.py + if (( rc != 0 )); then + echo "::error::clang-tidy found issues (exit code: ${rc}). Please run 'clang-tidy --fix' locally to fix them." + git diff --color=always || true + exit "${rc}" + fi + + - name: Run examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) + if: contains(matrix.runner.toolkit, 'CUDA') + run: | + cd testing + PYTEST=( + uv run --no-project -m -- + pytest --verbose --color=yes --durations=0 --showlocals --cache-clear + ) + "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + ../examples + + # NVIDIA CUDA tests + - name: Run CUDA tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) + id: cuda-tests + if: contains(matrix.runner.toolkit, 'CUDA') + run: | + cd testing + PYTEST=( + uv run --no-project -m -- + pytest --verbose --color=yes --durations=0 --showlocals --cache-clear + ) + "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + --ignore=./python/jit/test_tilelang_jit_cutedsl.py \ + ./python + + # CuTeDSL JIT tests require GEMM v1 (must be set before importing tilelang). + # Run them in a dedicated step to avoid changing the default GEMM selection + # (and to keep the rest of the CUDA tests on GEMM v2). + - name: Run CuTeDSL JIT tests (GEMM v1) with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) + id: cutedsl-tests + if: contains(matrix.runner.toolkit, 'CUDA') + env: + TILELANG_USE_GEMM_V1: "1" + run: | + cd testing + PYTEST=( + uv run --no-project -m -- + pytest --verbose --color=yes --durations=0 --showlocals --cache-clear + ) + # Avoid xdist contention on a single GPU by running this file in one worker. + "${PYTEST[@]}" --maxfail=3 --numprocesses=1 \ + ./python/jit/test_tilelang_jit_cutedsl.py + + # AMD ROCm tests + - name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) + id: rocm-tests + if: contains(matrix.runner.toolkit, 'ROCm') + run: | + cd testing + PYTEST=( + uv run --no-project -m -- + pytest --verbose --color=yes --durations=0 --showlocals --cache-clear + ) + "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + ./python/amd + + # Apple Metal tests + - name: Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) + id: metal-tests + if: contains(matrix.runner.toolkit, 'Metal') + run: | + cd testing + PYTEST=( + uv run --no-project -m -- + pytest --verbose --color=yes --durations=0 --showlocals --cache-clear + ) + "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ + -k metal \ + ./python + + - name: List generated files + if: ${{ !cancelled() }} + run: | + find . -type f -name '*.py[co]' -delete + find . -depth -type d -name "__pycache__" -exec rm -r "{}" + + if git status --ignored --porcelain | grep -qvE '/$'; then + ls -alh $(git status --ignored --porcelain | grep -vE '/$' | grep -oE '\S+$') + fi diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml new file mode 100644 index 0000000000000000000000000000000000000000..6666871100bfe7934b31e1fa342bf0779d26d51d --- /dev/null +++ b/.github/workflows/dist.yml @@ -0,0 +1,208 @@ +name: Dist +on: + workflow_dispatch: + schedule: + # gemini said this is 6:00 china time + - cron: "0 22 * * *" + pull_request: + types: + - opened + - synchronize + - reopened + - ready_for_review + paths: + - setup.py + - setup.cfg + - pyproject.toml + - MANIFEST.in + - CMakeLists.txt + - version_provider.py + - .github/workflows/dist.yml + release: + types: + - published + +permissions: + contents: read + +concurrency: + group: "${{ github.workflow }}-${{ github.ref }}" + cancel-in-progress: true + +env: + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + +jobs: + build-sdist: + name: Build SDist + if: | + github.repository_owner == 'tile-ai' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) + runs-on: macos-latest + timeout-minutes: 30 + env: + # `NO_VERSION_LABEL=ON` disables embedding the toolchain / git commit hash in version metadata. + # Otherwise, the version of the SDist has a git hash suffix (e.g., 0.1.0+gitabcdef12), + # but the package built from the SDist has no way to get the git hash (it is not a git repo), + # leading to inconsistent versions between SDist and built packages (+gitabcdef12 vs. +gitunknown). + NO_VERSION_LABEL: 'ON' + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 1 + submodules: recursive + + - name: Setup Python and uv with caching + id: setup-uv + uses: astral-sh/setup-uv@v7 + with: + python-version: "3.12" + activate-environment: true + + - name: Build SDist + run: | + uv run --no-project --with=build -m -- build --sdist --outdir=dist + + - name: Setup ccache + uses: hendrikmuhs/ccache-action@v1 + with: + create-symlink: true + evict-old-files: "7d" + append-timestamp: false + key: sdist-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.cc') }} + restore-keys: | + sdist-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.cc') }} + sdist-${{ runner.os }}-${{ runner.arch }} + ${{ runner.os }}-${{ runner.arch }} + + - name: Test SDist buildable + run: | + TEMP_DIR="$(mktemp -d -t tilelang-sdist-test)" + cp -r dist "${TEMP_DIR}/dist" + cd "${TEMP_DIR}" + uv pip install -v dist/*.tar.gz + python3 -c "import tilelang; print(tilelang.__version__)" + + - name: Upload SDist + # Not PR to save artifact storage, as SDist is only needed for releases. + if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') + uses: actions/upload-artifact@v6 + with: + name: sdist + path: dist/*.tar.gz + if-no-files-found: error + + build-wheels: + name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.target.runner }} with ${{ matrix.target.toolkit }} + if: | + github.repository_owner == 'tile-ai' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) + strategy: + matrix: + target: + - { runner: ubuntu-latest, toolkit: "CUDA-12.8" } + - { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" } + - { runner: macos-latest, toolkit: "Metal" } + python-version: + # Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8. + # Only build wheels against Python 3.8 Limited API to save CI resources. + - "3.9" + fail-fast: false + timeout-minutes: 120 + runs-on: ${{ matrix.target.runner }} + env: + NO_VERSION_LABEL: ${{ github.event_name == 'release' && 'OFF' || 'ON' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 1 + submodules: recursive + + - name: Setup ccache + uses: hendrikmuhs/ccache-action@v1 + with: + create-symlink: true + evict-old-files: "7d" + append-timestamp: false + key: wheel-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }}-${{ hashFiles('**/*.cc') }} + restore-keys: | + wheel-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }}-${{ hashFiles('**/*.cc') }} + wheel-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }} + wheel-${{ runner.os }}-${{ runner.arch }} + ${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }} + ${{ runner.os }}-${{ runner.arch }} + + - name: Set CIBW_BUILD + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + PYTHON_VERSION_MAJMIN="$(echo "${PYTHON_VERSION}" | cut -d '.' -f-2)" + PYTHON_VERSION_MAJMIN_NODOT="${PYTHON_VERSION_MAJMIN//./}" + echo "CIBW_BUILD=cp${PYTHON_VERSION_MAJMIN_NODOT}-*" | tee -a "${GITHUB_ENV}" + + if [[ "${{ matrix.target.toolkit }}" == *"CUDA"* ]]; then + CUDA_VERSION="${{ matrix.target.toolkit }}" + CUDA_VERSION="${CUDA_VERSION#CUDA-}" + echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}" + fi + + if [[ "${{ runner.os }}" == "Linux" ]]; then + HOST_CCACHE_DIR="$(ccache --get-config cache_dir)" + echo "CIBW_BEFORE_BUILD_LINUX=yum install -y ccache && ccache -o cache_dir=/host${HOST_CCACHE_DIR}" | tee -a "${GITHUB_ENV}" + fi + + - name: Build wheels + uses: pypa/cibuildwheel@v3.3 + with: + package-dir: . + output-dir: wheelhouse + config-file: "{package}/pyproject.toml" + + - name: Upload wheels + # Not PR to save artifact storage, as wheels are only needed for releases. + if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') + uses: actions/upload-artifact@v6 + with: + name: wheels-${{ matrix.python-version }}-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }} + path: wheelhouse/*.whl + if-no-files-found: error + + list-artifacts: + name: List artifacts + # Not PR to save artifact storage, as artifacts are only needed for releases. + if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') + runs-on: ubuntu-latest + needs: [build-sdist, build-wheels] + timeout-minutes: 15 + steps: + - name: Download built SDist + uses: actions/download-artifact@v7 + with: + # unpacks default artifact into dist/ + # if `name: artifact` is omitted, the action will create extra parent dir + name: sdist + path: dist + + - name: Download built wheels + uses: actions/download-artifact@v7 + with: + pattern: wheels-* + path: dist + merge-multiple: true + + - name: List distributions + run: ls -lh dist/* + + - name: Upload artifacts + uses: actions/upload-artifact@v6 + with: + name: artifacts + path: dist/* + if-no-files-found: error diff --git a/.github/workflows/pr-perfbench-bot.yml b/.github/workflows/pr-perfbench-bot.yml new file mode 100644 index 0000000000000000000000000000000000000000..e6954bcc45b08eb60b5d2ba83f26f18ce02635cb --- /dev/null +++ b/.github/workflows/pr-perfbench-bot.yml @@ -0,0 +1,88 @@ +name: Performance Benchmark Bot + +on: + issue_comment: + types: + - created + +permissions: + contents: read + +concurrency: + group: "${{ github.workflow }}-${{ github.ref }}" + cancel-in-progress: true # always cancel in-progress + +env: + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + PYTHONPATH: "" # explicit cleanup + PIP_USER: "" # explicit cleanup + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + +jobs: + perfbench: + name: Benchmark between PR and main + if: | + github.repository_owner == 'tile-ai' && + github.event.issue.pull_request && + (contains(github.event.comment.body, '/performance-report') || contains(github.event.comment.body, '/perf')) + runs-on: [self-hosted, nvidia] + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + ref: refs/pull/${{ github.event.issue.number }}/merge + fetch-depth: 0 + submodules: recursive + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + update-environment: true + cache: pip + cache-dependency-path: | + pyproject.toml + requirements*.txt + + - name: Install merged version + run: | + python -m venv tll + source tll/bin/activate + pip install -r requirements-test.txt + pip install . + + - name: Install original version + run: | + echo "Check files to be deleted!" + git clean -dxf -e tll/ + echo "Delete files completed!" + git checkout main + python -m venv tl + source tl/bin/activate + pip install -r requirements-test.txt + pip install . + + - name: Run performance test + id: perfbench + run: | + source tl/bin/activate + python maint/scripts/ci_performance.py + + - name: Post test results as PR comment + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: '๐Ÿ“Š โ€‹**Performance Test Results** (triggered by @' + context.payload.comment.user.login + '):\n\n' + + 'Run listed here: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}\n\n' + + "${{ steps.perfbench.outputs.stdout }}" + }) diff --git a/.github/workflows/pr-reminder-bot.yml b/.github/workflows/pr-reminder-bot.yml new file mode 100644 index 0000000000000000000000000000000000000000..67e12936c6bcdbc12d8e44a24c5dde0e34e6169f --- /dev/null +++ b/.github/workflows/pr-reminder-bot.yml @@ -0,0 +1,28 @@ +name: PR Reminder Bot + +on: + pull_request_target: + types: + - opened + +jobs: + remind: + runs-on: ubuntu-latest + if: github.repository_owner == 'tile-ai' + steps: + - name: Remind + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: '๐Ÿ‘‹ Hi! Thank you for contributing to the **TileLang** project.\n\n' + + 'Please remember to run `pre-commit run --all-files` in the root directory of the project ' + + 'to ensure your changes are properly linted and formatted. ' + + 'This will help ensure your contribution passes the format check.\n\n' + + 'We appreciate you taking this step! ' + + 'Our team will review your contribution, and we look forward to your awesome work! ๐Ÿš€' + }) diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml new file mode 100644 index 0000000000000000000000000000000000000000..2197015b66df1c35babf2a941a9dced814ef5699 --- /dev/null +++ b/.github/workflows/publish-docs.yml @@ -0,0 +1,62 @@ +name: Documentation + +on: + pull_request_target: + types: + - closed + workflow_dispatch: + +permissions: + contents: write + +jobs: + docs: + name: Build and Publish Docs + if: | + github.repository_owner == 'tile-ai' && + ( + ( + github.event_name == 'pull_request_target' && + github.event.pull_request.merged == true && + github.event.pull_request.base.ref == 'main' + ) || + github.event_name == 'workflow_dispatch' + ) + runs-on: [self-hosted, nvidia] + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: recursive + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.10" + + - name: Build docs + run: | + bash -ex maint/scripts/build_docs.sh + + - name: Push built docs to another repo + run: | + # Hide sensitive info in logs + echo "::add-mask::${{ secrets.TARGET_TOKEN }}" + echo "::add-mask::${{ secrets.TARGET_REPO }}" + TARGET_REPO_URL="https://github.com/${{ secrets.TARGET_REPO }}.git" + + git clone "${TARGET_REPO_URL}" -b main target_repo + cd target_repo + git config --local user.name "github-actions[bot]" + git config --local user.email "github-actions[bot]@users.noreply.github.com" + find . -mindepth 1 -maxdepth 1 ! -name ".github" ! -name "." ! -name ".git" -exec rm -rf {} + + cp -r ../docs/_build/html/* ./ + git add . + if [[ -n "$(git status --porcelain)" ]]; then + # If there are changes, commit and push + git commit -m "Update docs" + git push "https://github-actions[bot]:${{ secrets.TARGET_TOKEN }}@${TARGET_REPO_URL##*://}" main + else + echo "No changes detected, skipping commit and push." + fi diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8dc9fd8804d20688d64170e84f387467e235afba --- /dev/null +++ b/.gitignore @@ -0,0 +1,122 @@ +# Compiled Object files +*.slo +*.lo +*.o +*.so +*.obj +*.pyc + +# Precompiled Headers +*.gch +*.pch + +# emacs +*~ + +# vim +*.swp +*.swo + +debug/ +build/ +*dist/ +dist*/ +wheelhouse/ +__pycache__ +nnfusion.tar.gz + +# makeenv and test intermediate files +tmp/ + +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.vscode/ +.vs/ + +# VisualGDB files +VisualGDB/ +toolchain.cmake + +# docbuild artifacts +doc/sphinx/build/* +doc/doxygen/*.xml +doc/doxygen/*.html +doc/doxygen/man/* +doc/doxygen/latex/* +doc/doxygen/xml/* +doc/doxygen/html/* + +# git merge +*.orig +\#* +\.#* + +# idea +.idea/* + +# python egg +*.egg-info + +# Macos +**/.DS_Store + +nnfusion_rt/ +models/frozenmodels/ + +# log +*.log + +# pkl +*.pkl_* + +# .pytest_cache +.pytest_cache + +# .hypothesis +.hypothesis + +# .ruff_cache +.ruff_cache + +# exclude debug testing folder +!testing/python/debug + +# ignore lib with develop mode +tilelang/lib + +# cython +tilelang/jit/adapter/cython/.cycache + +# cache directory for clangd +.cache/ + +# claude +**/.claude + +# CMake +cmake-build/ +cmake-build-*/ + +# Git version for sdist +.git_commit.txt + +# pre-commit cache +.pre-commit-cache/* + +# host checks logs +maint/host_checks/logs/* + +# ncu +*.ncu-rep + +# csv +*.csv + +# clang-tidy +/run-clang-tidy.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..67ce3488ae726121186027e5ca04b26fe0efc94b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass +[submodule "3rdparty/tvm"] + path = 3rdparty/tvm + url = https://github.com/TileLang/tvm +[submodule "3rdparty/composable_kernel"] + path = 3rdparty/composable_kernel + url = https://github.com/ROCm/composable_kernel diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3504adb6d1b417b628080255c497d681e0df512d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,56 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +ci: + autofix_prs: false + autofix_commit_msg: "[Lint]: [pre-commit.ci] auto fixes [...]" + autoupdate_commit_msg: "[CI] [pre-commit.ci] autoupdate" + autoupdate_schedule: monthly +default_stages: [pre-commit, pre-push, manual] +exclude: '^(build|3rdparty)/.*$' # exclude build and 3rdparty directories +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-symlinks + - id: destroyed-symlinks + # FIXME: enable these hooks + # - id: trailing-whitespace + # - id: end-of-file-fixer + - id: check-added-large-files + - id: check-merge-conflict + fail_fast: true + # FIXME: enable these hooks + # - id: check-executables-have-shebangs + # - id: check-shebang-scripts-are-executable + - id: detect-private-key + - id: check-yaml + - id: check-toml + - id: check-ast + fail_fast: true + - id: debug-statements + - id: file-contents-sorter + args: [--ignore-case] + files: ^docs/spelling_wordlist\.txt$ + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v21.1.7 # sync with requirements-lint.txt + hooks: + - id: clang-format + types_or: [c++, c] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.9 # sync with requirements-lint.txt + hooks: + - id: ruff-check + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format + args: [--exit-non-zero-on-format] + - repo: https://github.com/codespell-project/codespell + rev: v2.4.1 # sync with requirements-lint.txt + hooks: + - id: codespell + additional_dependencies: [".[toml]"] + exclude: | + (?x)( + ^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$| + ^.+\.svg$| + ^.*\brequirements\b.*\.txt$ + ) diff --git a/3rdparty/.gitignore b/3rdparty/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f2ce68266f7aad7c9ff3152452c3772822961c8c --- /dev/null +++ b/3rdparty/.gitignore @@ -0,0 +1,3 @@ +clang* + +llvm* diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel new file mode 160000 index 0000000000000000000000000000000000000000..b38bb492a1a55b5abb0c345962143c0f9c482cfb --- /dev/null +++ b/3rdparty/composable_kernel @@ -0,0 +1 @@ +Subproject commit b38bb492a1a55b5abb0c345962143c0f9c482cfb diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 0000000000000000000000000000000000000000..b2dd65dc864e09688245b316ac46c4a6cd07e15c --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit b2dd65dc864e09688245b316ac46c4a6cd07e15c diff --git a/3rdparty/tvm b/3rdparty/tvm new file mode 160000 index 0000000000000000000000000000000000000000..79ed747db67e60d3a1889d8afd33473bc2424ade --- /dev/null +++ b/3rdparty/tvm @@ -0,0 +1 @@ +Subproject commit 79ed747db67e60d3a1889d8afd33473bc2424ade diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7af7f854fd2e8918adc9ffc160760104abc8e61d --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,332 @@ +# Learn a lot from the MLC - LLM Project +# https://github.com/mlc-ai/mlc-llm/blob/main/CMakeLists.txt + +cmake_minimum_required(VERSION 3.26) +project(TILE_LANG C CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "$ENV{CIBUILDWHEEL}") + # Warning came from tvm submodule + string(APPEND CMAKE_CXX_FLAGS " -Wno-dangling-reference") +endif() + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake) + +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git") + find_package(Git QUIET) + if(Git_FOUND) + execute_process( + COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE TILELANG_GIT_SUBMODULE_RESULT + ) + if(NOT TILELANG_GIT_SUBMODULE_RESULT EQUAL 0) + message( + FATAL_ERROR + "Failed to initialize git submodules. Please run " + "`git submodule update --init --recursive` and re-run CMake." + ) + endif() + else() + message( + FATAL_ERROR + "Git is required to initialize TileLang submodules. " + "Please install git or fetch the submodules manually." + ) + endif() +endif() + +find_program(CCACHE_PROGRAM ccache) +if(CCACHE_PROGRAM) + message(STATUS "Using ccache: ${CCACHE_PROGRAM} with base_dir=${CMAKE_SOURCE_DIR}") + if(APPLE) + # Passing configs like `ccache base_dir=/xxx cc ...` is supported + # (likely) since ccache 4.x, which has been provided by homebrew. + # Our Linux builder image (manylinux2014 & manylinux_2_28) still + # provides ccache 3.x and do not support this form. + # `cibuildwheel` uses fixed folder on Linux (`/project`) as working directory, + # so cache would work without setting `base_dir`. + set(CCACHE_PROGRAM "${CCACHE_PROGRAM};base_dir=${CMAKE_SOURCE_DIR}") + endif() + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") + set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") +else() + find_program(SCCACHE_PROGRAM sccache) + if(SCCACHE_PROGRAM) + message(STATUS "Using sccache: ${SCCACHE_PROGRAM}") + set(CMAKE_C_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "C compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") + set(CMAKE_CUDA_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") + endif() +endif() + +# Configs +set(TILELANG_BACKENDS CUDA ROCM METAL) + +set(TILELANG_BACKEND_DOC_CUDA "Enable CUDA backend (ON/OFF/or CUDA SDK path)") +set(TILELANG_BACKEND_DOC_ROCM "Enable ROCm backend (ON/OFF/or ROCm SDK path)") +set(TILELANG_BACKEND_DOC_METAL "Enable Metal backend") + +# TVM's config.cmake redefines USE_* options later, so we cache the user's choice +# (including explicit -DUSE_XXX arguments) before we include TVM and restore it +# afterwards. + +macro(tilelang_define_backend_option BACKEND) + set(_backend_var "USE_${BACKEND}") + set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}") + set(_user_override_var "TILELANG_USER_OVERRIDE_${_backend_var}") + + set(_user_override OFF) + if(DEFINED ${_user_override_var}) + set(_user_override "${${_user_override_var}}") + endif() + + if(DEFINED CACHE{${_backend_var}}) + get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE) + if(_cache_type STREQUAL "UNINITIALIZED") + set(_user_override ON) + endif() + endif() + + set(_default OFF) + if(DEFINED ${_backend_var}) + set(_default "${${_backend_var}}") + endif() + + option(${_backend_var} "${_doc}" "${_default}") + # Remember if the user explicitly set this option so that later logic + # won't auto-toggle backends they configured on the command line. + set(${_user_override_var} ${_user_override} CACHE INTERNAL + "User explicitly set ${_backend_var} during configuration" FORCE) + set(TILELANG_OPTION_${_backend_var} "${${_backend_var}}") +endmacro() + +foreach(BACKEND IN LISTS TILELANG_BACKENDS) + tilelang_define_backend_option(${BACKEND}) +endforeach() + +set(PREBUILD_CYTHON ON) +# Configs end + +include(cmake/load_tvm.cmake) + +if(EXISTS ${TVM_SOURCE}/cmake/config.cmake) + include(${TVM_SOURCE}/cmake/config.cmake) +else() + message(FATAL_ERROR "Nor tvm provided or submodule checkout-ed.") +endif() +# Re-apply TileLang's preferred backend settings after TVM's config may have +# overridden the USE_* cache entries. +foreach(BACKEND IN LISTS TILELANG_BACKENDS) + set(_backend_var "USE_${BACKEND}") + set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}") + set(${_backend_var} ${TILELANG_OPTION_${_backend_var}} CACHE STRING "${_doc}" FORCE) + set(${_backend_var} ${TILELANG_OPTION_${_backend_var}}) +endforeach() + +# Include directories for TileLang +set(TILE_LANG_INCLUDES ${TVM_INCLUDES}) + +# Collect source files +file(GLOB TILE_LANG_SRCS + src/*.cc + src/layout/*.cc + src/transform/*.cc + src/transform/common/*.cc + src/op/*.cc + src/target/utils.cc + src/target/codegen_c_host.cc + src/target/codegen_cpp.cc + src/target/rt_mod_cpp.cc + # intrin_rule doesn't have system dependency + src/target/intrin_rule*.cc +) + +# Always include CPU-safe runtime helpers +list(APPEND TILE_LANG_SRCS + src/runtime/error_helpers.cc +) + +# Track if the user explicitly selected a backend via cache options. +set(TILELANG_BACKEND_USER_SELECTED OFF) +foreach(BACKEND IN LISTS TILELANG_BACKENDS) + set(_backend_var "USE_${BACKEND}") + set(_override_var "TILELANG_USER_OVERRIDE_${_backend_var}") + if(${_backend_var} OR ${_override_var}) + set(TILELANG_BACKEND_USER_SELECTED ON) + endif() +endforeach() + +# Only auto-select a backend when the user didn't specify one explicitly. +if(NOT TILELANG_BACKEND_USER_SELECTED) + if($ENV{USE_METAL}) + set(USE_METAL ON) + elseif(APPLE) + message(STATUS "Enable Metal support by default.") + set(USE_METAL ON) + elseif($ENV{USE_ROCM}) + set(USE_ROCM ON) + else() + if($ENV{USE_CUDA}) + set(USE_CUDA ON) + elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA}) + # Build CPU-only when we explicitly disable CUDA + set(USE_CUDA OFF) + else() + message(STATUS "Enable CUDA support by default.") + set(USE_CUDA ON) + endif() + endif() +endif() + +if(USE_METAL) + file(GLOB TILE_LANG_METAL_SRCS + src/target/rt_mod_metal.cc + ) + list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS}) + # FIXME: CIBW failed with backtrace, why??? + set(TVM_FFI_USE_LIBBACKTRACE OFF) +elseif(USE_ROCM) + set(CMAKE_HIP_STANDARD 17) + include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake) + find_rocm(${USE_ROCM}) + add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1) + + file(GLOB TILE_LANG_HIP_SRCS + src/target/codegen_hip.cc + src/target/rt_mod_hip.cc + ) + list(APPEND TILE_LANG_SRCS ${TILE_LANG_HIP_SRCS}) + list(APPEND TILE_LANG_INCLUDES ${ROCM_INCLUDE_DIRS}) +elseif(USE_CUDA) + set(CMAKE_CUDA_STANDARD 17) + find_package(CUDAToolkit REQUIRED) + set(CMAKE_CUDA_COMPILER "${CUDAToolkit_BIN_DIR}/nvcc") + add_compile_definitions("CUDA_MAJOR_VERSION=${CUDAToolkit_VERSION_MAJOR}") + + # Set `USE_CUDA=/usr/local/cuda-x.y` + cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA) + + file(GLOB TILE_LANG_CUDA_SRCS + src/runtime/runtime.cc + src/target/ptx.cc + src/target/codegen_cuda.cc + src/target/codegen_py.cc + src/target/codegen_utils.cc + src/target/codegen_cutedsl.cc + src/target/rt_mod_cuda.cc + src/target/rt_mod_cutedsl.cc + ) + list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS}) + + list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS}) +endif() + +set(USE_Z3 ON CACHE STRING "Use Z3 SMT solver for TileLang optimizations") +set(USE_PYPI_Z3 ON CACHE BOOL "Use Z3 provided by PyPI z3-solver package") + +if(USE_Z3 AND USE_PYPI_Z3) + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/pypi-z3") + find_package(Z3 REQUIRED) +endif() + +# Include tvm after configs have been populated +add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) + +# Resolve compile warnings in tvm +add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=) + +add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS}) + +# Set debug mode compile definitions +# We open the deubg option of TVM, i.e. TVM_LOG_DEBUG +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + message(STATUS "Building TileLang with DEBUG mode") + target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") +endif() + +target_include_directories(tilelang_objs PRIVATE ${TILE_LANG_INCLUDES}) + +add_library(tilelang SHARED $) +add_library(tilelang_module SHARED $) +target_link_libraries(tilelang PUBLIC tvm_runtime tvm) +target_link_libraries(tilelang_module PUBLIC tvm) + +# Place dev build outputs under build/lib for consistency +set_target_properties(tilelang PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) +set_target_properties(tilelang_module PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) +# Build cython extension +find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) + +add_custom_command( + OUTPUT "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" + COMMENT + "Cythoning tilelang/jit/adapter/cython/cython_wrapper.pyx" + COMMAND Python::Interpreter -m cython + "${CMAKE_CURRENT_SOURCE_DIR}/tilelang/jit/adapter/cython/cython_wrapper.pyx" + --module-name tilelang_cython_wrapper + --cplus --output-file "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" + DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/tilelang/jit/adapter/cython/cython_wrapper.pyx" + VERBATIM) + +if(NOT "${SKBUILD_SABI_VERSION}" STREQUAL "") + set(USE_SABI USE_SABI ${SKBUILD_SABI_VERSION}) +endif() + +python_add_library(tilelang_cython_wrapper MODULE "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" ${USE_SABI} WITH_SOABI) + +# Ensure dev builds drop the extension into build/lib alongside other shared libs +set_target_properties(tilelang_cython_wrapper PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" +) + +# Install the extension into tilelang/lib inside the wheel +install(TARGETS tilelang_cython_wrapper + LIBRARY DESTINATION tilelang/lib + RUNTIME DESTINATION tilelang/lib + ARCHIVE DESTINATION tilelang/lib) + +# add python z3 lib path to rpath for running tests and dev in current folder +if(USE_Z3 AND USE_PYPI_Z3) + set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/lib) + set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/bin) +endif() + +if(APPLE) + set(TILELANG_INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") + if(USE_Z3 AND USE_PYPI_Z3) + # some z3 is placed in lib/ and some in bin/, we add both in rpath + list(APPEND TILELANG_INSTALL_RPATH "@loader_path/../../z3/lib" "@loader_path/../../z3/bin") + endif() +elseif(UNIX) + set(TILELANG_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") + if(USE_Z3 AND USE_PYPI_Z3) + # cmake uses ; by default, we explicitly use : for linux + string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../z3/lib") + endif() +endif() + +# let libtilelang to search tvm/tvm_runtime in same dir +set_target_properties(tilelang PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}") +set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}") +set_target_properties(tvm PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}") +set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}") + +install( + TARGETS tvm tvm_runtime tilelang_module tilelang + LIBRARY DESTINATION tilelang/lib +) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..9e380d83170dd3876efcfdc6acc3678840273348 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,132 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socioeconomic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +[leiwang1999@outlook.com](mailto:leiwang1999@outlook.com) +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..7edc20f6a70ec6c8957d17a2ed93de339ff1e528 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,110 @@ +# Contributing + +That would be awesome if you want to contribute something to TileLang! + +### Table of Contents + +- [Report Bugs](#report-bugs) +- [Ask Questions](#ask-questions) +- [Submit Pull Requests](#submit-pull-requests) +- [Setup Development Environment](#setup-development-environment) +- [Install Develop Version](#install-develop-version) +- [Lint Check](#lint-check) +- [Test Locally](#test-locally) +- [Build Wheels](#build-wheels) +- [Documentation](#documentation) + +## Report Bugs + +If you run into any weird behavior while using TileLang, feel free to open a new issue in this repository! Please run a **search before opening** a new issue, to make sure that someone else hasn't already reported or solved the bug you've found. + +Any issue you open must include: + +- Code snippet that reproduces the bug with a minimal setup. +- A clear explanation of what the issue is. + +## Ask Questions + +Please ask questions in issues. + +## Submit Pull Requests + +All pull requests are super welcomed and greatly appreciated! Issues in need of a solution are marked with a [`โ™ฅ help`](https://github.com/ianstormtaylor/TileLang/issues?q=is%3Aissue+is%3Aopen+label%3A%22%E2%99%A5+help%22) label if you're looking for somewhere to start. + +If you're new to contributing to TileLang, you can follow the following guidelines before submitting a pull request. + +> [!NOTE] +> Please include tests and docs with every pull request if applicable! + +## Setup Development Environment + +Before contributing to TileLang, please follow the instructions below to setup. + +1. Fork TileLang ([fork](https://github.com/tile-ai/tilelang/fork)) on GitHub and clone the repository. + + ```bash + git clone --recurse-submodules git@github.com:/tilelang.git # use the SSH protocol + cd tilelang + + git remote add upstream git@github.com:tile-ai/tilelang.git + ``` + +2. Setup a development environment: + + ```bash + uv venv --seed .venv # use `python3 -m venv .venv` if you don't have `uv` + + source .venv/bin/activate + python3 -m pip install --upgrade pip setuptools wheel "build[uv]" + uv pip install --requirements requirements-dev.txt + ``` + +3. Setup the [`pre-commit`](https://pre-commit.com) hooks: + + ```bash + pre-commit install --install-hooks + ``` + +Then you are ready to rock. Thanks for contributing to TileLang! + +## Install Develop Version + +To install TileLang in an "editable" mode, run: + +```bash +python3 -m pip install --no-build-isolation --verbose --editable . +``` + +in the main directory. This installation is removable by: + +```bash +python3 -m pip uninstall tilelang +``` + +We also recommend installing TileLang in a more manual way for better control over the build process, by compiling the C++ extensions first and set the `PYTHONPATH`. See [Working from Source via `PYTHONPATH`](https://tilelang.com/get_started/Installation.html#working-from-source-via-pythonpath) for detailed instructions. + +## Lint Check + +To check the linting, run: + +```bash +pre-commit run --all-files +``` + +## Test Locally + +To run the tests, start by building the project as described in the [Setup Development Environment](#setup-development-environment) section. + +Then you can rerun the tests with: + +```bash +python3 -m pytest testing +``` + +## Build Wheels + +_TBA_ + +## Documentation + +_TBA_ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2122252e91032bd38a4fd7e51b38068b2e895912 --- /dev/null +++ b/LICENSE @@ -0,0 +1,23 @@ + MIT License + + Copyright (c) Tile-AI. + **During the period from December 1, 2024, to Mar 14, 2025, this project is + subject to additional collaboration terms with Microsoft Corporation.** + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/README.md b/README.md index 22454adfaab17b287dbf8fb74a5145d111c7e460..131c8c047402ddb679848fd00f09565b5539fe17 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,251 @@ -# tilelang + +
+ +# Tile Language +[![PyPI version](https://badge.fury.io/py/tilelang.svg)](https://badge.fury.io/py/tilelang) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/tile-ai/tilelang) [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?logo=discord&logoColor=white)](https://discord.gg/TUrHyJnKPG) + +
+ +Tile Language (**tile-lang**) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). By employing a Pythonic syntax with an underlying compiler infrastructure on top of [TVM](https://tvm.apache.org/), tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance. + + + +## Latest News +- 12/18/2025 ๐Ÿš€: Added [CuTeDSL backend](https://github.com/tile-ai/tilelang/pull/1421) support, enabling compilation to NVIDIA CUTLASS CuTe DSL! Join us in building and optimizing this exciting new backend: [Issue #1454](https://github.com/tile-ai/tilelang/issues/1454). +- 12/17/2025 ๐Ÿ”ฌ: Integrated [Z3 theorem prover](https://github.com/tile-ai/tilelang/pull/1367) into TVM Arith Analyzer, bringing SMT-based symbolic reasoning for enhanced optimizations and automatic correctness verification! +- 10/31/2025 ๐Ÿ”ง: Migrated to [apache-tvm-ffi](https://github.com/tile-ai/tilelang/pull/1108), significantly reducing CPU overhead! +- 10/30/2025 ๐Ÿ“ฆ: We have released v0.1.6.post2, which is the last version compatible with Python 3.8. +- 10/07/2025 ๐ŸŽ: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details. +- 09/29/2025 ๐ŸŽ‰: Thrilled to announce that โ€‹โ€‹AscendCโ€‹โ€‹ and โ€‹Ascendโ€‹NPU IRโ€‹โ€‹ backends targeting Huawei Ascend chips are now supported! +Check out the preview here: +๐Ÿ”— [link](https://github.com/tile-ai/tilelang-ascend). +This includes implementations across two branches: +[ascendc_pto](https://github.com/tile-ai/tilelang-ascend) and +[npuir](https://github.com/tile-ai/tilelang-ascend/tree/npuir). +Feel free to explore and share your feedback! +- 07/04/2025 ๐Ÿš€: Introduced `T.gemm_sp` for 2:4 sparse tensor core support, check out [Pull Request #526](https://github.com/tile-ai/tilelang/pull/526) for details. +- 06/05/2025 โœจ: Added [NVRTC Backend](https://github.com/tile-ai/tilelang/pull/461) to significantly reduce compilation time for cute templates! +- 04/14/2025 ๐Ÿš€: Added high-performance FlashMLA implementation for AMD MI300X, achieving performance parity with hand-optimized assembly kernels of Aiter! See [example_mla_amd](./examples/deepseek_mla/amd/README.md) for details. +- 03/03/2025 ๐Ÿš€: Added high-performance MLA Decoding support using only 80 lines of Python code, achieving performance on par with FlashMLA on H100 (see [example_mla_decode.py](./examples/deepseek_mla/example_mla_decode.py))! We also provide [documentation](./examples/deepseek_mla/README.md) explaining how TileLang achieves this. +- 02/15/2025 โœจ: Added WebGPU Codegen support, see [Pull Request #86](https://github.com/tile-ai/tilelang/pull/86)! +- 02/12/2025 โœจ: Excited to announce the release of [v0.1.0](https://github.com/tile-ai/tilelang/releases/tag/v0.1.0)! +- 02/10/2025 ๐Ÿš€: Added debug tools for TileLangโ€”`T.print` for printing variables/buffers ([docs](https://tilelang.com/tutorials/debug_tools_for_tilelang.html)) and a memory layout plotter ([examples/plot_layout](./examples/plot_layout)). +- 01/20/2025 โœจ: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public! + +## Tested Devices +Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support). + +## OP Implementation Examples +**tile-lang** provides the building blocks to implement a wide variety of operators. Some examples include: + +- [Matrix Multiplication](./examples/gemm/) +- [Dequantization GEMM](./examples/dequantize_gemm/) +- [Flash Attention](./examples/flash_attention/) +- [Flash Linear Attention](./examples/linear_attention/) +- [Flash MLA Decoding](./examples/deepseek_mla/) +- [Native Sparse Attention](./examples/deepseek_nsa/) + +Within the `examples` directory, you will also find additional complex kernelsโ€”such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added. + + +## Benchmark Summary + +TileLang achieves exceptional performance across a variety of computational patterns. Comprehensive benchmark scripts and settings are available at [tilelang-benchmark](https://github.com/tile-ai/tilelang-benchmark). Below are selected results showcasing its capabilities: + +- MLA Decoding Performance on H100 + +
+
+ mla decode performance bs64 on H100 +
+
+ mla decode performance bs128 on H100 +
+
+ +- Flash Attention Performance on H100 + +
operator performance on H100 +
+ +- Matmul Performance on GPUs (RTX 4090, A100, H100, MI300X) + +
+ gemm fp16 performance on Gpus +
+ +- Dequantize Matmul Performance on A100 + +
+ dequantize gemv performance on A100 +
+ +## Installation +### Method 1: Install with Pip + +The quickest way to get started is to install the latest release from PyPI: + +```bash +pip install tilelang +``` + +Alternatively, you can install directly from the GitHub repository: + +```bash +pip install git+https://github.com/tile-ai/tilelang +``` + +Or install locally: + +```bash +# install required system dependencies +sudo apt-get update +sudo apt-get install -y python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev + +pip install -e . -v # remove -e option if you don't want to install in editable mode, -v for verbose output +``` + +### Method 2: Build from Source +We currently provide three ways to install **tile-lang** from source: + - [Install from Source (using your own TVM installation)](./docs/get_started/Installation.md#method-1-install-from-source-using-your-own-tvm-installation) + - [Install from Source (using the bundled TVM submodule)](./docs/get_started/Installation.md#method-2-install-from-source-using-the-bundled-tvm-submodule) + - [Install Using the Provided Script](./docs/get_started/Installation.md#method-3-install-using-the-provided-script) + +### Method 3: Install with Nightly Version + +For users who want access to the latest features and improvements before official releases, we provide nightly builds of **tile-lang**. + +```bash +pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/ +# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/ +``` + +> **Note:** Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet. + +## Quick Start + +In this section, you'll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cacheโ€“friendly swizzling. + +### GEMM Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.) + +Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware. + +```python +import tilelang +import tilelang.language as T + +# @tilelang.jit(target="cuda") +# target currently can be "cuda" or "hip" or "cpu". +# if not specified, it will be inferred from the input tensors during compile time +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): + + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M = 1024 # M = T.dynamic("m") if you want to use dynamic shape +N = 1024 +K = 1024 +block_M = 128 +block_N = 128 +block_K = 32 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) + +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +matmul_relu_kernel(a, b, c) + +print(c) +# Reference multiplication using PyTorch +ref_c = torch.relu(a @ b) + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = matmul_relu_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") +``` + +### Dive Deep into TileLang Beyond GEMM + +In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including: + +- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilizing magic layout transformation and intrins to accelerate dequantize gemm. +- [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning. +- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations. +- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col. + +## Upcoming Features + +Check our [tilelang v0.2.0 release plan](https://github.com/tile-ai/tilelang/issues/79) for upcoming features. + +--- + +TileLang has now been used in project [BitBLAS](https://github.com/microsoft/BitBLAS) and [AttentionEngine](https://github.com/microsoft/AttentionEngine). + +## Join the Discussion + +Welcome to join our Discord community for discussions, support, and collaboration! + +[![Join our Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?logo=discord&style=for-the-badge)](https://discord.gg/TUrHyJnKPG) + +## Acknowledgments + +We would like to express our gratitude to the [TVM](https://github.com/apache/tvm) community for their invaluable contributions. The initial version of this project was mainly developed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410) with supervision from Prof. [Zhi Yang](https://yangzhihome.github.io) at Peking University. Part of this work was carried out during an internship at Microsoft Research, where Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang offered valuable advice and support. We deeply appreciate their mentorship and contributions. diff --git a/THIRDPARTYNOTICES.txt b/THIRDPARTYNOTICES.txt new file mode 100644 index 0000000000000000000000000000000000000000..b7c48184117f3f601c6dc28f4c384689a0cc2f5f --- /dev/null +++ b/THIRDPARTYNOTICES.txt @@ -0,0 +1,616 @@ +BitBLAS uses third-party material as listed below. The attached notices are +provided for informational purposes only. + +Notice for apache/tvm +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ +Notice for IST-DASLab/marlin/ +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +------------------------------------------------------------------------------------ + +Notice for flashinfer-ai/flashinfer +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..11808190d4b90b20fe074a2dad43af6c0c1427ee --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.1.7 diff --git a/benchmark/blocksparse_attention/benchmark_configs.py b/benchmark/blocksparse_attention/benchmark_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..a23e2136ad722b91214afe495596cd198f81f6bf --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_configs.py @@ -0,0 +1,2 @@ +# BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK +configs = [[4, 2, 256, 64, 2, 64]] diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd82aa5e5218cf37379ed69a2ff93ba1020c199 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -0,0 +1,54 @@ +# ruff: noqa +import torch +from tilelang.profiler import do_bench + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + import flash_attn + + def benchmark_fn(): + flash_attn.flash_attn_func(q, k, v, causal=True) + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py new file mode 100644 index 0000000000000000000000000000000000000000..e645ae1475bd7bc89f6b08a5c110bcd459d9e3d5 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -0,0 +1,208 @@ +# ruff: noqa +import math +import torch + +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench + + +def is_hip(): + return False + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 2 + threads = 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool + + def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[k]: + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + return kernel_func(block_M, block_N, num_stages, threads) + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = tilelang.compile(program, out_idx=4) + + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + kernel(q, k, v, block_mask) + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py new file mode 100644 index 0000000000000000000000000000000000000000..85d754ae3a77f679a9d714eb1c2f83573a47c911 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -0,0 +1,75 @@ +# ruff: noqa +import math +import torch + +import torch.nn.functional as F +from tilelang.profiler import do_bench + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + return ref_output + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebca93a6a3735ac0fae1cdcd706587d3521e5b6 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -0,0 +1,291 @@ +# ruff: noqa +import math +import torch + +import triton +import triton.language as tl +from tilelang.profiler import do_bench + + +def is_hip(): + return False + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + k_block_col_idx, + block_mask_ptr, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, + sm_scale, + seqlen_k, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) + + if mask_val == True: + start_n = k_block_col_idx * BLOCK_N + # -- compute qk ---- + + k = tl.load(k_ptrs + start_n * stride_kt) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK: + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + start_n * stride_vt) + + p = p.to(v.type.element_ty) + + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + return acc, l_i, m_i + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + block_mask_ptr, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, + PAST_LEN, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + Q_LEN = N_CTX - PAST_LEN + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_h = off_hz % H + off_z = off_hz // H + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + mask_ptrs = block_mask_ptr + start_m * stride_bmm + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) + + k_block_start = 0 + k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) + + # loop over k, v and update accumulator + for col_idx in range(k_block_start, k_block_end): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + col_idx, + mask_ptrs, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, + sm_scale, + N_CTX, + PAST_LEN, + col_idx == k_block_end - 1, + BLOCK_M, + BLOCK_N, + ) + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(Out.dtype.element_ty) + + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) + + +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert k.shape[2] == v.shape[2] + o = out if out is not None else torch.empty_like(q).contiguous() + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) + + assert q.shape[-1] in [64, 128] + BLOCK_DMODEL = q.shape[-1] + + if is_hip(): + num_warps, num_stages = 8, 1 + else: + num_warps, num_stages = 4, 2 + + N_CTX = k.shape[2] + PAST_LEN = N_CTX - q.shape[2] + + H = q.shape[1] + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + block_sparse_mask, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), + *o.stride(), + H, + N_CTX, + PAST_LEN, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o + + +class _sparse_attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, block_sparse_dense, sm_scale): + # shape constraints + return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) + + @staticmethod + def backward(ctx, do): + # No gradient propagation. + raise NotImplementedError("It does not support gradient propagation yet") + return None, None, None, None, None + + +block_sparse_triton_fn = _sparse_attention.apply + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + block_sparse_triton_fn(q, k, v, block_mask, sm_scale) # noqa: B023 + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/requirements.txt b/benchmark/blocksparse_attention/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d0f50993616a1344cdab84f7851406605a0089b9 --- /dev/null +++ b/benchmark/blocksparse_attention/requirements.txt @@ -0,0 +1 @@ +flash-attn diff --git a/benchmark/mamba2/README.md b/benchmark/mamba2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0b6de19b1e9df2fea1f5ce6a2ae5ac2cb032e789 --- /dev/null +++ b/benchmark/mamba2/README.md @@ -0,0 +1,59 @@ +# Mamba2_chunk_scan Benchmark + +This document records the throughput achieved by `benchmark_mamba_chunk_scan.py` when computing `batch = 8`, `heads = 80`, `groups = 1`, `chunk_size = 256`, `dim = 64`, and `dstate = 128` across different `seq_len` using the default autotuning search space. + +## Environment + +- Repository commit: `8a5eb569704bfea64478c29adcfe3a09e3c2b12c` +- GPUs: `NVIDIA H800 SXM` on driver `560.35.05` + +## How to Reproduce + +```bash +cd benchmark/mamba2 +python - <<'PY' +from benchmark_mamba_chunk_scan import chunk_scan_fwd + +batch = 8 +heads = 80 +groups = 1 +chunk_size = 256 +dim = 64 +dstate = 128 +for seq_len in [1024, 2048, 4096, 8192, 16384, 32768]: + res = chunk_scan_fwd( + batch, + seq_len, + chunk_size, + groups, + heads, + dim, + dstate) + tflops = (2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate) / res.latency * 1e-9 + print(f"seq_len={seq_len:5d} latency={res.latency:.6f}ms TFlops={tflops:.3f}") +PY +``` + +## Results + +| Seq_len| Latency (ms) | Throughput (TFLOPs) | +|-------|-------------|---------------------| +| 1024 | 0.169 | 126.477 | +| 2048 | 0.329 | 130.195 | +| 4096 | 0.645 | 133.054 | +| 8192 | 1.278 | 134.362 | +| 16384 | 2.531 | 135.711 | +| 32768 | 5.076 | 135.379 | + + +## Compare with Baselines + +- Triton: v3.5.0, mamba-ssm: v2.2.6.post3 +- Helion: v0.2.1 + +
+ + Mamba2_chunk_scan Performance Comparison on H100 + +
Performance comparison across compilers on NVIDIA H100
+
\ No newline at end of file diff --git a/benchmark/mamba2/benchmark_mamba_chunk_scan.py b/benchmark/mamba2/benchmark_mamba_chunk_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f5cec670a40aad2e62d78847613277c8827648 --- /dev/null +++ b/benchmark/mamba2/benchmark_mamba_chunk_scan.py @@ -0,0 +1,389 @@ +import argparse +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, repeat +import itertools +import math +from tilelang.profiler import do_bench + +try: + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd +except ImportError as err: + raise ImportError("Please install mamba-ssm to use the triton chunk scan operator.") from err + +try: + import helion + from helion._testing import run_example + import helion.language as hl +except ImportError as err: + raise ImportError("Please install helion to use the helion chunk scan operator.") from err + + +def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): + """ + Argument: + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + C: (batch, seqlen, ngroups, dstate) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + _, _, ngroups, _, _ = cb.shape + batch, seqlen, nheads, headdim = x.shape + # _, _, ngroups, dstate = B.shape + # assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + # assert C.shape == B.shape + # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups) + # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + # rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return out + + +def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) + return out + + +def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): + @helion.kernel() + def helion_mamba2_chunk_scan_kernel( + cb: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + dA_cumsum: torch.Tensor, + C: torch.Tensor, + prev_states: torch.Tensor, + D: torch.Tensor, + ) -> torch.Tensor: + """ + Argument: + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + C: (batch, seqlen, ngroups, dstate) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads,) + Return: + out: (batch, seqlen, nheads, headdim) + """ + + batch, nchunks, ngroups, chunk_size, _ = cb.shape + _, seqlen, nheads, headdim = x.shape + _, _, _, dstate = C.shape + assert nchunks == (seqlen + chunk_size - 1) // chunk_size + + block_m = hl.register_block_size(chunk_size) + block_n = hl.register_block_size(headdim) + block_k = hl.register_block_size(64, 64) + dstate = hl.specialize(dstate) + + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert C.shape == (batch, seqlen, ngroups, dstate) + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert D.shape == (nheads,) + + dtype = cb.dtype + accum_dtype = torch.float32 + assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype + + out = torch.empty_like(x) + + p = 1.44269504 + + for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( + [nheads, chunk_size, headdim, batch, nchunks], + block_size=[1, block_m, block_n, 1, 1], + ): + acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) + dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32) + scale_m_local = torch.exp2(dA_cumsum_local_m * p) + + C_local = C[ + tile_b.begin, + tile_m.index + tile_c.begin * chunk_size, + tile_h.begin // (nheads // ngroups), + :, + ] + prev_states_local = prev_states[tile_b.begin, tile_c.begin, tile_h.begin, tile_n, :] + acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o) + acc_o *= scale_m_local[:, None] + + for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k): + cb_local = cb[ + tile_b.begin, + tile_c.begin, + tile_h.begin // (nheads // ngroups), + tile_m, + tile_k, + ] + dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) + cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p) + dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) + cb_local = (cb_local * dt_local[None, :]).to(dtype) + pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] + cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local)) + x_local = x[ + tile_b.begin, + tile_c.begin * chunk_size + tile_k.index, + tile_h.begin, + tile_n, + ] + acc_o = hl.dot(cb_local, x_local, acc=acc_o) + + D_local = D[tile_h.begin].to(torch.float32) + x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32) + acc_o += x_residual * D_local + out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype) + + return out + + args = (cb, x, dt, dA_cumsum, C, states, D) + run_example(helion_mamba2_chunk_scan_kernel, ref_program, args) + + +def get_configs(): + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[7], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): + dtype = T.float16 + accum_dtype = T.float32 + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + + @T.prim_func + def main( + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + ): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") + cb_local = T.alloc_fragment((block_M, block_K), dtype) + dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") + dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) + dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) + dt_shared = T.alloc_shared((block_K), dtype, scope="shared") + dt_local = T.alloc_fragment((block_K), accum_dtype) + x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") + dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") + scale_m_local = T.alloc_fragment((block_M), accum_dtype) + C_shared = T.alloc_shared((block_M, block_Dstate), dtype) + prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) + D_local = T.alloc_fragment((1), accum_dtype) + x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") + x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + batch_idx = by % batch + chunk_idx = by // batch + # m: chunk_size + # n : headdim + m_idx = bx // T.ceildiv(headdim, block_N) + n_idx = bx % T.ceildiv(headdim, block_N) + + T.annotate_layout( + { + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) + + T.no_set_max_nreg() + + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) + T.copy(dA_cs_m_shared, dA_cs_m_local) + T.clear(acc_o) + + for i in T.Parallel(block_M): + scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) + T.copy( + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) + T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] *= scale_m_local[i] + + loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) + T.copy(cb_shared, cb_local) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) + T.copy(dA_cs_k_shared, dA_cs_k_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) + T.copy(dt_shared, dt_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] *= dt_local[j] + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) + T.copy( + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) + T.gemm(cb_local, x_shared, acc_o) + + D_local[0] = D[bz] + T.copy( + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) + T.copy(x_residual_shared, x_residual_local) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] += x_residual_local[i, j] * D_local[0] + + T.copy(acc_o, acc_o_shared) + T.copy( + acc_o_shared, + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) + + return main + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) + nchunks = math.ceil(seq_len / chunk_size) + total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate + + print("Benchmarking TileLang...") + kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + + cb = torch.randn(batch, nchunks, groups, chunk_size, chunk_size).half().cuda() + x = torch.randn(batch, seq_len, heads, dim).half().cuda() + dt = torch.randn(batch, heads, nchunks, chunk_size).half().cuda() + dA_cumsum = torch.randn(batch, heads, nchunks, chunk_size).half().cuda() + C = torch.randn(batch, seq_len, groups, dstate).half().cuda() + states = torch.randn(batch, nchunks, heads, dim, dstate).half().cuda() + D = torch.randn(heads).half().cuda() + + print("Benchmarking Triton...") + triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) + print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}") + + print("Benchmarking Helion...") + chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D) diff --git a/benchmark/mamba2/mamba_benchmark_result.png b/benchmark/mamba2/mamba_benchmark_result.png new file mode 100644 index 0000000000000000000000000000000000000000..6915508c25d3e212d3e06188bd54f6eb16a3df95 Binary files /dev/null and b/benchmark/mamba2/mamba_benchmark_result.png differ diff --git a/benchmark/matmul/README.md b/benchmark/matmul/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3ecafa5de48866bc00c4b67b746698fd5af5df17 --- /dev/null +++ b/benchmark/matmul/README.md @@ -0,0 +1,36 @@ +# FP16 Matmul Benchmark (8192ร—8192) + +This document records the throughput achieved by `benchmark_matmul.py` when multiplying FP16 matrices sized `M = N = 8192` across different `K` dimensions using the default autotuning search space. + +## Environment + +- Repository commit: `17bd0a6c651f599bec1397e0b91830c3ddc93076` +- GPUs: `NVIDIA H800 SXM` on driver `560.35.05` + +## How to Reproduce + +```bash +cd benchmark/matmul +python - <<'PY' +from benchmark_matmul import matmul + +M = 8192 +N = 8192 +for K in [256, 512, 1024, 2048, 4096, 8192, 16384]: + res = matmul(M, N, K, False) + tflops = 2 * M * N * K / res.latency * 1e-12 + print(f"K={K:5d} latency={res.latency:.6f}s TFlops={tflops:.3f}") +PY +``` + +## Results + +| K | Latency (s) | Throughput (TFLOPs) | +|-------|-------------|---------------------| +| 256 | 0.089056 | 386 | +| 512 | 0.132064 | 520 | +| 1024 | 0.218816 | 628 | +| 2048 | 0.390112 | 705 | +| 4096 | 0.746752 | 736 | +| 8192 | 1.449888 | 758 | +| 16384 | 2.871168 | 766 | diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..643c1fd5e9661634a48068f6a313433ea257a179 --- /dev/null +++ b/benchmark/matmul/benchmark_matmul.py @@ -0,0 +1,250 @@ +import argparse +import itertools +import logging + +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune +from tilelang import jit + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def ref_program(A, B): + """ + A reference matrix multiplication program, used to compare performance. + + Parameters + ---------- + A : numpy.ndarray + The matrix with shape (M, K). + B : numpy.ndarray + The matrix with shape (N, K). + + Returns + ------- + np.ndarray + The result of A @ B.T, shape (M, N). + """ + return A @ B.T + + +def get_configs(args, kwargs): + """ + Generate a list of configuration dictionaries that will be used for tuning. + + Parameters + ---------- + with_roller : bool + Whether to enable bitblas roller to deduce search spaces + + Returns + ------- + list of dict + Each configuration dict includes various block sizes, pipeline stages, + thread numbers, and other parameters to explore during autotuning. + """ + M, N, K, with_roller = args[:4] + + if with_roller: + from tilelang.carver.template import MatmulTemplate + from tilelang.carver.arch import CUDA + from tilelang.carver.arch import CDNA + from tilelang.carver.roller.rasterization import NoRasterization + import torch + + arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + topk = 10 + + carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + roller_hints = carve_template.recommend_hints(topk=topk) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + # block_rows, block_cols represents warp partitioning + block_rows, block_cols = block_m // warp_m, block_n // warp_n + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = hint.pipeline_stage + config["thread_num"] = block_rows * block_cols * 32 + config["policy"] = T.GemmWarpPolicy.from_warp_partition(block_rows, block_cols) + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + for config in configs: + print(config) + else: + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[32, 64], + num_stages=[0, 1, 2, 3], + thread_num=[128, 256], + policy=[T.GemmWarpPolicy.Square], + enable_rasteration=[True, False], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + return configs + + +@autotune( + configs=get_configs, + warmup=3, + rep=20, +) +@jit( + out_idx=[2], +) +def matmul( + M, + N, + K, + with_roller, + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + policy=None, + enable_rasteration=None, +): + """ + Create an autotuned matrix multiplication kernel for matrices of shape: + - A: (M, K) + - B: (N, K) + - C: (M, N) + + Parameters + ---------- + M : int + The dimension M of the matrix multiplication. + N : int + The dimension N of the matrix multiplication. + K : int + The dimension K of the matrix multiplication. + + Returns + ------- + (best_latency, best_config, ref_latency) + best_latency : float + The best latency found among the tuned configurations. + best_config : dict + The parameter configuration that yielded best_latency. + ref_latency : float + The baseline latency of the reference program (for computing speedup). + """ + + # Use half-precision for input data to reduce memory bandwidth, + # accumulate in float for better numerical accuracy + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + """ + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. + """ + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # Allocate a shared memory for C sub-block of shape (block_M, block_N) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + # Enable (or disable) swizzling optimization + T.use_swizzle(panel_size=10, enable=enable_rasteration) + # to utilize swizzle tma layout + T.annotate_layout({C_shared: tilelang.layout.make_swizzled_layout(C_shared)}) + + # Clear out the accumulation buffer + T.clear(C_local) + + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy(A[by * block_M, k * block_K], A_shared) + # Load a sub-block of B from global memory into B_shared + T.copy(B[bx * block_N, k * block_K], B_shared) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + policy=policy, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +if __name__ == "__main__": + # Parse command-line arguments for matrix dimensions + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument( + "--with_roller", + action="store_true", + help="Whether to enable BitBLAS roller for search space", + ) + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + with_roller = args.with_roller + + # Compute total floating-point operations to measure throughput + total_flops = 2 * M * N * K + + # matmul(...) returns (best_latency, best_config, ref_latency) + best_result = matmul(M, N, K, with_roller) + best_latency = best_result.latency + best_config = best_result.config + ref_latency = best_result.ref_latency + + # Print out the benchmark results + print(f"Best latency (s): {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") + print(f"Best config: {best_config}") + + if ref_latency is not None: + print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef860c21039b2b3288a43f5d09bae9ed30e7a36 --- /dev/null +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -0,0 +1,316 @@ +import argparse +import logging +from tilelang import tvm as tvm +from tvm import DataType +import tilelang as tl +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func +from tilelang.autotuner import autotune +import itertools + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + stage=2, + enable_rasteration=False, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + # chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M, + block_N, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10, enable=enable_rasteration) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a(A_local, A_shared, ki) + + # Load B into fragment + mma_emitter.ldmatrix_b(B_local, B_shared, ki) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix(C_local, C_shared) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[i, j] + + return main + + +def ref_program(A, B): + """Reference matrix multiplication program.""" + return A @ B.T + + +def get_configs(args, kwargs): + """ + Generate a list of configuration dictionaries that will be used for tuning. + + Parameters + ---------- + with_roller : bool + Whether to enable bitblas roller to deduce search spaces + + Returns + ------- + list of dict + Each configuration dict includes various block sizes, pipeline stages, + thread numbers, and other parameters to explore during autotuning. + """ + M, N, K = args[:3] + with_roller = args[6] + + if with_roller: + from tilelang.carver.template import MatmulTemplate + from tilelang.carver.arch import CUDA + from tilelang.carver.arch import CDNA + from tilelang.carver.roller.rasterization import NoRasterization + import torch + + arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + topk = 10 + + carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + roller_hints = carve_template.recommend_hints(topk=topk) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + config["block_row_warps"] = block_m // warp_m + config["block_col_warps"] = block_n // warp_n + config["warp_row_tiles"] = warp_m + config["warp_col_tiles"] = warp_n + config["chunk"] = hint.rstep[0] + config["stage"] = hint.pipeline_stage + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + for config in configs: + print(config) + else: + iter_params = dict( + block_row_warps=[1, 2, 4], + block_col_warps=[1, 2, 4], + warp_row_tiles=[16, 32, 64, 128], + warp_col_tiles=[16, 32, 64, 128], + chunk=[32, 64, 128, 256], + stage=[0, 2], + enable_rasteration=[True, False], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + return configs + + +@autotune( + configs=get_configs, + warmup=3, + rep=5, + ref_prog=ref_program, + skip_check=True, +) +@tl.jit( + out_idx=[2], +) +def matmul( + M, + N, + K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, + with_roller=False, + block_row_warps=None, + block_col_warps=None, + warp_row_tiles=None, + warp_col_tiles=None, + chunk=None, + stage=None, + enable_rasteration=None, +): + """Create an autotuned tensor core matrix multiplication kernel.""" + + def kernel(): + return tl_matmul( + M, + N, + K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + stage=stage, + enable_rasteration=enable_rasteration, + ) + + return kernel() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned TensorCore MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces") + parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + in_dtype = T.dtype(args.dtype) + out_dtype = T.float32 if in_dtype == T.int8 else T.float16 + accum_dtype = T.float32 if in_dtype == T.int8 else T.float16 + with_roller = args.with_roller + with_roller = True + # Compute total floating-point operations + total_flops = 2 * M * N * K + + # Run autotuning + best_result = matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_roller) + best_latency = best_result.latency + best_config = best_result.config + ref_latency = best_result.ref_latency + + # Print benchmark results + print(f"Best latency (s): {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") + print(f"Best config: {best_config}") + print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..7ecffc26a28b9d171e70fcd1b25395708310be2e --- /dev/null +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -0,0 +1,288 @@ +import argparse +import itertools +import logging +import torch +from triton.testing import do_bench + +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune +from tilelang import jit +from tilelang.contrib import nvcc +from tilelang.layout import make_cutlass_metadata_layout + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +arch = nvcc.get_target_compute_version() + +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + + +def ref_program(A, B): + """ + A reference matrix multiplication program, used to compare performance. + + Parameters + ---------- + A : numpy.ndarray + The matrix with shape (M, K). + B : numpy.ndarray + The matrix with shape (N, K). + + Returns + ------- + np.ndarray + The result of A @ B.T, shape (M, N). + """ + return A @ B.T + + +def get_configs(M, N, K): + """ + Generate a list of configuration dictionaries that will be used for tuning. + + Parameters + ---------- + with_roller : bool + Whether to enable bitblas roller to deduce search spaces + + Returns + ------- + list of dict + Each configuration dict includes various block sizes, pipeline stages, + thread numbers, and other parameters to explore during autotuning. + """ + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [64, 128] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + policy = [T.GemmWarpPolicy.Square] + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + policy, + enable_rasterization, + ) + ) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "policy": c[5], + "enable_rasterization": c[6], # keep param name for backward-compat + } + for c in _configs + ] + return configs + + +def matmul_sp(M, N, K, in_dtype, accum_dtype): + """ + Create an autotuned matrix multiplication kernel for matrices of shape: + - A: (M, K) + - B: (K, N) + - C: (M, N) + + Parameters + ---------- + M : int + The dimension M of the matrix multiplication. + N : int + The dimension N of the matrix multiplication. + K : int + The dimension K of the matrix multiplication. + + Returns + ------- + (best_latency, best_config, ref_latency) + best_latency : float + The best latency found among the tuned configurations. + best_config : dict + The parameter configuration that yielded best_latency. + ref_latency : float + The baseline latency of the reference program (for computing speedup). + """ + + # Decorate the kernel with autotune & jit, specifying: + # - Tuning config list + # - Profiling keys + # - Warmup and repetition counts for better measurement + # - A reference program for correctness verification + # - The "tvm" profiler backend + # - HIP as the compilation target (modify as needed for your hardware) + + @autotune( + configs=get_configs(M, N, K), + warmup=3, + rep=20, + ) + @jit( + out_idx=[2], + ) + def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + policy=None, + enable_rasterization=None, + ): + """ + The actual kernel to compute C = A @ B^T. + + Parameters + ---------- + block_M : int + Block size in M dimension. + block_N : int + Block size in N dimension. + block_K : int + Block size in K dimension. + num_stages : int + Number of pipelined stages (for asynchronous load). + thread_num : int + Number of threads to use per block. + k_pack : int + K dimension packing factor to improve memory coalescing. + + Returns + ------- + Function + A TVM Tensor Language function (T.prim_func) that computes matmul. + """ + # Use half-precision for input data to reduce memory bandwidth, + # accumulate in float for better numerical accuracy + e_factor, e_dtype = ARCH_INFO[arch] + + @T.prim_func + def main( + A_sparse: T.Tensor((M, K // 2), in_dtype), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), in_dtype), + C: T.Tensor((M, N), accum_dtype), + ): + """ + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. + """ + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_K, block_N), in_dtype) + # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # Allocate a shared memory for C sub-block of shape (block_M, block_N) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + + # Clear out the accumulation buffer + T.clear(C_local) + T.disable_warp_group_reg_alloc() + + T.use_swizzle(panel_size=10, enable=enable_rasterization) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), + } + ) + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + # Load a sub-block of E from global memory into E_shared + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + # Load a sub-block of B from global memory into B_shared + T.copy(B[k * block_K, bx * block_N], B_shared) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared + T.gemm_sp_v2( + A_shared, + E_shared, + B_shared, + C_local, + transpose_B=False, + policy=policy, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + return kernel() + + +if __name__ == "__main__": + # Parse command-line arguments for matrix dimensions + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--disable_cache", action="store_true") + parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") + parser.add_argument( + "--bench_torch_sparse", + type=str, + choices=["cutlass", "cusparselt"], + default=None, + help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported", + ) + args = parser.parse_args() + + if args.disable_cache: + tilelang.disable_cache() + + M, N, K = args.m, args.n, args.k + + # Compute total floating-point operations to measure throughput + total_flops = 2 * M * N * K + + # matmul(...) returns (best_latency, best_config, ref_latency) + best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype) + best_latency = best_result.latency + best_config = best_result.config + A = torch.randn(M, K, dtype=torch.float16, device="cuda") + B = torch.randn(K, N, dtype=torch.float16, device="cuda") + ref_latency = do_bench(lambda: A @ B) + + if args.bench_torch_sparse is not None: + from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor + + if args.bench_torch_sparse == "cutlass": + SparseSemiStructuredTensor._FORCE_CUTLASS = True + A_sp = to_sparse_semi_structured(A, transposed=False) + torch_sparse_latency = do_bench(lambda: A_sp @ B) + + # Print out the benchmark results + print(f"Best latency (s): {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") + print(f"Best config: {best_config}") + + if args.bench_torch_sparse is not None: + print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}") + + print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/benchmark/matmul_fp8/README.md b/benchmark/matmul_fp8/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fe181e2b3105d5e596185340f887afed67cae6f4 --- /dev/null +++ b/benchmark/matmul_fp8/README.md @@ -0,0 +1,36 @@ +# FP8 Matmul Benchmark (8192ร—8192) + +This document records the throughput achieved by `benchmark_matmul.py` when multiplying FP8 matrices sized `M = N = 8192` across different `K` dimensions. Each measurement relies on the default autotuning search space bundled with the benchmark. + +## Environment + +- Repository commit: `6b1faf71faf18c564f5f77e0f5c1671cd91dfbc3` +- GPUs: `NVIDIA H800 SXM` on driver `560.35.05` + +## How to Reproduce + +```bash +cd benchmark/matmul_fp8 +python - <<'PY' +from benchmark_matmul import matmul + +M = 8192 +N = 8192 +for K in [256, 512, 1024, 2048, 4096, 8192, 16384]: + res = matmul(M, N, K, False) + tflops = 2 * M * N * K / res.latency * 1e-12 + print(f"K={K:5d} latency={res.latency:.6f}s TFlops={tflops:.3f}") +PY +``` + +## Results + +| K | Latency (s) | Throughput (TFLOPs) | +|-------|-------------|---------------------| +| 256 | 0.060352 | 569 | +| 512 | 0.080096 | 858 | +| 1024 | 0.121696 | 1129 | +| 2048 | 0.204672 | 1343 | +| 4096 | 0.374816 | 1467 | +| 8192 | 0.729664 | 1507 | +| 16384 | 1.427264 | 1541 | diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e62812fc53cafb11e6b7d433a1cb61c37076f8 --- /dev/null +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -0,0 +1,251 @@ +import argparse +import itertools +import torch +import logging +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune +from tilelang import jit + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def ref_program(A, B): + """ + A reference matrix multiplication program, used to compare performance. + + Parameters + ---------- + A : numpy.ndarray + The matrix with shape (M, K). + B : numpy.ndarray + The matrix with shape (N, K). + + Returns + ------- + np.ndarray + The result of A @ B.T, shape (M, N). + """ + return A.float() @ B.T.float() + + +def get_configs(args, kwargs): + """ + Generate a list of configuration dictionaries that will be used for tuning. + + Parameters + ---------- + with_roller : bool + Whether to enable bitblas roller to deduce search spaces + + Returns + ------- + list of dict + Each configuration dict includes various block sizes, pipeline stages, + thread numbers, and other parameters to explore during autotuning. + """ + M, N, K, with_roller = args[:4] + + if with_roller: + from tilelang.carver.template import MatmulTemplate + from tilelang.carver.arch import CUDA + from tilelang.carver.arch import CDNA + from tilelang.carver.roller.rasterization import NoRasterization + import torch + + arch = CDNA("hip") if torch.version.hip is not None else CUDA("cuda") + + topk = 10 + + carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + roller_hints = carve_template.recommend_hints(topk=topk) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + # block_rows, block_cols represents warp partitioning + block_rows, block_cols = block_m // warp_m, block_n // warp_n + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = hint.pipeline_stage + config["thread_num"] = block_rows * block_cols * 32 + config["policy"] = T.GemmWarpPolicy.from_warp_partition(block_rows, block_cols) + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + for config in configs: + print(config) + else: + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[64, 128], + num_stages=[0, 1, 2, 3], + thread_num=[128, 256], + k_pack=[1, 2], + policy=[T.GemmWarpPolicy.Square], + enable_rasteration=[True, False], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + return configs + + +@autotune( + configs=get_configs, + warmup=3, + rep=20, +) +@jit( + out_idx=[2], +) +def matmul( + M, + N, + K, + with_roller, + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + k_pack=None, + policy=None, + enable_rasteration=None, +): + """ + Create an autotuned matrix multiplication kernel for matrices of shape: + - A: (M, K) + - B: (N, K) + - C: (M, N) + + Parameters + ---------- + M : int + The dimension M of the matrix multiplication. + N : int + The dimension N of the matrix multiplication. + K : int + The dimension K of the matrix multiplication. + + Returns + ------- + (best_latency, best_config, ref_latency) + best_latency : float + The best latency found among the tuned configurations. + best_config : dict + The parameter configuration that yielded best_latency. + ref_latency : float + The baseline latency of the reference program (for computing speedup). + """ + + # Use half-precision for input data to reduce memory bandwidth, + # accumulate in float for better numerical accuracy + dtype = T.float8_e4m3fnuz if torch.version.hip is not None else T.float8_e4m3fn + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + """ + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. + """ + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # Allocate a shared memory for C sub-block of shape (block_M, block_N) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + # Enable (or disable) swizzling optimization + T.use_swizzle(panel_size=10, enable=enable_rasteration) + # to utilize swizzle tma layout + T.annotate_layout({C_shared: tilelang.layout.make_swizzled_layout(C_shared)}) + + # Clear out the accumulation buffer + T.clear(C_local) + + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy(A[by * block_M, k * block_K], A_shared) + # Load a sub-block of B from global memory into B_shared + T.copy(B[bx * block_N, k * block_K], B_shared) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + policy=policy, + k_pack=k_pack, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +if __name__ == "__main__": + # Parse command-line arguments for matrix dimensions + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument( + "--with_roller", + action="store_true", + help="Whether to enable BitBLAS roller for search space", + ) + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + with_roller = args.with_roller + + # Compute total floating-point operations to measure throughput + total_flops = 2 * M * N * K + + # matmul(...) returns (best_latency, best_config, ref_latency) + best_result = matmul(M, N, K, with_roller) + best_latency = best_result.latency + best_config = best_result.config + + # Print out the benchmark results + print(f"Best latency (s): {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") + print(f"Best config: {best_config}") diff --git a/cmake/load_tvm.cmake b/cmake/load_tvm.cmake new file mode 100644 index 0000000000000000000000000000000000000000..cb21be95f649174d829778855a4849ae1d87a006 --- /dev/null +++ b/cmake/load_tvm.cmake @@ -0,0 +1,30 @@ +# todo: support prebuilt tvm + +set(TVM_BUILD_FROM_SOURCE TRUE) +set(TVM_SOURCE ${CMAKE_SOURCE_DIR}/3rdparty/tvm) + +if(DEFINED ENV{TVM_ROOT}) + if(EXISTS $ENV{TVM_ROOT}/cmake/config.cmake) + set(TVM_SOURCE $ENV{TVM_ROOT}) + message(STATUS "Using TVM_ROOT from environment variable: ${TVM_SOURCE}") + endif() +endif() + +message(STATUS "Using TVM source: ${TVM_SOURCE}") + +set(TVM_INCLUDES + ${TVM_SOURCE}/include + ${TVM_SOURCE}/src + ${TVM_SOURCE}/3rdparty/dlpack/include + ${TVM_SOURCE}/3rdparty/dmlc-core/include +) + +if(EXISTS ${TVM_SOURCE}/ffi/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include) +elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include) +endif() + +if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include) +endif() diff --git a/cmake/pypi-z3/FindZ3.cmake b/cmake/pypi-z3/FindZ3.cmake new file mode 100644 index 0000000000000000000000000000000000000000..d7920f8f9ccc66f300ac2b1f92b5d8044e457839 --- /dev/null +++ b/cmake/pypi-z3/FindZ3.cmake @@ -0,0 +1,30 @@ +if(Z3_FOUND) + return() +endif() +find_package(Python3 COMPONENTS Interpreter REQUIRED) +execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])" + OUTPUT_VARIABLE Z3_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE Z3_PYTHON_RESULT +) +if(NOT Z3_PYTHON_RESULT EQUAL 0 OR Z3_PATH STREQUAL "") + message(FATAL_ERROR "Failed to locate z3 Python package. Ensure z3-solver>=4.13.0 is installed.") +endif() +message("-- Find Z3 in path: ${Z3_PATH}") +find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS ${Z3_PATH}/include) +find_library(Z3_LIBRARY NO_DEFAULT_PATH NAMES z3 libz3 PATHS ${Z3_PATH}/bin ${Z3_PATH}/lib ${Z3_PATH}/lib64) +message("-- Found Z3 include dir: ${Z3_INCLUDE_DIR}") +message("-- Found Z3 library: ${Z3_LIBRARY}") +add_library(z3::libz3 SHARED IMPORTED GLOBAL) +set_target_properties(z3::libz3 + PROPERTIES + IMPORTED_LOCATION ${Z3_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${Z3_INCLUDE_DIR} +) +if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY) + message(FATAL_ERROR "Could not find Z3 library or include directory") +endif() +set(Z3_CXX_INCLUDE_DIRS ${Z3_INCLUDE_DIR}) +set(Z3_C_INCLUDE_DIRS ${Z3_INCLUDE_DIR}) +set(Z3_FOUND TRUE) diff --git a/docker/Dockerfile.cu118 b/docker/Dockerfile.cu118 new file mode 100644 index 0000000000000000000000000000000000000000..be8274461a40e1803580ca7f6524feea31e65ada --- /dev/null +++ b/docker/Dockerfile.cu118 @@ -0,0 +1,28 @@ +FROM nvcr.io/nvidia/pytorch:22.12-py3 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \ + bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh + +ENV PATH="/opt/conda/bin:${PATH}" + +ENV LIBGL_ALWAYS_INDIRECT=1 + +RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all + +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev + +RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ + && cd TileLang && USE_CUDA=1 pip install -e . -v + +CMD bash diff --git a/docker/Dockerfile.cu120 b/docker/Dockerfile.cu120 new file mode 100644 index 0000000000000000000000000000000000000000..7ca1d931fcde2f8a99bf1f0c89a21f158c67d0ee --- /dev/null +++ b/docker/Dockerfile.cu120 @@ -0,0 +1,28 @@ +FROM nvcr.io/nvidia/pytorch:23.01-py3 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \ + bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh + +ENV PATH="/opt/conda/bin:${PATH}" + +ENV LIBGL_ALWAYS_INDIRECT=1 + +RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all + +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev + +RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ + && cd TileLang && USE_CUDA=1 pip install -e . -v + +CMD bash diff --git a/docker/Dockerfile.cu121 b/docker/Dockerfile.cu121 new file mode 100644 index 0000000000000000000000000000000000000000..f91029d751b7cee0a09e9d0c5d2de5ea353d48a1 --- /dev/null +++ b/docker/Dockerfile.cu121 @@ -0,0 +1,28 @@ +FROM nvcr.io/nvidia/pytorch:23.07-py3 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \ + bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh + +ENV PATH="/opt/conda/bin:${PATH}" + +ENV LIBGL_ALWAYS_INDIRECT=1 + +RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all + +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev + +RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ + && cd TileLang && USE_CUDA=1 pip install -e . -v + +CMD bash diff --git a/docker/Dockerfile.cu123 b/docker/Dockerfile.cu123 new file mode 100644 index 0000000000000000000000000000000000000000..b3d1217fdd134a3f883f64b19e6fc9987fee56ee --- /dev/null +++ b/docker/Dockerfile.cu123 @@ -0,0 +1,28 @@ +FROM nvcr.io/nvidia/pytorch:24.02-py3 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \ + bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh + +ENV PATH="/opt/conda/bin:${PATH}" + +ENV LIBGL_ALWAYS_INDIRECT=1 + +RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all + +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev + +RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ + && cd TileLang && USE_CUDA=1 pip install -e . -v + +CMD bash diff --git a/docker/Dockerfile.cu124 b/docker/Dockerfile.cu124 new file mode 100644 index 0000000000000000000000000000000000000000..335f52565d04adf6a66d447ea3ebc788f4fbbcc3 --- /dev/null +++ b/docker/Dockerfile.cu124 @@ -0,0 +1,28 @@ +FROM nvcr.io/nvidia/pytorch:24.05-py3 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \ + bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh + +ENV PATH="/opt/conda/bin:${PATH}" + +ENV LIBGL_ALWAYS_INDIRECT=1 + +RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all + +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev + +RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ + && cd TileLang && USE_CUDA=1 pip install -e . -v + +CMD bash diff --git a/docker/Dockerfile.cu125 b/docker/Dockerfile.cu125 new file mode 100644 index 0000000000000000000000000000000000000000..148e44b41df820c705f22482dc491caa32455514 --- /dev/null +++ b/docker/Dockerfile.cu125 @@ -0,0 +1,28 @@ +FROM nvcr.io/nvidia/pytorch:24.07-py3 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \ + bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh + +ENV PATH="/opt/conda/bin:${PATH}" + +ENV LIBGL_ALWAYS_INDIRECT=1 + +RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all + +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev + +RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ + && cd TileLang && USE_CUDA=1 pip install -e . -v + +CMD bash diff --git a/docker/Dockerfile.cu126 b/docker/Dockerfile.cu126 new file mode 100644 index 0000000000000000000000000000000000000000..c031c2bc98f97a02ab35afa31610e4c3d3d110bf --- /dev/null +++ b/docker/Dockerfile.cu126 @@ -0,0 +1,28 @@ +FROM nvcr.io/nvidia/pytorch:24.12-py3 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \ + bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh + +ENV PATH="/opt/conda/bin:${PATH}" + +ENV LIBGL_ALWAYS_INDIRECT=1 + +RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all + +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev + +RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ + && cd TileLang && USE_CUDA=1 pip install -e . -v + +CMD bash diff --git a/docker/Dockerfile.cu128 b/docker/Dockerfile.cu128 new file mode 100644 index 0000000000000000000000000000000000000000..2b895ecd8a02038f12cbb2c951c7a5046cf81a9c --- /dev/null +++ b/docker/Dockerfile.cu128 @@ -0,0 +1,31 @@ +FROM nvcr.io/nvidia/pytorch:25.01-py3 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \ + bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh + +ENV PATH="/opt/conda/bin:${PATH}" + +ENV LIBGL_ALWAYS_INDIRECT=1 + +RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && conda clean --all + +RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev \ + build-essential cmake libedit-dev libxml2-dev cython3 + +RUN pip install cython + +RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ + && cd TileLang && USE_CUDA=1 pip install -e . -v + +CMD bash diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm new file mode 100644 index 0000000000000000000000000000000000000000..5f61f0e2e8c41f40ea2d213e07ae1e282a9f0102 --- /dev/null +++ b/docker/Dockerfile.rocm @@ -0,0 +1,51 @@ +FROM rocm/pytorch:rocm6.3.2_ubuntu22.04_py3.10_pytorch_release_2.4.0 + +WORKDIR /root + +RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git wget \ + libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \ + rocm-dev rocm-libs hip-dev hipblas-dev rocblas-dev \ + && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/* + +ENV PATH="/opt/conda/bin:${PATH}" +ENV LIBGL_ALWAYS_INDIRECT=1 +ENV USE_ROCM=1 +ENV USE_CUDA=0 +ENV ROCM_HOME=/opt/rocm +ENV HIP_PLATFORM=amd +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" + + +RUN conda run -n py_3.10 conda install pip cmake -y && \ + conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \ + conda clean --all + +RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \ + apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/* + +# Copy local tilelang directory instead of cloning from git +# Build from tilelang root: docker build -f docker/Dockerfile.rocm -t mi300:latest . +COPY . /root/tilelang + +RUN mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \ + conda run -n py_3.10 bash -c "export USE_ROCM=1 USE_CUDA=0 && pip install 'numpy<2.0' --force-reinstall" && \ + conda run -n py_3.10 bash -c "cd /root/tilelang && \ + # Backup and modify pyproject.toml to remove torch from dependencies \ + cp pyproject.toml pyproject.toml.bak && \ + sed -i '/^[[:space:]]*\"torch/d' pyproject.toml && \ + # Install tilelang with all dependencies except torch \ + USE_ROCM=1 USE_CUDA=0 pip install -e . -v && \ + # Restore original pyproject.toml \ + mv pyproject.toml.bak pyproject.toml" + +RUN conda init bash && \ + echo "conda activate py_3.10" >> /root/.bashrc + +SHELL ["/bin/bash", "-l", "-c"] + +ENTRYPOINT ["/bin/bash", "--login", "-i"] diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000000000000000000000000000000000000..74060193a0831bbaf19498f985a0cf1b6f845da2 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,15 @@ +To ease the process of installing all the dependencies, we provide a Dockerfile and a simple guideline to build a Docker image with all of above installed. The Docker image is built on top of Ubuntu 20.04, and it contains all the dependencies required to run the experiments. We only provide the Dockerfile for NVIDIA GPU, and the Dockerfile for AMD GPU will be provided upon request. + +```bash +git clone --recursive https://github.com/tile-ai/tilelang TileLang +cd TileLang/docker +# build the image, this may take a while (around 10+ minutes on our test machine) +# replace the version number cu124 with the one you want to use +# replace .cu** with .rocm for AMD GPU +docker build -t tilelang_workspace -f Dockerfile.cu124 . +# run the container +# if it's nvidia +docker run -it --cap-add=SYS_ADMIN --network=host --gpus all --cap-add=SYS_PTRACE --shm-size=4G --security-opt seccomp=unconfined --security-opt apparmor=unconfined --name tilelang_test tilelang_workspace bash +# if it's amd +docker run -it --cap-add=SYS_ADMIN --network=host --device=/dev/kfd --device=/dev/dri --cap-add=SYS_PTRACE --shm-size=4G --security-opt seccomp=unconfined --security-opt apparmor=unconfined --name tilelang_test tilelang_workspace bash +``` diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4d8eb40499da61a00de91503a87038940f8a95d6 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,2 @@ +_build/ +autoapi/ \ No newline at end of file diff --git a/docs/CNAME b/docs/CNAME new file mode 100644 index 0000000000000000000000000000000000000000..ca903c694a195b577524d38b2b26cc577ab76bf9 --- /dev/null +++ b/docs/CNAME @@ -0,0 +1 @@ +tilelang.com \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..157adfb90438fa9f3bf061f876877d606acc4d0d --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,25 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= python -m sphinx +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile clean + +# The "clean" target is updated to remove the autoapi generated files as well. +# Run "make clean" to ensure a completely fresh build. +clean: + rm -rf $(BUILDDIR) autoapi + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..349c0eccc5e7e030456d4712241ae1d72282ffa6 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,30 @@ +# Tile Language Documentation + +The documentation was built upon [Sphinx](https://www.sphinx-doc.org/en/master/). + +## Dependencies + +Run the following command in this directory to install dependencies first: + +```bash +pip3 install -r requirements.txt +``` + +## Build the Documentation + +Then you can build the documentation by running: + +```bash +make html +``` + +## View the Documentation + +Run the following command to start a simple HTTP server: + +```bash +cd _build/html +python3 -m http.server +``` + +Then you can view the documentation in your browser at `http://localhost:8000` (the port can be customized by appending ` -p PORT_NUMBER` in the python command above). diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 0000000000000000000000000000000000000000..0ef6b48cb8b08d17ac582728af5a6f040de06ee1 --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,11 @@ +/* Reduce the displayed size of the sidebar logo in Furo */ +.sidebar-logo { + max-height: 125px; + width: auto; +} + +/* Optional: keep container from growing too tall due to spacing */ +.sidebar-logo-container { + line-height: 0; +} + diff --git a/docs/_static/img/LayoutInference.png b/docs/_static/img/LayoutInference.png new file mode 100644 index 0000000000000000000000000000000000000000..d44e4100d013329365036f355d32f400084456d9 Binary files /dev/null and b/docs/_static/img/LayoutInference.png differ diff --git a/docs/_static/img/MatmulExample.png b/docs/_static/img/MatmulExample.png new file mode 100644 index 0000000000000000000000000000000000000000..555ae30a75b2486bffb8acf27f72802d2c96ec3d Binary files /dev/null and b/docs/_static/img/MatmulExample.png differ diff --git a/docs/_static/img/Parallel.png b/docs/_static/img/Parallel.png new file mode 100644 index 0000000000000000000000000000000000000000..656d4cc01089ccc374d8890c431ee7d9ae096fb5 Binary files /dev/null and b/docs/_static/img/Parallel.png differ diff --git a/docs/_static/img/ir_transform_diagram.png b/docs/_static/img/ir_transform_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..3bd86891394c90db12f98c5dcc43f02330aa3f93 Binary files /dev/null and b/docs/_static/img/ir_transform_diagram.png differ diff --git a/docs/_static/img/logo-row.svg b/docs/_static/img/logo-row.svg new file mode 100644 index 0000000000000000000000000000000000000000..633243f3a9a003a903b859e8d8da5273b0f4cbf3 Binary files /dev/null and b/docs/_static/img/logo-row.svg differ diff --git a/docs/_static/img/logo-v2.png b/docs/_static/img/logo-v2.png new file mode 100644 index 0000000000000000000000000000000000000000..410773f60a0d6ddf9bb86186ecb70529ff1d4667 Binary files /dev/null and b/docs/_static/img/logo-v2.png differ diff --git a/docs/_static/img/logo.png b/docs/_static/img/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..5d04697ce4cd98d1aa6d9edc0f492601ef92c575 Binary files /dev/null and b/docs/_static/img/logo.png differ diff --git a/docs/_static/img/mla_hopper/bs128_float16.png b/docs/_static/img/mla_hopper/bs128_float16.png new file mode 100644 index 0000000000000000000000000000000000000000..3cf24c84b82532bf422efee26afe61b4ae0e1948 Binary files /dev/null and b/docs/_static/img/mla_hopper/bs128_float16.png differ diff --git a/docs/_static/img/mla_hopper/bs64_float16.png b/docs/_static/img/mla_hopper/bs64_float16.png new file mode 100644 index 0000000000000000000000000000000000000000..15807c3d2e57f5a2848b792d0fe746db31be455d Binary files /dev/null and b/docs/_static/img/mla_hopper/bs64_float16.png differ diff --git a/docs/_static/img/mla_hopper/pv_layout.jpg b/docs/_static/img/mla_hopper/pv_layout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..79b0c8cf301d9c04eef050c893156c71549ce03d Binary files /dev/null and b/docs/_static/img/mla_hopper/pv_layout.jpg differ diff --git a/docs/_static/img/mla_hopper/qk_layout.jpg b/docs/_static/img/mla_hopper/qk_layout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d5bd923d0d8ab1fe5edece222f31777ccd0d746 Binary files /dev/null and b/docs/_static/img/mla_hopper/qk_layout.jpg differ diff --git a/docs/_static/img/op_benchmark_consistent_gemm_fp16.png b/docs/_static/img/op_benchmark_consistent_gemm_fp16.png new file mode 100644 index 0000000000000000000000000000000000000000..840e423e7199a96e8127cfe2750f7ebb60058bb3 Binary files /dev/null and b/docs/_static/img/op_benchmark_consistent_gemm_fp16.png differ diff --git a/docs/_static/img/overview.png b/docs/_static/img/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..0aa477701b3dbb8eac60988c5e46dbafc8acf8aa Binary files /dev/null and b/docs/_static/img/overview.png differ diff --git a/docs/_static/img/software_pipeline_inference.png b/docs/_static/img/software_pipeline_inference.png new file mode 100644 index 0000000000000000000000000000000000000000..b1b3fd667eb612ea01cafd14c16ecdd599d42c02 Binary files /dev/null and b/docs/_static/img/software_pipeline_inference.png differ diff --git a/docs/_static/img/sparse_mma_storage_example.png b/docs/_static/img/sparse_mma_storage_example.png new file mode 100644 index 0000000000000000000000000000000000000000..0b16398197b28f10b5681cfd27a9f2fe061be5cd Binary files /dev/null and b/docs/_static/img/sparse_mma_storage_example.png differ diff --git a/docs/compiler_internals/inject_fence_proxy.md b/docs/compiler_internals/inject_fence_proxy.md new file mode 100644 index 0000000000000000000000000000000000000000..7a89456ac809d6d534cce9f6167a4e4ba52a9c59 --- /dev/null +++ b/docs/compiler_internals/inject_fence_proxy.md @@ -0,0 +1,113 @@ +# InjectFenceProxy Pass + +`tl.InjectFenceProxy` is a TIR-level transform that keeps the GPU proxy state consistent on NVIDIA Hopper (SM90+) by inserting `fence.proxy.async` instructions when control flow switches from generic memory operations to asynchronous proxy operations. + +## Why Fences Are Needed + +Hopper separates memory instructions into generic and asynchronous proxy paths. When an asynchronous instruction (for example, `cp.async` or `tma.load`) issues after generic traffic (like `ldmatrix` or plain buffer stores), the hardware requires a `fence.proxy.async` to guarantee ordering. Missing fences can lead to race conditions or undefined behavior. + +## What the Pass Does + +- Walks every statement in the `PrimFunc`, tracking whether it behaves as a **generic**, **async**, or **neutral** proxy (neutral statements reset the state, such as an explicit fence). +- Automatically lowers `tma_store` intrinsics into the required `arrive`/`wait` handshake so that TMA stores participate correctly in synchronization. +- Injects an explicit `fence.proxy.async` whenever a generic statement is followed by an async statement without an intervening neutral barrier. + +The pass is conservative: unknown extern calls are treated as async so that the fence is inserted rather than accidentally omitted. + +### Timeline View + +``` +generic initialize_wgmma_descriptor โ†’ generic shared-store โ†’ async wgmma + โ”‚ โ”‚ โ”‚ + โ””โ”€ generic proxy โ”ดโ”€ generic proxy โ”ดโ”€ async proxy + โ”‚ fence inserted here โ†‘ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +The proxy tracker scans the sequence from left to right. The moment it detects a transition from generic to async (between the store and `cp.async` above), it synthesizes a `fence.proxy.async` to reset the hardware proxy state before the async path runs. + +## Coverage of Intrinsics + +The tracker understands the TileLang intrinsics for TMA load/store, shared-memory MMA (`wgmma`), and TVM/PTX async copy intrinsics (`cp.async` variants). Generic operations currently include `ldmatrix`, `stmatrix`, and descriptor initialization. Other IR nodes (loops, blocks, attributes) receive a proxy kind derived from their bodies so that the analysis survives structured control flow. + +## Usage + +The pass is part of the default TileLang lowering pipeline. To apply it manually: + +```python +from tilelang import tl +from tvm import IRModule + +mod = IRModule({"main": prim_func}) +with tvm.transform.PassContext(): + mod = tl.transform.InjectFenceProxy()(mod) +``` + +## End-to-End Example + +Before the pass: + +```python +@T.prim_func +def kernel(): + with T.Kernel(1): + desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") + smem = T.decl_buffer((128,), "float16", scope="shared") + T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32) + smem[0] = T.float16(0) + T.ptx_wgmma_ss( + "float16", + "m64n64k16", + T.bool(True), + T.bool(True), + "fp16", + "fp16", + "fp16", + desc.data, + T.int32(0), + desc.data, + T.int32(0), + smem.data, + T.int32(0), + T.bool(True), + 1, + 1, + ) +``` + +After `tl.transform.InjectFenceProxy`: + +```python +@T.prim_func +def kernel(): + with T.Kernel(1): + desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") + smem = T.decl_buffer((128,), "float16", scope="shared") + T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32) + smem[0] = T.float16(0) + T.fence_proxy_async() + T.ptx_wgmma_ss( + "float16", + "m64n64k16", + T.bool(True), + T.bool(True), + "fp16", + "fp16", + "fp16", + desc.data, + T.int32(0), + desc.data, + T.int32(0), + smem.data, + T.int32(0), + T.bool(True), + 1, + 1, + ) +``` + +The only change is the `fence_proxy_async` between the generic descriptor setup / shared-memory write and the async `wgmma`. In larger kernels the pass performs the same operation across nested blocks, loops, and conditional branches. + +## Extending the Pass + +If you introduce a new intrinsic that behaves like an async proxy, add it to `IsAsyncIntrinsic` in `src/transform/inject_fence_proxy.cc`. Likewise, extend `IsKnownGeneric` for additional generic operations. When adding new neutral barriers, make sure they set the proxy kind to `kNeutral` so the state resets correctly. diff --git a/docs/compiler_internals/letstmt_inline.md b/docs/compiler_internals/letstmt_inline.md new file mode 100644 index 0000000000000000000000000000000000000000..012af9020d0e228959588cbbdb704ccd5ba34cda --- /dev/null +++ b/docs/compiler_internals/letstmt_inline.md @@ -0,0 +1,163 @@ +# LetStmt Inlining in TileLang + +This document explains how `LetStmt` inlining works in TileLang's simplification pipeline, which is an important optimization that affects code generation and performance. + +## Overview + +A `LetStmt` (Let Statement) is a temporary variable binding in the IR (Intermediate Representation). During compilation, TileLang's simplifier may choose to inline these temporary variables to simplify the code. TileLang also provides a standalone `LetInline` pass that performs eager substitution before the main legalization pipeline. However, not all `LetStmt` nodes can be safely inlined. + +## When Does LetStmt Get Inlined? + +The inlining logic is implemented in `src/transform/simplify.cc`. A `LetStmt` will be inlined if **both** of the following conditions are met: + +### 1. The value satisfies `CanInlineLetStmt` + +The `CanInlineLetStmt` helper returns `true` when: + +- **The value is a constant** (`is_const_number(op->value)` returns true) +- **The value is a variable** (`op->value.as()` returns a node) +- **The value is an integer expression without side effects**: + - The value has `int` dtype + - The side effect level is `kPure` or lower (no observable side effects) + +```cpp +bool CanInlineLetStmt(const LetStmtNode *op) { + if (is_const_number(op->value)) + return true; + if (op->value.as()) + return true; + // Won't face the deep expression explosion problem as in Let expression. + // attempt to inline as much as possible if the value integer type(can be + // index). + if (!op->value.dtype().is_int()) + return false; + return SideEffect(op->value) <= CallEffectKind::kPure; +} +``` + +### 2. The variable is NOT used in buffer definitions + +Even if `CanInlineLetStmt` returns true, the variable will **not** be inlined if it's used in a buffer's definition (shape, strides, elem_offset, or data fields). + +This protection exists because: +- Buffer definitions are not updated during the simplification pass +- If a variable used in a buffer definition is inlined, later references to that buffer would fail to find the variable definition +- This would cause compilation errors or incorrect behavior + +The mutator checks this before dropping the binding: + +```cpp +bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); + +if (can_inline && !used_in_buffer_def) { + return body; // Inline: remove LetStmt and return body directly +} +``` + +## Example: Why Buffer Definition Variables Are Protected + +Consider this code: + +```python +let stride = M * 16 +let buffer_a = Buffer(data, shape=[M, N], strides=[stride, 1]) +buffer_a[i, j] = ... +``` + +- `stride` satisfies `CanInlineLetStmt` (it's an int expression with no side effects) +- However, `stride` is used in `buffer_a`'s `strides` field +- If we inline it, the buffer definition becomes `strides=[M*16, 1]` +- But the Buffer object's fields are not updated during simplification +- Later code accessing `buffer_a` would fail to find the `stride` variable + +Therefore, `stride` is added to `used_in_buffer_def_` and will **not** be inlined. + +## How Variables Are Collected + +The `CollectVarsUsedInBufferDefinition` helper traverses all `BufferLoad` and `BufferStore` nodes and collects variables used in their buffer definitions: + +```cpp +void VisitBuffer(const Buffer &buf) { + // Collect variables that should remain defined + VarUseDefAnalyzer usage(Array{}); + usage(buf->data); + for (const auto &dim : buf->shape) { + usage(dim); + } + for (const auto &dim : buf->strides) { + usage(dim); + } + usage(buf->elem_offset); + + // Track for use in LetStmtNode mutator + for (const auto &var : usage.undefined_) { + used_in_buffer_def_.insert(var.get()); + } +} +``` + +## Practical Example: Temporary Variable Issue + +Consider this TileLang code: + +```python +for i in T.Parallel(block_N): + idx = bx * block_N + i + tmp = T.max(A[idx], 1) + B[idx] = tmp / 2 + A[idx] = tmp * 2 +``` + +In this case: +- `tmp` is an integer-like temporary variable +- It satisfies `CanInlineLetStmt` (pure int expression) +- It's **not** used in any buffer definition +- Therefore, `tmp` **will be inlined** + +This means the IR becomes: + +```python +for i in T.Parallel(block_N): + idx = bx * block_N + i + B[idx] = T.max(A[idx], 1) / 2 + A[idx] = T.max(A[idx], 1) * 2 +``` + +If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern. + +## Controlling Let Inlining via Pass Config + +TileLang exposes an explicit pass configuration key, `tilelang.PassConfigKey.TL_FORCE_LET_INLINE` (`"tl.force_let_inline"`), that allows users to force the eager `LetInline` pass to run before the legalization pipeline begins. When enabled, the pipeline invokes `tilelang.transform.LetInline()` at the start of `LowerAndLegalize` (see `tilelang/engine/phase.py`). This knob is useful when debugging LetStmt-related issues or when deterministic inlining behavior is desired across different environments. + +```python +from tilelang import transform +from tilelang.engine.phase import LowerAndLegalize + +with transform.PassContext( + config={transform.PassConfigKey.TL_FORCE_LET_INLINE: True} +): + lowered_mod = LowerAndLegalize(input_mod, target) +``` + +If the flag is left unset (the default), the eager pass is only applied when downstream transforms opt in (for example, by calling `_Simplify(..., inline_let=True)` inside Tile operators). The guard in `tilelang/engine/phase.py` ensures the eager pass is only triggered when the user explicitly requests it. + +## Summary + +The LetStmt inlining mechanism is a **conservative optimization** that: +1. Aggressively inlines simple, pure integer expressions to simplify the IR +2. Protects variables used in buffer definitions to avoid breaking buffer access +3. Helps reduce IR complexity and improve code generation +4. Can be forced through `TL_FORCE_LET_INLINE` when deterministic eager inlining is required + +Understanding when inlining happens is crucial for: +- Debugging compilation issues +- Understanding generated code +- Writing efficient TileLang programs +- Identifying potential optimization opportunities or bugs + +## Related Files + +- `src/transform/simplify.cc`: Main Simplify implementation +- `src/transform/frontend_legalize.cc`: Standalone LetInline pass +- `tilelang/engine/phase.py`: Pipeline integration for eager LetInlining +- `testing/python/transform/test_tilelang_transform_let_inline.py`: Regression coverage for the pass diff --git a/docs/compiler_internals/tensor_checks.md b/docs/compiler_internals/tensor_checks.md new file mode 100644 index 0000000000000000000000000000000000000000..b4d2a0b3c03048455b2e2a77e3536292d5f5a202 --- /dev/null +++ b/docs/compiler_internals/tensor_checks.md @@ -0,0 +1,387 @@ +# Tensor Checks (Host-Side Auto-Validation) + +This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more โ€” so you donโ€™t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. + +## Why Host-Side Checks +- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. +- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. +- Focused error reporting: assertions are raised close to the call site with precise โ€œwhich field failedโ€ messages. + +## How To Inspect Host Source +You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: + +```python +print(matmul_relu_kernel.get_host_source()) +``` + +--- + +## What The Host Checks + +### 1) Argument count and pointer kind +- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. +- Each argumentโ€™s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise youโ€™ll see errors like `Expect arg[i] to be pointer` or a scalar type error. + +### 2) Tensor checks (per tensor, after nullability decision) +- Nullability + - If the tensor is โ€œstatically reachable/usedโ€ by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. + - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. +- Rank (`ndim`) + - Runtime `ndim` must equal the compile-time rank. +- Data type (`dtype`) + - Match the triple `(code, bits, lanes)` with tolerance: + - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. + - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. + - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). + - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. +- Shape + - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. + - Linear equations among symbolic dims can be solved on the fly (when thereโ€™s only one unknown at a given check point), enabling cross-tensor constraints. +- Strides + - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. + - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). +- `byte_offset` + - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. +- Device info + - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. + - When multiple tensors participate, assert that `device_id` matches across them. +- Data pointer + - Must be non-NULL when the tensor is required to be non-null by the nullability rule. + +### 3) Scalar checks +- `T.int*` family: require integer; error: `Expect arg[i] to be int`. +- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. + +--- + +## Shapes and Symbolic Equations: Linear Solving +When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: + +```python +@T.prim_func +def main( + A: T.Tensor((m,), dtype), + B: T.Tensor((m + n,), dtype), + C: T.Tensor((n * k,), dtype), +): + ... +``` + +This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. + +--- + +## Nullability Rules and Examples +Which tensors may be NULL? + +- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. +- Examples: + +1) Must be non-NULL (used) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + A[0] = 1 +``` +Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. + +2) Still must be non-NULL (constant-true branch) +```python +some_cond: bool = True +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +3) Nullable (constant-false branch, statically unreachable) +```python +some_cond: bool = False +@T.prim_func +def main(A: T.Tensor((M, K), dtype)): + if some_cond: + A[0] = 1 +``` + +4) Must be non-NULL (runtime condition) +```python +@T.prim_func +def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): + if some_cond: + A[0] = 1 +``` +Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. + +--- + +## Device Type Codes (DLPack) +Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. +Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. + +--- + +## Common Error Examples (What youโ€™ll see) +- Argument count mismatch (num_args) + - Trigger: missing/extra argument + - Error: `: num_args should be N; expected: , got: N` + +- Pointer-typed argument expected + - Trigger: scalar passed where a tensor is expected + - Error: `: Expect arg[i] to be pointer` + +- Rank (ndim) mismatch + - Trigger: runtime rank differs from compile-time rank + - Error: `..ndim is expected to equal R, but got mismatched ndim` + +- Dtype mismatch + - Trigger: dtype not equal to the compiled dtype and not within the tolerance set + - Error: `..dtype is expected to be , but got incompatible dtype` + +- Shape constraint violation + - Trigger: a dimension doesnโ€™t match a constant/symbol binding + - Error: `Argument ..shape[i] has an unsatisfied constraint: ... == ` + +- Strides check failed (e.g., non-contiguous layout) + - Trigger: transposed/sliced tensors that violate expected strides + - Error: `Argument ..strides[j] has an unsatisfied constraint: ... == ` + +- Device type mismatch + - Trigger: calling a CUDA kernel with CPU tensors, etc. + - Error: `..device_type mismatch [expected: ()] ...` + +- Device id mismatch + - Trigger: mixing tensors from different GPUs + - Error: `Argument ..device_id has an unsatisfied constraint: ... == ...` + +- NULL data pointer + - Trigger: tensor required to be non-null has a NULL data pointer + - Error: `. is expected to have non-NULL data pointer, but got NULL` + +- Scalar type mismatch + - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` + - Error: `: Expect arg[i] to be int/boolean` + +--- + +## Troubleshooting Tips +- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. +- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. +- Align devices: ensure all participating tensors share the same `device_type` and `device_id`. +- Align dtype: use `.to()` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. +- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). + +--- + +## FAQ +- Can I disable the checks? + - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. +- Is the overhead noticeable? + - The checks are lightweight (branches and field reads). Compared to Python-side checks, itโ€™s faster; the dominating cost remains the Pythonโ†’C boundary. Overall itโ€™s cheaper than equivalent checks in Python. + +--- + +## Reference Example (Matmul + ReLU) + +```python +@T.prim_func +def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), +): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + +# For debugging, print the host source +print(matmul_relu_kernel.get_host_source()) +``` + +The host will insert all checks described above for this example. + +--- + +## Quick Error Reference (Short List) +- Argument count + - Trigger: missing/extra args; Error: `num_args should be N; expected: , got: N`. +- Pointer kind + - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. +- Rank (ndim) + - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. +- Dtype + - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be `. +- Shape + - Trigger: constant/symbol binding violated; Error: `shape[i] ... == `. +- Strides + - Trigger: layout mismatch; Error: `strides[j] ... == `. +- Device type + - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. +- Device id + - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. +- Data pointer + - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. +- Scalar types + - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. + +--- + +## Host Error Troubleshooting (Minimal Repros) + +Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: + +```python +# Convention: +# A: float16 [M, K] +# B: float16 [K, N] +# C: float16 [M, N] +# Target: CUDA (device_type=2) +fn = matmul_relu_kernel # your compiled function +M = N = K = 1024 +``` + +Adjust dtype/device if your kernel differs. + +### 0. Tip: print the host source +```python +print(fn.get_host_source()) +``` + +### 1. num_args mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +# Missing C +fn(A, B) +``` +Expected: `: num_args should be 3; expected: , got: 3`. + +Fix: pass all arguments per the signature. + +### 2. Expect pointer (tensor) but got scalar +```python +import torch + +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(1, B, C) +``` +Expected: `: Expect arg[0] to be pointer`. + +Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). + +### 3. ndim mismatch +```python +import torch + +A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.ndim is expected to equal 2, but got mismatched ndim`. + +Fix: ensure runtime rank equals compiled rank. + +### 4. dtype mismatch +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `.A_handle.dtype is expected to be float16, but got incompatible dtype`. + +Fix: `A = A.to(torch.float16)` or create with the correct dtype. + +### 5. Shape constant/symbol mismatch +```python +import torch + +A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .A_handle.shape[i] has an unsatisfied constraint: ... == `. + +Fix: satisfy linear constraints and constants across tensors. + +### 6. Strides check failure (non-contiguous) +```python +import torch + +A = torch.empty((M, K), device='cuda', dtype=torch.float16) +A_nc = A.t() # transpose -> non-contiguous +B = torch.empty((K, N), device='cuda', dtype=torch.float16) +C = torch.empty((M, N), device='cuda', dtype=torch.float16) +fn(A_nc, B, C) +``` +Expected: `Argument .A_handle.strides[1] has an unsatisfied constraint: ... == 1`. + +Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. + +### 7. device_type mismatch +```python +import torch + +A = torch.empty((M, K), device='cpu', dtype=torch.float16) +B = torch.empty((K, N), device='cpu', dtype=torch.float16) +C = torch.empty((M, N), device='cpu', dtype=torch.float16) +fn(A, B, C) # CUDA-targeted kernel +``` +Expected: `.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. + +Fix: move tensors to the CUDA device. + +### 8. device_id mismatch (multi-GPU) +```python +import torch + +A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) +B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) +C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) +fn(A, B, C) +``` +Expected: `Argument .B_handle.device_id has an unsatisfied constraint: ... == ...`. + +Fix: place all tensors on the same GPU (e.g., `cuda:0`). + +### 9. NULL data pointer (advanced) +This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. + +Expected: `. is expected to have non-NULL data pointer, but got NULL`. + +Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. + +### 10. Scalar type mismatch (int / bool) +```python +import tilelang.language as T + +@T.prim_func +def scalar_check(x: T.int32, flag: T.bool()): + T.evaluate(0) + +scalar_check(1.0, True) # x is float -> Expect arg[0] to be int +scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean +``` + +Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. + +--- + +## Closing Notes +- Cross-check โ€œshape / strides / device / dtypeโ€ against the kernel signature to localize issues efficiently. +- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly. + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..877b5582e1e28ee75704d5c75a8ff900a61c4cd3 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,79 @@ +# General information about the project. +project = "TileLang
" +author = "Tile Lang Contributors" +copyright = f"2025-2025, {author}" + +# Version information. +with open("../VERSION") as f: + version = f.read().strip() +release = version + +extensions = [ + "sphinx_tabs.tabs", + "sphinx_toolbox.collapse", + "sphinxcontrib.httpdomain", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx_reredirects", + "sphinx.ext.mathjax", + "myst_parser", + "autoapi.extension", +] + +autoapi_type = "python" +autoapi_dirs = ["../tilelang"] + +autoapi_options = [ + "members", + "undoc-members", + "show-inheritance", + "show-module-summary", + "special-members", +] +autoapi_keep_files = False # Useful for debugging the generated rst files + +autoapi_generate_api_docs = True + +autodoc_typehints = "description" + +autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"] + +source_suffix = {".rst": "restructuredtext", ".md": "markdown"} + +myst_enable_extensions = ["colon_fence", "deflist"] + +redirects = {"get_started/try_out": "../index.html#getting-started"} + +language = "en" + +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md", "**/*libinfo*", "**/*version*"] + +pygments_style = "sphinx" +todo_include_todos = False + +# -- Options for HTML output ---------------------------------------------- + +html_theme = "furo" +templates_path = [] +html_static_path = ["_static"] +html_css_files = ["custom.css"] +footer_copyright = "ยฉ 2025-2026 TileLang" +footer_note = " " + +html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"} + +header_links = [ + ("Home", "https://github.com/tile-ai/tilelang"), + ("Github", "https://github.com/tile-ai/tilelang"), +] + +html_context = { + "footer_copyright": footer_copyright, + "footer_note": footer_note, + "header_links": header_links, + "display_github": True, + "github_user": "tile-ai", + "github_repo": "tilelang", + "github_version": "main/docs/", + "theme_vcs_pageview_mode": "edit", +} diff --git a/docs/deeplearning_operators/deepseek_mla.md b/docs/deeplearning_operators/deepseek_mla.md new file mode 100644 index 0000000000000000000000000000000000000000..08175778f0cc80c91aa4bf12023bacd6284fa59c --- /dev/null +++ b/docs/deeplearning_operators/deepseek_mla.md @@ -0,0 +1,200 @@ +# ๐Ÿš€ Write High Performance FlashMLA with TileLang on Hopper + + +
+ Author: Yu Cheng + Author: Lei Wang +
+ +TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang. + +## Introduction to MLA + +DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance. + +## Benchmark Results + +We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below. + +```{figure} ../_static/img/mla_hopper/bs64_float16.png +:width: 50% +:alt: Overview +:align: center + +Figure 1: Performance under batch size=64 +``` + +```{figure} ../_static/img/mla_hopper/bs128_float16.png +:width: 50% +:alt: Overview +:align: center + +Figure 2: Performance under batch size=128 +``` + +As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. +Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this. + +## Implementation + +First, let's review the core computation logic of traditional FlashAttention: + +```python +# acc_s: [block_M, block_N] +# scores_max: [block_M] +# scores_scale: [block_M] +# acc_o: [block_M, dim] + +for i in range(loop_range): + acc_s = Q @ K[i] + scores_max_prev = scores_max + scores_max = max(acc_s, dim=1) + scores_scale = exp(scores_max_prev - scores_max) + acc_o *= scores_scale + acc_s = exp(acc_s - scores_max) + acc_o = acc_s @ V[i] + ... +``` + +Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency. + +Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance. + +This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling. + +Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. + +Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory. + +### Layout Inference + +While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you. + +Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations. + +```{figure} ../_static/img/mla_hopper/qk_layout.jpg +:width: 50% +:alt: Overview +:align: center + +Figure 3: Buffer shapes in Q @ K +``` + +```{figure} ../_static/img/mla_hopper/pv_layout.jpg +:width: 50% +:alt: Overview +:align: center + +Figure 4: Buffer shapes in acc_s @ V +``` + +The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA. + +For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`. + +It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance. + +### Threadblock Swizzling + +Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions. + +In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code: + +```python +T.use_swizzle(panel_size: int, order: str = "row") +``` + +Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col". + + +### Shared Memory Swizzling + +In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance. + +One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency. + +Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code: + +```python +T.annotate_layout({ + S_shared: TileLang.layout.make_swizzled_layout(S_shared), +}) +``` + +Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout. + + +### Warp-Specialization + +The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. + +In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. + + +### Pipeline + + +Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation: + +```python +T.pipelined(range: int, stage: int) +``` + +Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases. + + +### Split-KV + +We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results. + +In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. + + +## ๐Ÿš€ On AMD MI300X Accelerators + +Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms. + +### Architectural Considerations and Optimization Strategies + +Key implementation differences between Hopper and MI300X architectures include: + +1. **Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations. + +2. **Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes: + - Reducing software pipeline stages + - Register-based caching of Q matrices instead of shared memory utilization: + ```python + # Original shared memory allocation + Q_shared = T.alloc_shared([block_H, dim], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + + # Optimized register allocation + Q_local = T.alloc_fragment([block_H, dim], dtype) + Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) + ``` + +3. **Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64. + +4. **Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code. + +### Performance Evaluation + +We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate: + +
+ + AMD FlashMLA Performance Comparison + +
Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)
+
+ +Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) in most test cases, while significantly outperforming both Triton (1.98ร—) and PyTorch (3.76ร—) implementations. This performance is achieved through a concise 80-line Python implementation, demonstrating TileLang's efficiency and programmability advantages. + +### Future Optimization Opportunities + +1. **Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction. + +2. **Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to: + - Reduce shared memory pressure + - Improve compute-to-memory access ratios + - Enhance parallelism through dimension-wise task distribution diff --git a/docs/deeplearning_operators/elementwise.md b/docs/deeplearning_operators/elementwise.md new file mode 100644 index 0000000000000000000000000000000000000000..f3543c02f5ed4a7d95708de488dba0309ca7bf93 --- /dev/null +++ b/docs/deeplearning_operators/elementwise.md @@ -0,0 +1,332 @@ +# ElementWise Operators + +
+ Author: Chenghua Wang +
+ +:::{warning} +:class: myclass1 myclass2 +:name: a-tip-reference + + This document is still **experimental** and may be incomplete. + Suggestions and improvements are highly encouragedโ€”please submit a PR! +::: + +Elementwise operators are widely used in deep learning and often serve as the first example encountered by those beginning to explore parallel programming. This tutorial will analyze several implementations of the elementwise addition operator using TileLang and compare them with the corresponding CUDA implementation. By the end of this tutorial, you will learn: + +1. How to implement an elementwise operator using TileLang. +2. How to compile operators with dynamic shapes. +3. How TileLang addresses boundary-related issues. +4. The similarities and differences between operators implemented in TileLang and those implemented in CUDA/CuTe. + +Please note that this tutorial does not delve deeply into the design principles of TileLang. For a broader understanding of TileLang, we recommend consulting the [Overview](../get_started/overview.md). + +## Elementwise add in TileLang + +```python +def elementwise_add(N, threads=256, dtype=T.bfloat16): + + @T.prim_func + def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): + with T.Kernel(T.ceildiv(N, threads), threads=threads) as (b_x): + # vector add. + for i in T.Parallel(threads): + C[b_x * threads + i] = A[b_x * threads + i] + B[b_x * threads + i] + + return main +``` + +All logic for TileLang kernels must be implemented within the `T.Kernel(...)` scope. In this example, initializing `T.kernel(...)` requires specifying both the grid size and the number of threads per block. The returned value `bx` corresponds to `blockIdx.x` in CUDA. In the provided implementation, `T.Parallel` is used to process the data tile (of size `1 x threads`) assigned to the block for computation. + +Those familiar with CUDA programming might wonder where `threadIdx` fits into this. Note that the code inside `T.Kernel` operates at the **block level**, not the **thread level**. In this example, your focus is solely on defining the block-level logic. During compilation, TileLang automatically maps computations to the corresponding threads and applies further optimizations. The optimized code generated by TileLang may closely align with carefully handcrafted computational logic, as demonstrated in Section 2 with a concrete example. While TileLang also supports thread-level programming semantics, this will be covered in subsequent discussions. + +The program can be compiled using the following code: + +```python +program = elementwise_add(1024, threads=256, dtype=T.bfloat16) +kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") +``` +Launching the kernel is straightforward, just call it directly like a function: + +```python +C = kernel(A, B) +``` + +The vector add operation can also be extended to two-dimensional cases, where both implementations demonstrate comparable efficiency in practice. Below is an example from the test section that readers can refer to: [example](https://github.com/tile-ai/tilelang/blob/main/testing/python/kernel/test_tilelang_kernel_element_wise_add.py). The code for this kernel is provided below: + +```python +import tilelang.language as T +def elementwise_add( + M, + N, + block_M, + block_N, + in_dtype, + out_dtype, + threads, +): + @T.prim_func + def main( + A: T.Tensor((M, N), in_dtype), + B: T.Tensor((M, N), in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + + for (local_y, local_x) in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + + C[y, x] = A[y, x] + B[y, x] + + return main +``` + +### How to compile operators with dynamic shapes? + +In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this: + +```python +program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16) +kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") +``` + +The resulting CUDA code for the kernel will include an additional `int N` parameter after the `bfloat16_t* __restrict__ A`, `bfloat16_t* __restrict__ B`, and `bfloat16_t* __restrict__ C` parameters. + +### How TileLang addresses boundary-related issues. + +TileLang automatically incorporates boundary-checking conditions; however, this comes at a cost. These boundary conditions may prevent TileLang from performing more advanced optimizations. I will introduce an example from the next section in advance. The corresponding code is also provided below, but note that it involves the associated CUDA code. Readers are encouraged to first review the next section before returning to this paragraph for a clearer understanding. + +When compiling the example below, let's set `N` to 2047: + +```python +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): + + @T.prim_func + def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): + with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x): + # vector add. + for i, j in T.Parallel(threads, num_per_thread): + offsets = (b_x * threads + i) * num_per_thread + C[offsets + j] = A[offsets + j] + B[offsets + j] + + return main +``` + +TileLang will generate the following CUDA code: + +```c++ +extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) { + #pragma unroll + for (int i = 0; i < 8; ++i) { + if (((i * 256) + ((int)threadIdx.x)) < 2047) { + C[((i * 256) + ((int)threadIdx.x))] = (A[((i * 256) + ((int)threadIdx.x))] + B[((i * 256) + ((int)threadIdx.x))]); + } + } +} +``` + +We can observe that TileLang did not apply optimizations such as vectorization or coalesced memory access. In fact, except for the tail group of data, all other threads could have executed more optimized code. + +## Comparison of TileLang, CUDA, and CuTe + +For the subsequent examples, this tutorial will use the vector add operation for simplicity and brevity. + +Typically, those new to CUDA programming often write CUDA code in a style similar to this: + +```c++ +// vector add +__global__ void elementwise_add(float* a, float* b, float* c, int N) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < N) { + c[idx] = a[idx] + b[idx]; + } +} +``` + +The code above assigns each thread to compute a single element, which is evidently inefficient since common acceleration techniques like coalesced memory access and vectorization are not utilized. However, TileLang code written with similar logic (e.g., loop-based traversal) can be optimized by the compiler into highly efficient implementations, making it more accessible for beginners. Additionally, the final generated code from the compiler remains observable, providing transparency into the optimization process. + +The CUDA code generated by TileLang for the compiled kernel can be retrieved using the `kernel.get_kernel_source()` method. Below is the CUDA code produced for the vector addition example from Section 1: + +```cu +extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) { + if (((int)threadIdx.x) < 32) { + uint4 __1; + uint4 v_ = *(uint4*)(A + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))); + uint4 v__1 = *(uint4*)(B + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))); + ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x); + ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y); + ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x); + ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y); + ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x); + ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y); + ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x); + ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y); + *(uint4*)(C + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))) = __1; + } +} +``` + +In the code above, TileLang not only automatically maps block-level parallelism to threads but also applies optimizations such as vectorization and coalesced memory access. + +While TileLang incorporates various optimizations for the aforementioned case, its behavior may sometimes appear counterintuitive. For example, when targeting 256 threads for task processing, applying vectorization can result in each thread computing 8 data elementsโ€”effectively utilizing only 32 active threads. Interestingly, the kernel launch configuration still retains the original allocation of 256 threads. + +In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design. + +```python +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): + + @T.prim_func + def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): + with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x): + # vector add. + for i, j in T.Parallel(threads, num_per_thread): + offsets = (b_x * threads + i) * num_per_thread + C[offsets + j] = A[offsets + j] + B[offsets + j] + + return main +``` + +The corresponding CUDA code generated for the above example is presented below: + +```c++ +extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) { + uint4 __1; + uint4 v_ = *(uint4*)(A + (((int)threadIdx.x) * 8)); + uint4 v__1 = *(uint4*)(B + (((int)threadIdx.x) * 8)); + ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x); + ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y); + ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x); + ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y); + ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x); + ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y); + ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x); + ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y); + *(uint4*)(C + (((int)threadIdx.x) * 8)) = __1; +} +``` +Aha, this CUDA code aligns closely with conventional programming practices, making it more familiar and intuitive. + +But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations. + +```python +def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16): + + @T.prim_func + def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): + with T.Kernel(T.ceildiv(N, threads * NUM_ELE_PER_THREAD), threads=threads) as (b_x): + A_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype) + B_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype) + C_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype) + + s_start = b_x * threads * NUM_ELE_PER_THREAD + s_end = (b_x + 1) * threads * NUM_ELE_PER_THREAD + + # LDG. 128 + T.copy( + A[s_start:s_end], + A_register, + ) + T.copy( + B[s_start:s_end], + B_register, + ) + + # vector add. + for tid, i in T.Parallel(threads, NUM_ELE_PER_THREAD): + C_register[tid * NUM_ELE_PER_THREAD + i] = ( + A_register[tid * NUM_ELE_PER_THREAD + i] + + B_register[tid * NUM_ELE_PER_THREAD + i]) + + # STG. 128 + T.copy( + C_register, + C[s_start:s_end], + ) + + return main +``` + +In the example above, each thread is responsible for computing 8 elements. The `T.copy(...)` method functions at the block level, and TileLang automatically maps data movement operations to individual threads. This design may resonate more intuitively with CUDA developers. Let us now analyze the CUDA code generated from this implementation. + +```c++ +// N is set to 8192 * 8192 when compiling +extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) { + bfloat16_t A_register[8]; + bfloat16_t B_register[8]; + *(uint4*)(A_register + 0) = *(uint4*)(A + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))); + *(uint4*)(B_register + 0) = *(uint4*)(B + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))); + uint4 __1; + uint4 v_ = *(uint4*)(A_register + 0); + uint4 v__1 = *(uint4*)(B_register + 0); + ((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x); + ((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y); + ((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x); + ((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y); + ((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x); + ((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y); + ((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x); + ((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y); + *(uint4*)(A_register + 0) = __1; + *(uint4*)(C + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))) = *(uint4*)(A_register + 0); +} +``` + +We observed the emergence of two additional registers, `A_register` and `B_register`. However, during the actual computation, these registers are simply reassigned to `v_` and `v__1`, respectively. + +To evaluate complexity, one could implement the same elementwise addition operator using CuTe and compare it with the TileLang version. The corresponding CuTe code is provided below: + +```c++ +template +__global__ void elementwise_add(nv_bfloat16* C, + const nv_bfloat16* A, + const nv_bfloat16* B, + int N) { + using namespace cute; + + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + Tensor t_C = make_tensor(make_gmem_ptr(C), make_shape(N)); + Tensor t_A = make_tensor(make_gmem_ptr(A), make_shape(N)); + Tensor t_B = make_tensor(make_gmem_ptr(B), make_shape(N)); + + Tensor t_C_tile = local_tile(t_C, make_shape(Int{}), make_coord(idx)); + Tensor t_A_tile = local_tile(t_A, make_shape(Int{}), make_coord(idx)); + Tensor t_B_tile = local_tile(t_B, make_shape(Int{}), make_coord(idx)); + + Tensor reg_buffer_A = make_tensor_like(t_A_tile); + Tensor reg_buffer_B = make_tensor_like(t_B_tile); + Tensor reg_buffer_C = make_tensor_like(t_C_tile); + + // LDG. 128 + copy(t_A_tile, reg_buffer_A); + copy(t_B_tile, reg_buffer_B); + + auto reg_C_vector = recast(reg_buffer_C); + auto reg_A_vector = recast(reg_buffer_A); + auto reg_B_vector = recast(reg_buffer_B); + + // Perform vectorized addition +#pragma unroll + for (int vec_idx = 0; vec_idx < size(reg_C_vector); ++vec_idx) { + reg_C_vector(vec_idx) = reg_A_vector(vec_idx) + reg_B_vector(vec_idx); + } + + auto reg_C_flat = recast(reg_C_vector); + + // STG. 128 + copy(reg_C_flat, t_C_tile); +} +``` + +## Conclusion + +This tutorial showcases the implementation of the elementwise addition operator using TileLang, while also comparing various design approaches. TileLang significantly reduces the complexity of CUDA programming, enabling high performance with minimal code. Nevertheless, working with TileLang demands careful attention to specific implementation details. To ensure computational efficiency, it is essential to thoroughly examine the generated CUDA kernels. + +--- + +**Reference:** + +[1] The CuTe code implementation draws inspiration from the techniques discussed in this blog: https://zhuanlan.zhihu.com/p/690703999 diff --git a/docs/deeplearning_operators/gemv.md b/docs/deeplearning_operators/gemv.md new file mode 100644 index 0000000000000000000000000000000000000000..c75a961b8079b75d4a813658b1cae1899a873353 --- /dev/null +++ b/docs/deeplearning_operators/gemv.md @@ -0,0 +1,464 @@ +# General Matrix-Vector Multiplication (GEMV) +=========================================== + +
+ Contributor: @botbw +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + Suggestions and improvements are highly encouragedโ€”please submit a PR! +::: + +:::{tip} +Example code can be found at `examples/gemv/example_gemv.py`. +::: + +General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using `TileLang`. + +## Triton Implementation +When implementing a GEMV kernel, you might start with a high-level approach using a tool like `Triton`. + +A simple Triton kernel for GEMV might look like this: +```python +@triton.jit +def _gemv_naive( + x_ptr, A_ptr, y_ptr, + N, K, + BLOCK_SIZE_K: tl.constexpr, +): + n = tl.program_id(0) + offs_k = tl.arange(0, BLOCK_SIZE_K) + mask = offs_k < K + a_ptrs = A_ptr + n * K + offs_k + a_vals = tl.load(a_ptrs, mask=mask, other=0.0) + x_vals = tl.load(x_ptr + offs_k, mask=mask, other=0.0) + dot = tl.sum(a_vals * x_vals, axis=0) + tl.store(y_ptr + n, dot) +``` + +`Triton` is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in `TileLang` that exposes more low-level control. + +## Naive Implementation in TileLang +If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a `(1, k) * (k, n)` GEMM. Below is a simple example: + +```python +def naive_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: + tn = T.get_thread_binding(0) # tn = threadIdx.x + A_shared = T.alloc_shared((BLOCK_K,), dtype) + B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype) + C_reg = T.alloc_local((1,), accum_dtype) + T.clear(C_reg) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for tk in T.serial(BLOCK_K): + A_shared[tk] = A[bk * BLOCK_K + tk] + B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] + for tk in T.serial(BLOCK_K): + C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, + tk].astype(accum_dtype) + C[bn * BLOCK_N + tn] = C_reg[0] + + return main +``` + +And your kernel will be compiled into CUDA by `TileLang` (in `~/.tilelang/cache`): + +```C++ +extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + float C_reg[1]; + __shared__ uint64_t _mbarrier[2]; + if (((int)threadIdx.x) == 0) { + tl::mbarrier_init(_mbarrier[0], 128); + tl::mbarrier_init(_mbarrier[1], 128); + } + __syncthreads(); + if (128 <= ((int)threadIdx.x)) { + tl::warpgroup_reg_dealloc<24>(); + for (int bk = 0; bk < 8; ++bk) { + tl::mbarrier_wait(_mbarrier[1], ((bk & 1) ^ 1)); + for (int tk = 0; tk < 128; ++tk) { + ((half_t*)buf_dyn_shmem)[tk] = A[((bk * 128) + tk)]; + ((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk) - 16256)] = B[(((((((int)blockIdx.x) * 131072) + (((int)threadIdx.x) * 1024)) + (bk * 128)) + tk) - 131072)]; + } + tl::fence_proxy_async(); + tl::mbarrier_cp_async_arrive(_mbarrier[0]); + tl::mbarrier_arrive(_mbarrier[0]); + } + } else { + tl::warpgroup_reg_alloc<240>(); + C_reg[0] = 0.000000e+00f; + for (int bk_1 = 0; bk_1 < 8; ++bk_1) { + tl::mbarrier_wait(_mbarrier[0], (bk_1 & 1)); + for (int tk_1 = 0; tk_1 < 128; ++tk_1) { + C_reg[0] = (C_reg[0] + (((float)((half_t*)buf_dyn_shmem)[tk_1]) * ((float)((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk_1) + 128)]))); + } + tl::fence_proxy_async(); + tl::mbarrier_arrive(_mbarrier[1]); + } + C[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((half_t)C_reg[0]); + } +} +``` + +In this design, the first 128 threads act as the data producer and the last 128 threads as the consumer within a block (assuming a 1D block). + +At this level, we only gain very little computation power from our GPU with around **~0.17 ms** compared to torch/cuBLAS's **~0.008 ms**, which is around 20x slower. + +## More Concurrency + +To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like `atomicAdd` in CUDA. + +Hereโ€™s a simplified version: +```python +def naive_splitk_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((1,), dtype) + B_local = T.alloc_local((1,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + A_local[0] = A[bk * BLOCK_K + tk] + B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] + C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main +``` + +By introducing parallelism along K dimension, our kernel now achieves **~0.024 ms**, an improvement, but still not on par with torch/cuBLAS. + +### Customizing Parallelism in K Dimension +If your K dimension is large, you can further customize how many elements each thread processes by introducing a `reduce_threads` parameter. This way, each thread handles multiple elements per iteration: + +```python +def splitk_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + reduce_threads: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + TILE_K = T.ceildiv(BLOCK_K, reduce_threads) + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + C_accum = T.alloc_local((1,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.serial(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main +``` + + +## Vectorized Reads + +GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`: + +```python +def splitk_gemv_vectorized( + N: int, + K: int, + BLOCK_N: int, + reduce_threads: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + C_accum = T.alloc_local((1,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main +``` + +With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance. + + +## `tvm_thread_allreduce` Instead of `atomicAdd` + +[`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`: + +```python +def splitk_gemv_vectorized_tvm( + N: int, + K: int, + BLOCK_N: int, + reduce_threads: int, + dtype: str = "float16", + accum_dtype: str = "float", +): + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + C_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_accum[0], + True, + C_reduced[0], + tk, + dtype="handle", + )) + + C[bn * BLOCK_N + tn] = C_reduced[0] + + return main +``` + +With this optimization, the kernel latency now reduces from **~0.0084 ms** to **~0.0069 ms**, which is faster than torch/cuBLAS! + +## Autotune + +`BLOCK_N`, `BLOCK_K`, `reduce_threads` are hyperparameters in our kernel, which can be tuned to improve performance. We can use the `tilelang.autotune` feature to automatically search for optimal configurations: + +```python +def get_best_config(N, K): + + def get_configs(): + BLOCK_N = [2, 4, 8, 32, 64, 128] + reduce_threads = [4, 8, 32] + _configs = list(itertools.product( + BLOCK_N, + reduce_threads, + )) + configs = [{ + "BLOCK_N": c[0], + "reduce_threads": c[1], + } for c in _configs] + return configs + + @autotune( + configs=get_configs(), + warmup=3, + rep=20, + ) + @jit( + out_idx=[-1], + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=False, + target="auto", + ) + def kernel( + BLOCK_N=None, + reduce_threads=None, + ): + dtype = "float16" + accum_dtype = "float" + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Buffer((K,), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + C_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_accum[0], + True, + C_reduced[0], + tk, + dtype="handle", + )) + + C[bn * BLOCK_N + tn] = C_reduced[0] + + return main + + return kernel() +``` + +After autotuning, now our kernel gets **~0.0067 ms**, the final generated CUDA kernel might like this: + +```C++ +extern "C" __global__ void __launch_bounds__(64, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) { + float C_accum[1]; + half_t A_local[8]; + half_t B_local[8]; + __shared__ float red_buf0[64]; + C_accum[0] = 0.000000e+00f; + for (int bk = 0; bk < 4; ++bk) { + *(uint4*)(A_local + 0) = *(uint4*)(A + ((bk * 256) + (((int)threadIdx.y) * 8))); + *(uint4*)(B_local + 0) = *(uint4*)(B + ((((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 1024)) + (bk * 256)) + (((int)threadIdx.y) * 8))); + for (int k = 0; k < 8; ++k) { + C_accum[0] = (C_accum[0] + (((float)A_local[k]) * ((float)B_local[k]))); + } + } + tl::fence_proxy_async(); + __syncthreads(); + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = C_accum[0]; + __syncthreads(); + if (((int)threadIdx.y) < 16) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 16)]); + } + __syncthreads(); + if (((int)threadIdx.y) < 8) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 8)]); + } + __syncthreads(); + if (((int)threadIdx.y) < 4) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 4)]); + } + __syncthreads(); + if (((int)threadIdx.y) < 2) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 2)]); + } + __syncthreads(); + if (((int)threadIdx.y) < 1) { + red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 1)]); + } + __syncthreads(); + C[((((int)blockIdx.x) * 2) + ((int)threadIdx.x))] = ((half_t)red_buf0[(((int)threadIdx.x) * 32)]); +} +``` + +This corresponds closely to our `TileLang` program, with necessary synchronization and low-level optimizations inserted automatically. + +## Conclusion + +### Benchmark Table on Hopper GPU + +| Kernel Name | Latency | +|------------|------------| +| torch/cuBLAS | 0.00784 ms | +| Triton | 0.00773 ms | +| naive_gemv | 0.16607 ms | +| splitk_gemv | 0.02419 ms | +| splitk_gemv_vectorized | 0.00809 ms | +| splitk_gemv_vectorized_tvm | 0.00675 ms | + + +Triton Time: 0.0077344514429569244 +In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives. \ No newline at end of file diff --git a/docs/deeplearning_operators/matmul.md b/docs/deeplearning_operators/matmul.md new file mode 100644 index 0000000000000000000000000000000000000000..fea036ebe48429d8ce40b46a9f5220f5e2d4e828 --- /dev/null +++ b/docs/deeplearning_operators/matmul.md @@ -0,0 +1,259 @@ +# General Matrix-Matrix Multiplication with Tile Library + +
+ Author: Lei Wang +
+ +:::{warning} +:class: myclass1 myclass2 +:name: a-tip-reference + + This document is still **experimental** and may be incomplete. + Suggestions and improvements are highly encouragedโ€”please submit a PR! +::: + +TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction: + +* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM. + +* **Level 2:** A user is aware of GPU architecture conceptsโ€”such as shared memory, tiling, and thread blocksโ€”but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc. + +* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc. + +```{figure} ../_static/img/overview.png +:width: 50% +:alt: Overview +:align: center + +Figure 1: High-level overview of the TileLang compilation flow. +``` + +In this tutorial, we introduce Level 2 with a matrix multiplication example in TileLang. We will walk through how to allocate shared memory, set up thread blocks, perform parallel copying, pipeline the computation, and invoke the tile-level GEMM intrinsic. We will then show how to compile and run the kernel in Python, comparing results and measuring performance. + +## Why Another GPU DSL? + +TileLang emerged from the need for a DSL that: + +1. Balances high-level expressiveness (like TVM or Triton) with enough flexibility to control finer details when needed. +2. Supports efficient code generation and scheduling for diverse hardware backends (NVIDIA GPUs, AMD GPUs, CPU, etc.). +3. Simplifies scheduling and memory pipelines with built-in primitives (such as `T.Pipelined`, `T.Parallel`, `T.gemm`), yet retains options for expert-level tuning. + +While Level 1 in TileLang can be very comfortable for general usersโ€”since it requires no scheduling or hardware-specific knowledgeโ€”it can incur longer auto-tuning times and may not handle some complex kernel fusion patterns (e.g., Flash Attention) as easily. Level 3 gives you full control but demands more effort, similar to writing raw CUDA/HIP kernels. Level 2 thus strikes a balance for users who want to write portable and reasonably concise code while expressing important architectural hints. + +## Matrix Multiplication Example + +```{figure} ../_static/img/MatmulExample.png +:alt: Matmul Example +:align: center + +``` + +### Basic Structure + +Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses: + +* **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.). +* **`T.alloc_shared(...)`** to allocate GPU shared memory. +* **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation. +* **`T.Pipelined(...)`** to express software pipelining across the K dimension. +* **`T.Parallel(...)`** to parallelize data copy loops. +* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs). + +```python +import tilelang +import tilelang.language as T +from tilelang.intrinsics import make_mma_swizzle_layout + +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Optional layout hints (commented out by default) + # T.annotate_layout({ + # A_shared: make_mma_swizzle_layout(A_shared), + # B_shared: make_mma_swizzle_layout(B_shared), + # }) + + # Optional: Enabling swizzle-based rasterization + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A from global to shared memory + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Parallel copy tile of B from global to shared memory + for k, j in T.Parallel(block_K, block_N): + B_shared[k, j] = B[ko * block_K + k, bx * block_N + j] + + # Perform a tile-level GEMM + T.gemm(A_shared, B_shared, C_local) + + # Copy result from local (register fragment) to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + +# 1. Create the TileLang function +func = matmul(1024, 1024, 1024, 128, 128, 32) + +# 2. JIT-compile the kernel for NVIDIA GPU +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda") + +import torch + +# 3. Prepare input tensors in PyTorch +a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) +b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) + +# 4. Invoke the JIT-compiled kernel +c = jit_kernel(a, b) +ref_c = a @ b + +# 5. Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 6. Inspect generated CUDA code (optional) +cuda_source = jit_kernel.get_kernel_source() +print("Generated CUDA kernel:\n", cuda_source) + +# 7. Profile performance +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +``` + +### Key Concepts + +1. **Kernel Context**: + +```python +with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + ... +``` + +- This sets up the block grid dimensions based on N/block_N and M/block_M. +- `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads. + + +```{figure} ../_static/img/Parallel.png +:alt: Parallel +:align: center + +``` + + +2. **Shared & Fragment Memory**: + +```python +A_shared = T.alloc_shared((block_M, block_K), dtype) +B_shared = T.alloc_shared((block_K, block_N), dtype) +C_local = T.alloc_fragment((block_M, block_N), accum_dtype) +``` + +- `T.alloc_shared` allocates shared memory across the entire thread block. +- `T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compilerโ€™s layout inference assigns slices of this space to each thread. + +3. **Software Pipelining**: + +```python +for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + ... +``` + +- `T.Pipelined` automatically arranges asynchronous copy and compute instructions to overlap memory operations with arithmetic. +- The argument `num_stages=3` indicates the pipeline depth. + +```{figure} ../_static/img/software_pipeline_inference.png +:alt: Software Pipeline Inference +:align: center + +``` + + +4. **Parallel Copy**: + +```python +for k, j in T.Parallel(block_K, block_N): + B_shared[k, j] = B[ko * block_K + k, bx * block_N + j] +``` + +- `T.Parallel` marks the loop for thread-level parallelization. +- The compiler will map these loops to the available threads in the block. + +5. **Tile-Level GEMM**: + +```python +T.gemm(A_shared, B_shared, C_local) +``` + +- A single call that performs a tile-level matrix multiplication using the specified buffers. +- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach. + +6. **Copying Back Results**: + +```python +T.copy(C_local, C[by * block_M, bx * block_N]) +``` + +- After computation, data in the local register fragment is written back to global memory. + +## Comparison with Other DSLs + +TileLang Level 2 is conceptually similar to Triton in that the user can control tiling and parallelization, while letting the compiler handle many low-level details. However, TileLang also: + +- Allows explicit memory layout annotations (e.g. `make_mma_swizzle_layout`). +- Supports a flexible pipeline pass (`T.Pipelined`) that can be automatically inferred or manually defined. +- Enables mixing different levels in a single programโ€”for example, you can write some parts of your kernel in Level 3 (thread primitives) for fine-grained PTX/inline-assembly and keep the rest in Level 2. + +## Performance on Different Platforms + +```{figure} ../_static/img/op_benchmark_consistent_gemm_fp16.png +:alt: Performance on Different Platforms +:align: center + +``` + +When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves performance comparable to or better than vendor libraries and Triton on various GPUs. In internal benchmarks, for an FP16 matrix multiply (e.g., 4090, A100, H100, MI300X), TileLang has shown: + +- ~1.1x speedup over cuBLAS on RTX 4090 +- ~0.97x on A100 (on par with cuBLAS) +- ~1.0x on H100 +- ~1.04x on MI300X +- Compared to Triton, speedups range from 1.08x to 1.25x depending on the hardware. + +These measurements will vary based on tile sizes, pipeline stages, and the hardwareโ€™s capabilities. + +## Conclusion + +This tutorial demonstrated a Level 2 TileLang kernel for matrix multiplication. With just a few lines of code: + +1. We allocated shared memory and register fragments. +2. We pipelined the loading and computation along the K dimension. +3. We used parallel copying to efficiently load tiles from global memory. +4. We invoked `T.gemm` to dispatch a tile-level matrix multiply. +5. We verified correctness against PyTorch and examined performance. + +By balancing high-level abstractions (like `T.copy`, `T.Pipelined`, `T.gemm`) with the ability to annotate layouts or drop to thread primitives (Level 3) when needed, TileLang can be both user-friendly and highly tunable. We encourage you to experiment with tile sizes, pipeline depths, or explicit scheduling to see how performance scales across different GPUs. + +For more advanced usageโ€”including partial lowering, explicitly controlling thread primitives, or using inline assemblyโ€”you can explore Level 3. Meanwhile, for purely functional expressions and high-level scheduling auto-tuning, consider Level 1. + +## Further Resources + +* [TileLang GitHub](https://github.com/tile-ai/tilelang) +* [BitBLAS](https://github.com/tile-ai/bitblas) +* [Triton](https://github.com/openai/triton) +* [Cutlass](https://github.com/NVIDIA/cutlass) +* [PyCUDA](https://documen.tician.de/pycuda/) diff --git a/docs/deeplearning_operators/matmul_sparse.md b/docs/deeplearning_operators/matmul_sparse.md new file mode 100644 index 0000000000000000000000000000000000000000..5910bd6f8c25943ee18bbd65b7ed7fa0b060de5a --- /dev/null +++ b/docs/deeplearning_operators/matmul_sparse.md @@ -0,0 +1,262 @@ +# Sparse Matrix-Matrix Multiplication with Tile Library + +
+ Author: botbw +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + + This feature is still **experimental** and need further optimization. + + Suggestions and improvements are highly encouragedโ€”please submit a PR! +::: + +:::{tip} +It's suggested to go through `docs/deeplearning_operators/matmul.md` first. + +Example code can be found at `examples/gemm_sp`. +::: + +## Structured sparsity in the NVIDIA Ampere architecture + +Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation. + +:::{warning} + This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X. +::: + +```{figure} ../_static/img/sparse_mma_storage_example.png +:align: center + +Figure: Sparse MMA storage example (from PTX doc) +``` + +## Compress a dense tensor + +To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata. + +Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`โ€™s built-in compressor (or reimplementing it in `PyTorch`). + +A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensorโ€”along with other required arguments (e.g., block_K for sm90, transpose options)โ€”can be passed in to perform the compression. + +```python +from tilelang.utils.sparse import compress +A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) +``` + +Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern. + +> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor) +The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads). +For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**. + + +## `T.gemm_sp` with CUTLASS's compressor + +:::{warning} + +It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time. + +::: + +A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata. + +Check comments in below kernel code for required modification. + +```python +def matmul_sp_sm80( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + trans_A, + trans_B, +): + is_8_bit = "8" in in_dtype + metadata_dtype = 'int32' if is_8_bit else 'int16' + E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ # Annotate reordered cutlass metadata layout + E: + make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), + E_shared: + make_cutlass_metadata_layout( + E_shared, mma_dtype=in_dtype, arch="8.0"), + }) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main +``` + +Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`. + +## `T.gemm_sp_v2` with a custom compressor + +To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`. + +Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors. + +The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs. + +Suppose we have the following row vector: +```python +t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten() +``` + +The non-zero elements and their corresponding indices are: + +```python +t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten() +indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten() +``` + +The corresponding uint16 metadata is: +```python +# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000]) +# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16) +# Note: the above code is not runnable in python as the interpreter won't take the binary +# as 2's complement +metadata_int16 = tensor(-29107) +``` + +You can decode an int16 metadata tensor using the following utility: +```python +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) +``` + +The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level. + +For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASSโ€™s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch exampleโ€”but without using the `_calculate_meta_reordering_scatter_offsets` function. + +If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel. + +```python + +@tilelang.jit(out_idx=[1, 2], pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, +}) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: # NOTE: Make sure compressor metadata layout + T.annotate_layout({ # is same with your computation kernel + E: + make_cutlass_metadata_layout( + E, mma_dtype="float16", arch="8.0", block_k=block_K), + E_shared: + make_cutlass_metadata_layout( + E_shared, + mma_dtype="float16", + arch="8.0", + block_k=block_K), + }) + T.clear(A_sp_shared) + T.clear(E_shared) + non_zero_cnt = T.alloc_local((1, ), dtype="uint8") + non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8") + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + T.clear(non_zero_cnt) + T.clear(non_zero_elt_log_idx) + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel +``` + +## A note on `gemm_sp` and `gemm_sp_v2` + +Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout. + +However, fixing a specific layout introduces several potential issues: + +1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling. + +2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically. + +3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.) + +`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout. \ No newline at end of file diff --git a/docs/get_started/Installation.md b/docs/get_started/Installation.md new file mode 100644 index 0000000000000000000000000000000000000000..8fa41c023ad82fcf6004b230c9b556f87aaa32a4 --- /dev/null +++ b/docs/get_started/Installation.md @@ -0,0 +1,260 @@ +# Installation Guide + +## Installing with pip + +**Prerequisites for installation via wheel or PyPI:** + +- **glibc**: 2.28 (Ubuntu 20.04 or later) +- **Python Version**: >= 3.8 +- **CUDA Version**: 12.0 <= CUDA < 13 + +The easiest way to install tilelang is directly from PyPI using pip. To install the latest version, run the following command in your terminal: + +```bash +pip install tilelang +``` + +Alternatively, you may choose to install tilelang using prebuilt packages available on the Release Page: + +```bash +pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl +``` + +To install the latest version of tilelang from the GitHub repository, you can run the following command: + +```bash +pip install git+https://github.com/tile-ai/tilelang.git +``` + +After installing tilelang, you can verify the installation by running: + +```bash +python -c "import tilelang; print(tilelang.__version__)" +``` + +## Building from Source + +**Prerequisites for building from source:** + +- **Operating System**: Linux +- **Python Version**: >= 3.8 +- **CUDA Version**: >= 10.0 + +If you prefer Docker, please skip to the [Install Using Docker](#install-using-docker) section. This section focuses on building from source on a native Linux environment. + +First, install the OS-level prerequisites on Ubuntu/Debian-based systems using the following commands: + +```bash +apt-get update +apt-get install -y python3 python3-dev python3-setuptools gcc zlib1g-dev build-essential cmake libedit-dev +``` + +Then, clone the tilelang repository and install it using pip. The `-v` flag enables verbose output during the build process. + +> **Note**: Use the `--recursive` flag to include necessary submodules. Tilelang currently depends on a customized version of TVM, which is included as a submodule. If you prefer [Building with Existing TVM Installation](#using-existing-tvm), you can skip cloning the TVM submodule (but still need other dependencies). + +```bash +git clone --recursive https://github.com/tile-ai/tilelang.git +cd tilelang +pip install . -v +``` + +If you want to install tilelang in development mode, you can use the `-e` flag so that any changes to the Python files will be reflected immediately without reinstallation. + +```bash +pip install -e . -v +``` + +> **Note**: changes to C++ files require rebuilding the tilelang C++ library. See [Faster Rebuild for Developers](#faster-rebuild-for-developers) below. A default `build` directory will be created if you use `pip install`, so you can also directly run `make` in the `build` directory to rebuild it as [Working from Source via PYTHONPATH](#working-from-source-via-pythonpath) suggested below. + +(working-from-source-via-pythonpath)= + +### Working from Source via `PYTHONPATH` (Recommended for Developers) + +If you prefer to work directly from the source tree via `PYTHONPATH` instead of using pip, make sure the native extension (`libtilelang.so`) is built first: + +```bash +mkdir -p build +cd build +cmake .. -DUSE_CUDA=ON +make -j +``` + +We also recommend using `ninja` to speed up compilation: + +```bash +cmake .. -DUSE_CUDA=ON -G Ninja +ninja +``` + +Then add the repository root to `PYTHONPATH` before importing `tilelang`, for example: + +```bash +export PYTHONPATH=/path/to/tilelang:$PYTHONPATH +python -c "import tilelang; print(tilelang.__version__)" +``` + +Some useful CMake options you can toggle while configuring: +- `-DUSE_CUDA=ON|OFF` builds against NVIDIA CUDA (default ON when CUDA headers are found). +- `-DUSE_ROCM=ON` selects ROCm support when building on AMD GPUs. +- `-DNO_VERSION_LABEL=ON` disables the backend/git suffix in `tilelang.__version__`. + +(using-existing-tvm)= + +### Building with Customized TVM Path + +If you already have a TVM codebase, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang: + +```bash +TVM_ROOT= pip install . -v +``` + +> **Note**: This will still rebuild the TVM-related libraries (stored in `TL_LIBS`). And this method often leads to some path issues. Check `env.py` to see some environment variables which are not set properly. + +(install-using-docker)= + +## Install Using Docker + +For users who prefer a containerized environment with all dependencies pre-configured, tilelang provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems. + +**Prerequisites:** +- Docker installed on your system +- NVIDIA Docker runtime or GPU is not necessary for building tilelang, you can build on a host without GPU and use that built image on other machine. + +1. **Clone the Repository**: + +```bash +git clone --recursive https://github.com/tile-ai/tilelang +cd tilelang +``` + +2. **Build Docker Image**: + +Navigate to the docker directory and build the image for your desired CUDA version: + +```bash +cd docker +docker build -f Dockerfile.cu120 -t tilelang-cu120 . +``` + +Available Dockerfiles: +- `Dockerfile.cu120` - For CUDA 12.0 +- Other CUDA versions may be available in the docker directory + +3. **Run Docker Container**: + +Start the container with GPU access and volume mounting: + +```bash +docker run -itd \ + --shm-size 32g \ + --gpus all \ + -v /home/tilelang:/home/tilelang \ + --name tilelang_b200 \ + tilelang-cu120 \ + /bin/zsh +``` + +**Command Parameters Explanation:** +- `--shm-size 32g`: Increases shared memory size for better performance +- `--gpus all`: Enables access to all available GPUs +- `-v /home/tilelang:/home/tilelang`: Mounts host directory to container (adjust path as needed) +- `--name tilelang_b200`: Assigns a name to the container for easy management +- `/bin/zsh`: Uses zsh as the default shell + +4. **Access the Container and Verify Installation**: + +```bash +docker exec -it tilelang_b200 /bin/zsh +# Inside the container: +python -c "import tilelang; print(tilelang.__version__)" +``` + +## Install with Nightly Version + +For users who want access to the latest features and improvements before official releases, we provide nightly builds of tilelang. + +```bash +pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/ +# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/ +``` + +> **Note:** Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet. + +## Install Configs + +### Build-time environment variables +`USE_CUDA`: If to enable CUDA support, default: `ON` on Linux, set to `OFF` to build a CPU version. By default, we'll use `/usr/local/cuda` for building tilelang. Set `CUDAToolkit_ROOT` to use different cuda toolkit. + +`USE_ROCM`: If to enable ROCm support, default: `OFF`. If your ROCm SDK does not located in `/opt/rocm`, set `USE_ROCM=` to enable build ROCm against custom sdk path. + +`USE_METAL`: If to enable Metal support, default: `ON` on Darwin. + +`TVM_ROOT`: TVM source root to use. + +`NO_VERSION_LABEL` and `NO_TOOLCHAIN_VERSION`: +When building tilelang, we'll try to embed SDK and version information into package version as below, +where local version label could look like `.git`. Set `NO_VERSION_LABEL=ON` to disable this behavior. +``` +$ python -mbuild -w +... +Successfully built tilelang-0.1.6.post1+cu116.git0d4a74be-cp38-abi3-linux_x86_64.whl +``` + +where `={cuda,rocm,metal}`. Specifically, when `=cuda` and `CUDA_VERSION` is provided via env, +`=cu`, similar with this part in pytorch. +Set `NO_TOOLCHAIN_VERSION=ON` to disable this. + +### Run-time environment variables + +Please refer to the `env.py` file for a full list of supported run-time environment variables. + +## Other Tips + +### IDE Configs + +Building tilelang locally will automatically generate a `compile_commands.json` file in `build` dir. +VSCode with clangd and [clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) should be able to index that without extra configuration. + +### Compile Cache + +The default path of the compile cache is `~/.tilelang/cache`. `ccache` will be automatically used if found. + +### Repairing Wheels + +If you plan to use your wheel in other environment, +it's recommended to use auditwheel (on Linux) or delocate (on Darwin) +to repair them. + +(faster-rebuild-for-developers)= + +### Faster Rebuild for Developers + +`pip install` introduces extra [un]packaging and takes ~30 sec to complete, +even if no source change. + +Developers who needs to recompile frequently could use: + +```bash +pip install -r requirements-dev.txt + +# For first time compilation +pip install -e . -v --no-build-isolation + +# Or manually compile with cmake/ninja. Remember to set PYTHONPATH properly. +mkdir build +cd build +cmake .. -G Ninja +ninja + +# Rebuild when you change the cpp code +cd build; ninja +``` + +When running in editable/developer mode, +you'll see logs like below: + +```console +$ python -c 'import tilelang' +2025-10-14 11:11:29 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /Users/yyc/repo/tilelang/build +``` diff --git a/docs/get_started/overview.md b/docs/get_started/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..18fa9f1936fcb5f9b5dedb9efd394992acf243f6 --- /dev/null +++ b/docs/get_started/overview.md @@ -0,0 +1,91 @@ +# The Tile Language: A Brief Introduction + +## Programming Interface + +The figure below depicts how **TileLang** programs are progressively lowered from a high-level description to hardware-specific executables. We provide three different programming interfacesโ€”targeted at **Beginner**, **Developer**, and **Expert** usersโ€”that each reside at different levels in this lowering pipeline. The **Tile Language** also allows mixing these interfaces within the same kernel, enabling users to work at whichever level of abstraction best suits their needs. + +```{figure} ../_static/img/overview.png +:width: 50% +:alt: Overview +:align: center + +Figure 1: High-level overview of the TileLang compilation flow. +``` + +## Programming Interfaces + +1. **Beginner Level (Hardware-Unaware)** + - Intended for users who need to write code that is independent of specific hardware details. + - The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations. + - *Note:* This interface is not yet fully implemented. + +2. **Developer Level (Hardware-Aware with Tile Library)** + - Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations. + - Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures. + - Users at this level can leverage these ready-made primitives without diving into low-level threading details. + +3. **Expert Level (Hardware-Aware with Thread Primitives)** + - For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing). + - Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels. + - This level grants maximum flexibility for specialized optimizations tailored to specific GPU or multi-core architectures. + +## Compilation Flow + +1. **Tile Program** + A high-level specification of the computation. Depending on the userโ€™s expertise, they may write a purely hardware-unaware tile program or incorporate constructs from the Tile Library or thread primitives. + +2. **Tile Program with Tile Library** + When developers choose from the Tile Library, the original Tile Program is expanded with specialized library calls. These calls encapsulate efficient implementation patterns for different operations. + +3. **Tile Program with Thread Primitives** + Expert-level developers can explicitly use low-level threading constructs to hand-optimize data layout, synchronization, and memory usage. + +4. **IRModule** + After the program is composed with libraries or thread primitives, it is lowered to an intermediate representation (IR) that captures the necessary hardware details. + +5. **Source Code Generation (C/CUDA/HIP/LLVM/โ€ฆ)** + From the IR, the system generates target-specific source code. This source code is tuned for the desired backends or GPU architectures (e.g., NVIDIA, AMD). + +6. **Hardware-Specific Executable/Runtime** + Finally, the generated source is compiled into hardware-specific executables, ready to run on the corresponding devices. The pipeline supports multiple GPU backends and can be extended to additional architectures. + +## Tile-based Programming Model + +[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``, +illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining, +and operator calls to manage data movement and computation with fine-grained control. +In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling +leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization +and reduce latency. +Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang`` +allows developers to reason about performance-critical optimizations within a user-friendly programming model. + +```{figure} ../_static/img/MatmulExample.png +:align: center +:width: 100% +:alt: GEMM with Multi-Level Tiling on GPUs +:name: fig-overview-gemm + +Figure 2: Optimizing GEMM with Multi-Level Tiling on GPUs via ``TileLang``. +``` + +### Tile declarations + +At the heart of our approach is the notion of *tiles* as first-class objects in the programming model. A tile represents a shaped portion of data, which can be owned and manipulated by a warp, thread block, or equivalent parallel unit. In the `Matmul` example, the `A` and `B` buffers are read in tiled chunks (determined by `block_M`, `block_N`, `block_K`) inside the kernel loop. With `T.Kernel`, `TileLang` defines the execution context, which includes the thread block index (`bx` and `by`) and the number of threads. These contexts can help compute the index for each thread block and make it easier for `TileLang` to automatically infer and optimize memory access and computation. Additionally, these contexts allow users to manually control the behavior of each independent thread within a thread block. + +### Explicit Hardware Memory Allocation + +A hallmark of `TileLang` is the ability to explicitly place these tile buffers in the hardware memory hierarchy. Rather than leaving it to a compiler's opaque optimization passes, `TileLang` exposes user-facing intrinsics that map directly to physical memory spaces or accelerator-specific constructs. In particular: + +- `T.alloc_shared`: Allocates memory in a fast, on-chip storage space, which corresponds to shared memory on NVIDIA GPUs. Shared memory is ideal for caching intermediate data during computations, as it is significantly faster than global memory and allows for efficient data sharing between threads in the same thread block. For example, in matrix multiplication, tiles of matrices can be loaded into shared memory to reduce global memory bandwidth demands and improve performance. + +- `T.alloc_fragment`: Allocates accumulators in fragment memory, which corresponds to register files on NVIDIA GPUs. By keeping inputs and partial sums in registers or hardware-level caches, latency is further minimized. Note that in this tile program, each tile allocates the same local buffers as shared memory, which might seem counterintuitive, as shared memory is generally faster but more abundant, whereas register file space is limited. This is because the allocation here refers to the register files for an entire thread block. `TileLang` uses a Layout Inference Pass during compilation to derive a Layout object `T.Fragment`, which determines how to allocate the corresponding register files for each thread. This process will be discussed in detail in subsequent sections. + +Data transfer between global memory and hardware-specific memory can be managed using `T.copy`. Furthermore, hardware-specific buffers can be initialized using `T.clear` or `T.fill`. For data assignments, operations can also be performed in parallel using `T.Parallel`, as demonstrated in Layout Inference Pass in the following sections. + +```{figure} ../_static/img/LayoutInference.png + :align: center + :width: 100% + :alt: GEMM with Multi-Level Tiling on GPUs + +``` diff --git a/docs/get_started/targets.md b/docs/get_started/targets.md new file mode 100644 index 0000000000000000000000000000000000000000..c2b3f2fb5ac7b119e1b084bb8694b99765eab40b --- /dev/null +++ b/docs/get_started/targets.md @@ -0,0 +1,120 @@ +# Understanding Targets + +TileLang is built on top of TVM, which relies on **targets** to describe the device you want to compile for. +The target determines which code generator is used (CUDA, HIP, Metal, LLVM, โ€ฆ) and allows you to pass +device-specific options such as GPU architecture flags. This page summarises how to pick and customise a target +when compiling TileLang programs. + +## Common target strings + +TileLang ships with a small set of common targets; each accepts the full range of TVM options so you can fine-tune +the generated code. The most frequent choices are listed below: + +| Base name | Description | +| --------- | ----------- | +| `auto` | Detects CUDA โ†’ HIP โ†’ Metal in that order. Useful when running the same script across machines. | +| `cuda` | NVIDIA GPUs. Supports options such as `-arch=sm_80`, `-max_num_threads=1024`, etc. | +| `hip` | AMD GPUs via ROCm. Options like `-mcpu=gfx90a` can be appended. | +| `metal` | Apple Silicon GPUs (arm64 Macs). | +| `llvm` | CPU execution; accepts the standard TVM LLVM switches. | +| `webgpu` | Browser / WebGPU runtimes. | +| `c` | Emit plain C source for inspection or custom toolchains. | + +To add options, append them after the base name, separated by spaces. For example: + +```python +target = "cuda -arch=sm_90" +kernel = tilelang.compile(func, target=target, execution_backend="cython") +# or +@tilelang.jit(target=target) +def compiled_kernel(*args): + return func(*args) +``` + +The same convention works for HIP or LLVM (e.g. `hip -mcpu=gfx940`, `llvm -mtriple=x86_64-linux-gnu`). + +### Advanced: Specify Exact Hardware + +When you already know the precise GPU model, you can encode it in the target stringโ€”either via `-arch=sm_XX` or by +using one of TVMโ€™s pre-defined target tags such as `nvidia/nvidia-h100`. Supplying this detail is optional for +TileLang in general use, but it becomes valuable when the TVM cost model is enabled (e.g. during autotuning). The +cost model uses the extra attributes to make better scheduling predictions. If you skip this step (or do not use the +cost model), generic targets like `cuda` or `auto` are perfectly fine. + +All CUDA compute capabilities recognised by TVMโ€™s target registry are listed below. Pick the one that matches your +GPU and append it to the target string or use the corresponding target tagโ€”for example `nvidia/nvidia-a100`. + +| Architecture | GPUs (examples) | +| ------------ | ---------------- | +| `sm_20` | `nvidia/tesla-c2050`, `nvidia/tesla-c2070` | +| `sm_21` | `nvidia/nvs-5400m`, `nvidia/geforce-gt-520` | +| `sm_30` | `nvidia/quadro-k5000`, `nvidia/geforce-gtx-780m` | +| `sm_35` | `nvidia/tesla-k40`, `nvidia/quadro-k6000` | +| `sm_37` | `nvidia/tesla-k80` | +| `sm_50` | `nvidia/quadro-k2200`, `nvidia/geforce-gtx-950m` | +| `sm_52` | `nvidia/tesla-m40`, `nvidia/geforce-gtx-980` | +| `sm_53` | `nvidia/jetson-tx1`, `nvidia/jetson-nano` | +| `sm_60` | `nvidia/tesla-p100`, `nvidia/quadro-gp100` | +| `sm_61` | `nvidia/tesla-p4`, `nvidia/quadro-p6000`, `nvidia/geforce-gtx-1080` | +| `sm_62` | `nvidia/jetson-tx2` | +| `sm_70` | `nvidia/nvidia-v100`, `nvidia/quadro-gv100` | +| `sm_72` | `nvidia/jetson-agx-xavier` | +| `sm_75` | `nvidia/nvidia-t4`, `nvidia/quadro-rtx-8000`, `nvidia/geforce-rtx-2080` | +| `sm_80` | `nvidia/nvidia-a100`, `nvidia/nvidia-a30` | +| `sm_86` | `nvidia/nvidia-a40`, `nvidia/nvidia-a10`, `nvidia/geforce-rtx-3090` | +| `sm_87` | `nvidia/jetson-agx-orin-32gb`, `nvidia/jetson-agx-orin-64gb` | +| `sm_89` | `nvidia/geforce-rtx-4090` | +| `sm_90a` | `nvidia/nvidia-h100` (DPX profile) | +| `sm_100a` | `nvidia/nvidia-b100` | + +Refer to NVIDIAโ€™s [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) page or the TVM source +(`3rdparty/tvm/src/target/tag.cc`) for the latest mapping between devices and compute capabilities. + +## Creating targets programmatically + +If you prefer working with TVMโ€™s `Target` objects, TileLang exposes the helper +`tilelang.utils.target.determine_target` (returns a canonical target string by default, or the `Target` +object when `return_object=True`): + +```python +from tilelang.utils.target import determine_target + +tvm_target = determine_target("cuda -arch=sm_80", return_object=True) +kernel = tilelang.compile(func, target=tvm_target) +``` + +You can also build targets directly through TVM: + +```python +from tvm.target import Target + +target = Target("cuda", host="llvm") +target = target.with_host(Target("llvm -mcpu=skylake")) +``` + +TileLang accepts either `str` or `Target` inputs; internally they are normalised and cached using the canonical +string representation. **In user code we strongly recommend passing target strings rather than +`tvm.target.Target` instancesโ€”strings keep cache keys compact and deterministic across runs, whereas constructing +fresh `Target` objects may lead to slightly higher hashing overhead or inconsistent identity semantics.** + +## Discovering supported targets in code + +Looking for a quick reminder of the built-in base names and their descriptions? Use: + +```python +from tilelang.utils.target import describe_supported_targets + +for name, doc in describe_supported_targets().items(): + print(f"{name:>6}: {doc}") +``` + +This helper mirrors the table above and is safe to call at runtime (for example when validating CLI arguments). + +## Troubleshooting tips + +- If you see `Target cuda -arch=sm_80 is not supported`, double-check the spellings and that the option is valid for + TVM. Any invalid switch will surface as a target-construction error. +- Runtime errors such as โ€œno kernel image is availableโ€ usually mean the `-arch` flag does not match the GPU you are + running on. Try dropping the flag or switching to the correct compute capability. +- When targeting multiple environments, use `auto` for convenience and override with an explicit string only when + you need architecture-specific tuning. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000000000000000000000000000000000000..55804259a46a857b3589919e0536e565950ffa2d --- /dev/null +++ b/docs/index.md @@ -0,0 +1,74 @@ +# ๐Ÿ‘‹ Welcome to Tile Language + +[GitHub](https://github.com/tile-ai/tilelang) + +Tile Language (tile-lang) is a concise domain-specific language designed to streamline +the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). +By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM, +tile-lang allows developers to focus on productivity without sacrificing the +low-level optimizations necessary for state-of-the-art performance. + +:::{toctree} +:maxdepth: 2 +:caption: GET STARTED + +get_started/Installation +get_started/overview +get_started/targets +::: + + +:::{toctree} +:maxdepth: 1 +:caption: TUTORIALS + +tutorials/debug_tools_for_tilelang +tutorials/auto_tuning +tutorials/logging +::: + +:::{toctree} +:maxdepth: 1 +:caption: PROGRAMMING GUIDES + +programming_guides/overview +programming_guides/language_basics +programming_guides/instructions +programming_guides/control_flow +programming_guides/autotuning +programming_guides/type_system +::: + +:::{toctree} +:maxdepth: 1 +:caption: DEEP LEARNING OPERATORS + +deeplearning_operators/elementwise +deeplearning_operators/gemv +deeplearning_operators/matmul +deeplearning_operators/matmul_sparse +deeplearning_operators/deepseek_mla +::: + +:::{toctree} +:maxdepth: 1 +:caption: COMPILER INTERNALS + +compiler_internals/letstmt_inline +compiler_internals/inject_fence_proxy +compiler_internals/tensor_checks +::: + +:::{toctree} +:maxdepth: 1 +:caption: API Reference + +autoapi/tilelang/index +::: + +:::{toctree} +:maxdepth: 1 +:caption: Privacy + +privacy +::: diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..51d365274a7f0d8fe96b9a730935939f96944c4d --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/privacy.md b/docs/privacy.md new file mode 100644 index 0000000000000000000000000000000000000000..3fb712bc28772d60b1d5eaf2597308716958b418 --- /dev/null +++ b/docs/privacy.md @@ -0,0 +1,3 @@ +# Privacy + +All data stays in users' device and is not collected by the app. diff --git a/docs/programming_guides/autotuning.md b/docs/programming_guides/autotuning.md new file mode 100644 index 0000000000000000000000000000000000000000..66d46889fe88e4b1de874d552e2f1c922c534660 --- /dev/null +++ b/docs/programming_guides/autotuning.md @@ -0,0 +1,308 @@ +# Autotuning + +TileLang includes a builtโ€‘in autotuner that searches configuration spaces +for the best performing kernel, compiles candidates in parallel, validates +correctness, benchmarks them, and caches the best result for reuse. + +This guide covers two workflows: +- Decoratorโ€‘based: `@tilelang.autotune(configs=...)` stacked on `@tilelang.jit` +- Programmatic: `AutoTuner.from_kernel(...).set_*().run()` + +It also explains input tensor supply, validation, caching, and environment +variables that affect parallelism and cache behavior. + +## 1) Decoratorโ€‘based Autotune + +Use `@tilelang.autotune` above `@tilelang.jit` and expose tunable parameters as +function arguments with defaults. The autotuner overrides these parameters with +values from your config space. + +```python +import tilelang +import tilelang.language as T + +def matmul_configs(M, N, K): + # Example space โ€” tailor to your target + tiles = [64, 128] + stages = [2, 3] + threads = [128, 256] + return [ + dict(block_M=BM, block_N=BN, block_K=BK, num_stages=S, threads=TH) + for BM in tiles + for BN in tiles + for BK in [32, 64] + for S in stages + for TH in threads + ] + +@tilelang.autotune(configs=matmul_configs, warmup=25, rep=100, timeout=60) +@tilelang.jit(out_idx=[-1]) +def matmul(M: int, N: int, K: int, + block_M: int = 128, block_N: int = 128, block_K: int = 32, + threads: int = 128, num_stages: int = 3, + dtype: str = 'float16', accum_dtype: str = 'float32'): + + @T.prim_func + def kernel(A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_s = T.alloc_shared((block_M, block_K), dtype) + B_s = T.alloc_shared((block_K, block_N), dtype) + C_f = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_s) + T.copy(B[ko * block_K, bx * block_N], B_s) + T.gemm(A_s, B_s, C_f) + + T.copy(C_f, C[by * block_M, bx * block_N]) + + return kernel + +# Usage +# Provide inputs via context (recommended for reproducibility across configs) +import torch +M = N = K = 1024 +A = torch.randn(M, K, device='cuda', dtype=torch.float16) +B = torch.randn(K, N, device='cuda', dtype=torch.float16) +C = torch.empty(M, N, device='cuda', dtype=torch.float16) + +from tilelang.autotuner import set_autotune_inputs +with set_autotune_inputs(A, B, C): + tuned_kernel = matmul(M, N, K) # compiles, tunes, returns best kernel + tuned_kernel(A, B, C) # run best kernel +``` + +Notes +- `configs` can be a list of dicts or a callable `(args...) -> list[dict]`. Each + dictโ€™s keys must match the tunable function arguments (e.g., `block_M`). +- The decorator returns a callable that runs autotune once per argument tuple + and caches the resulting best kernel inโ€‘process. +- For explicit input control during tuning, wrap the call with + `set_autotune_inputs(...)`. Otherwise, `supply_type` (below) is used. + +## 2) Programmatic Autotune + +Use the `AutoTuner` class to manage configs and arguments more explicitly. + +```python +from tilelang.autotuner import AutoTuner + +kernel_factory = matmul # the function above (already @tilelang.jit) +tuner = AutoTuner.from_kernel(kernel_factory(M, N, K), configs=matmul_configs(M, N, K)) + +tuner.set_profile_args( + warmup=25, rep=100, timeout=60, + supply_type=tilelang.TensorSupplyType.Auto, # or provide supply_prog/ref_prog + ref_prog=lambda A, B, C: torch.allclose(C, (A @ B).to(C.dtype), rtol=1e-2, atol=1e-2), +) + +tuner.set_compile_args( + target='auto', # or 'cuda'/'hip'/'metal' + execution_backend='auto', # resolves per-target + out_idx=[-1], # which outputs to return if multiple + pass_configs={ # optional TVM passes/flags + # tilelang.PassConfigKey.EXAMPLE_KEY: value, + }, +) + +artifact = tuner.run() # compiles + runs + validates all configs +best_kernel = artifact.kernel # JITKernel +best_latency = artifact.latency +best_config = artifact.config + +# Reuse best kernel +best_kernel(A, B, C) +``` + +### Example Gallery (in repo) +- examples/gdn/example_chunk_delta_h.py:101 โ€” uses `@autotune` to sweep configs +- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py:451 โ€” uses `@tilelang.autotune` +- examples/quickstart.py:84 โ€” profiles a tuned kernel with `get_profiler` +- examples/hadamard_transform/example_hadamard.py:152 โ€” profiler with custom warmup +- examples/dynamic_shape/example_dynamic.py:94 โ€” profiler for dynamic shapes +- examples/gemm/example_gemm_persistent.py:135 โ€” compare persistent vs nonโ€‘persistent + +Click any path to open the code and compare patterns. + +## Input Tensor Supply + +The tuner needs inputs to compile and benchmark kernels. Provide them in one of +three ways (priority order): + +1) Context manager (fixed inputs across configs) +```python +with set_autotune_inputs(A, B, C): + tuned = matmul(M, N, K) +``` + +2) Custom supplier program +```python +def supply_prog(signature): + # signature holds KernelParam objects describing shapes/dtypes + # Return a list of torch tensors matching the kernelโ€™s arguments + return [A, B, C] + +tuner.set_profile_args(supply_prog=supply_prog) +``` + +3) Builtโ€‘in generators via `supply_type` +- `TensorSupplyType.Auto` (default): heuristic per dtype (uniform ints / fp ranges) +- `Integer`, `Uniform`, `Normal`, `Randn`, `Zero`, `One` + +Important +- Builtโ€‘in generators require static shapes; if your PrimFunc uses symbolic + dimensions (T.dyn), supply concrete inputs via (1) or (2). +- Float8 dtypes require PyTorch 2.1+ for `torch.float8_*` support. + +## Correctness Checking and Tolerances + +Use one of the following validation methods: +- `ref_prog`: Provide a reference program that receives the same inputs and + checks results. You can return a boolean or raise on mismatch. +- `manual_check_prog`: A callable that inspects outputs and raises on mismatch. +- `skip_check=True`: Skip correctness checks (faster, use with caution). + +Control numeric drift via: +- `rtol` and `atol` (defaults 1eโ€‘2) +- `max_mismatched_ratio` (default 1%) + +## Configuration Spaces and Best Practices + +What to tune +- Tile sizes: `block_M`, `block_N`, `block_K` +- Software pipelining: `num_stages` +- Threads per block: `threads` (or (x, y) tuple) +- Optional: dtype variants, epilogues, small scheduling knobs + +Tips +- Start from a working baseline. Tune a small, meaningful space first. +- Respect hardware limits (shared memory bytes, registers per thread/block, + max threads per block). Eliminate impossible configs upโ€‘front. +- Keep block sizes multiples of vector widths and warp sizes when relevant. +- Use `set_autotune_inputs` to ensure each config is measured on identical data. +- Record your best configs and bake them as defaults when stable. + +## Parallel Compilation/Benchmarking and Timeouts + +The tuner compiles configurations in parallel using a thread pool and benchmarks +them with a perโ€‘config timeout. On CUDA, each worker sets the current device to +avoid context issues. + +Notes +- `timeout` uses POSIX signals; on nonโ€‘Unix systems, it may not take effect. +- Logs are written to `autotuner.log` in the working directory. + +## Caching + +The autotuner caches best artifacts both inโ€‘memory (per process) and on disk under +`$TILELANG_CACHE_DIR/autotuner`. The cache key includes: +- TileLang version, function source, closure freeโ€‘vars +- Config list, compile args, profile args + +Disk cache contents (per key) +- Best config and latency: `best_config.json`, `latency.json` +- Kernel sources and library: `device_kernel.cu`, `host_kernel.cu`, `kernel_lib.so` (or `kernel.cubin`/`executable.so` depending on backend) +- Function and params: `function.pkl`, `params.pkl` + +Control via env vars (tilelang.env) +- `TILELANG_CACHE_DIR` (default `~/.tilelang/cache`) +- `TILELANG_TMP_DIR` (default `$TILELANG_CACHE_DIR/tmp`) +- Disable all kernel caches: `TILELANG_DISABLE_CACHE=1` +- Disable autotune disk cache only: `TILELANG_AUTO_TUNING_DISABLE_CACHE=1` + +CPU worker control +- `TILELANG_AUTO_TUNING_CPU_UTILITIES` (fraction, default 0.9) +- `TILELANG_AUTO_TUNING_CPU_COUNTS` (int, `-1` auto) +- `TILELANG_AUTO_TUNING_MAX_CPU_COUNT` (int, `-1` unlimited) + +Backend notes +- NVRTC backend persists `.cubin` and a Python launcher. +- Torch/DLPack backend may not save artifacts to disk; in this case, only + inโ€‘memory caching applies and a warning is logged. + +## Alternative: Manual Sweeps with par_compile + +If you prefer manual control, use `JITImpl.par_compile` to compile a batch of +configs and drive your own benchmarking: + +```python +@tilelang.jit +def factory(M, N, K, block_M=128, block_N=128, block_K=32): + @T.prim_func + def k(A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16')): + ... + return k + +impl = factory # JITImpl +cfgs = [ + dict(block_M=64, block_N=128, block_K=32), + dict(block_M=128, block_N=128, block_K=64), +] +kernels = impl.par_compile(cfgs, num_workers=4) +# Now benchmark kernels[i](A, B, C) yourself +``` + +## Recording and Reusing Best Configs + +The programmatic path returns an `AutotuneResult` that can be saved and later +reloaded. This is useful for CI, multiโ€‘host workflows, or shipping tuned configs. + +```python +artifact = tuner.run() # AutotuneResult + +# Save to disk +from pathlib import Path +save_dir = Path('out/best/matmul_1024') +artifact.save_to_disk(save_dir, verbose=True) + +# Reload later +from tilelang.autotuner.param import AutotuneResult, CompileArgs +restored = AutotuneResult.load_from_disk(save_dir, CompileArgs()) +best = restored.kernel +best(A, B, C) +``` + +Notes +- DLPack/Torch execution backend may not persist compiled binaries; in that + case, reโ€‘compilation is needed on load or use a different backend. +- The directory contains humanโ€‘readable JSONs (best config/latency) and sources. + +## Advanced: Config Space Callables + +Derive config spaces from problem sizes to keep searches targeted and legal: + +```python +def matmul_configs(M, N, K): + large = min(M, N, K) >= 1024 + tiles = [128] if large else [64, 128] + for BM in tiles: + for BN in tiles: + for BK in [32, 64]: + for S in [2, 3]: + for TH in [128, 256]: + yield dict(block_M=BM, block_N=BN, block_K=BK, + num_stages=S, threads=TH) +``` + +## Device and Backend Selection + +Tune compileโ€‘time options explicitly: +- `target='auto'|'cuda'|'hip'|'metal'` (normalized to a TVM Target) +- `execution_backend='auto'|'tvm_ffi'|'ctypes'|'cython'|'nvrtc'|'torch'` +- `pass_configs={...}` to toggle TileLang/TVM passes for experiments + +On CUDA with multiple GPUs, the tuner sets the current device per worker thread +to avoid context mixups. + +## Troubleshooting +- โ€œNo configurations to tuneโ€: Ensure `configs` is a nonโ€‘empty list or callable. +- Timeouts: Increase `timeout`; ensure inputs fit device memory; verify that + your reference check isnโ€™t the bottleneck. +- Dynamic shapes: Provide concrete inputs via `set_autotune_inputs` or a custom + `supply_prog`. +- Disk cache disabled: Check `TILELANG_AUTO_TUNING_DISABLE_CACHE` and backend. diff --git a/docs/programming_guides/control_flow.md b/docs/programming_guides/control_flow.md new file mode 100644 index 0000000000000000000000000000000000000000..158c51166e501a6628618e028ed0bbd904f7a47d --- /dev/null +++ b/docs/programming_guides/control_flow.md @@ -0,0 +1,145 @@ +# Control Flow + +This guide covers the controlโ€‘flow primitives in TileLang and how they lower to +efficient GPU code. You will use these to structure loops, handle boundaries, +and express pipelined compute. + +## Overview +- Conditionals: `if` / `elif` / `else`, ternary (`x if c else y`) +- Loops: `T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined` +- While loops: `while` with a TIR condition +- Flow control: Python `break` / `continue` +- Safety: automatic OOB guards via the LegalizeSafeMemoryAccess pass + +The examples assume `import tilelang.language as T`. + +## Conditionals + +Standard Python `if`/`elif`/`else` is supported inside `@T.prim_func` kernels. +Conditions should be TIR expressions (e.g., `i < N`). Python plain booleans are +treated as compileโ€‘time constants and will be folded. + +```python +for i in T.serial(N): + if i < N: # TIR condition + C[i] = A[i] + B[i] + else: + pass + +# Ternary +x = (A[i] if i < N else 0) +``` + +Shortโ€‘circuit boolean ops are supported. For multiโ€‘dimensional bounds, use +`T.any_of` / `T.all_of` for clarity: + +```python +if T.all_of(i < M, j < N): + C[i, j] = A[i, j] + B[i, j] +``` + +Boundary handling note +- The LegalizeSafeMemoryAccess pass automatically inserts guards when an access + may be outโ€‘ofโ€‘bounds, and elides them when proven safe. You can often omit + explicit `if` checks for simple edge handling, but keep them when you need + custom logic or clarity. + +## Loops + +### Serial + +`T.serial` creates a plain forโ€‘loop. Common forms: + +```python +for i in T.serial(N): + ... # 0..N-1 + +for i in T.serial(0, N, 2): + ... # 0, 2, 4, ... +``` + +### Unroll + +`T.unroll` requests loop unrolling for small trip counts. + +```python +for k in T.unroll(K_TILE): + acc += a[k] * b[k] +``` + +Advanced: TileLang forwards unroll hints to TIR; factor/explicit knobs are +available for expert tuning. + +### Parallel (elementwise) + +`T.Parallel(ext0, ext1, ...)` builds nested loops that map well to elementwise +operations. The body receives all indices in one `for` header: + +```python +for i, j in T.Parallel(M, N): + C[i, j] = A[i, j] + B[i, j] +``` + +Optional: `coalesced_width=` can hint memory coalescing for the innermost loop. + +### Pipelined (software pipelining) + +`T.Pipelined(iters, num_stages=...)` overlaps producer/consumer stages (e.g., +Globalโ†’Shared copies with compute). This is the backbone of GEMM/attention +pipelines. + +```python +for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) # stage: copy A tile + T.copy(B[ko * BK, bx * BN], B_s) # stage: copy B tile + T.gemm(A_s, B_s, C_f) # stage: compute +``` + +### Persistent (advanced) + +`T.Persistent(domain, wave_size, index, group_size=...)` exposes persistent +threadโ€‘block style looping. It is an advanced construct that TileLang lowers in +later passes and is typically used by specialized templates. + +## While Loops + +`while` is supported when the condition is a TIR expression. Avoid infinite +loops; TileLang will error if it detects a constantโ€‘true condition. + +```python +i = 0 +while i < N: + ... + if done: + break + i += 1 +``` + +## Break and Continue + +Use Python `break`/`continue` to exit or skip within `T.serial`/`T.unroll`/ +`T.Parallel`/`while` loops. Keep the body clean after a `break`/`continue` for +readability; the compiler will ignore the dead path. + +## Putting It Together: Residual Tile Handling + +Below is a typical edge pattern for a 2D kernel. With LegalizeSafeMemoryAccess, +the explicit guard can be omitted when you donโ€™t need a custom edge path. + +```python +for i, j in T.Parallel(M, N): + gi = by * BM + i + gj = bx * BN + j + if T.all_of(gi < M, gj < N): # optional in many cases + C[gi, gj] = A[gi, gj] + B[gi, gj] +``` + +## Debugging Conditions + +Use `T.print` to inspect values under predicates. For buffers, TileLang prints +from a single thread to avoid duplicate outputs. + +```python +if i == 0: + T.print(C, msg='C tile:') +``` diff --git a/docs/programming_guides/instructions.md b/docs/programming_guides/instructions.md new file mode 100644 index 0000000000000000000000000000000000000000..84bd9217990003044a97e6c59007486cae64566f --- /dev/null +++ b/docs/programming_guides/instructions.md @@ -0,0 +1,182 @@ +# Instructions + +This page summarizes the core TileLang โ€œinstructionsโ€ available at the DSL +level, how they map to hardware concepts, and how to use them correctly. + +## Quick Categories +- Data movement: `T.copy`, `T.c2d_im2col`, staging Global โ†” Shared โ†” Fragment +- Compute primitives: `T.gemm`/`T.gemm_sp`, elementwise math (`T.exp`, `T.max`), + reductions (`T.reduce_sum`, `T.cumsum`, warp reducers) +- Control helpers: `T.clear`/`T.fill`, `T.reshape`/`T.view` +- Diagnostics: `T.print`, `T.device_assert` +- Advanced: atomics, memory barriers, warpโ€‘group ops + +## Data Movement + +Use `T.copy(src, dst, coalesced_width=None, disable_tma=False, eviction_policy=None)` +to move tiles between memory scopes. It accepts `tir.Buffer`, `BufferLoad`, or +`BufferRegion`; extents are inferred or broadcast when possible. + +```python +# Global โ†’ Shared tiles (extents inferred from dst) +T.copy(A[by * BM, ko * BK], A_s) +T.copy(B[ko * BK, bx * BN], B_s) + +# Fragment/Register โ†’ Global (store result) +T.copy(C_f, C[by * BM, bx * BN]) +``` + +Semantics +- Extents are deduced from arguments; missing sides broadcast to the otherโ€™s rank. +- Access patterns are legalized and coalesced during lowering. Explicit + vectorization is not required in HL mode. +- Safety: the LegalizeSafeMemoryAccess pass inserts boundary guards when an + access may be outโ€‘ofโ€‘bounds and drops them when proven safe. + +Other helpers +- `T.c2d_im2col(img, col, ...)`: convenience for convโ€‘style transforms. + +## Compute Primitives + +GEMM and sparse GEMM +- `T.gemm(A_shared, B_shared, C_fragment)`: computes a tile GEMM using shared + inputs and a fragment accumulator; lowered to targetโ€‘specific tensor cores. +- `T.gemm_sp(...)`: 2:4 sparse tensor core variant (see examples and README). + +Reductions and scans +- `T.reduce_sum`, `T.reduce_max`, `T.reduce_min`, `T.cumsum`, plus warp + reducers (`T.warp_reduce_sum`, etc.). +- Allocate and initialize accumulators via `T.alloc_fragment` + `T.clear` or + `T.fill`. + +Elementwise math +- Most math ops mirror TVM TIR: `T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, + `T.sigmoid`, etc. Compose freely inside loops. + +Reshape/view (no copy) +- `T.reshape(buf, new_shape)` and `T.view(buf, shape=None, dtype=None)` create + new views that share storage, with shape/dtype checks enforced. + +## Synchronization (HL usage) + +In HL pipelines, you usually donโ€™t need to write explicit barriers. Passes such +as PipelinePlanning/InjectSoftwarePipeline/InjectTmaBarrier orchestrate +producer/consumer ordering and thread synchronization behind the scenes. + +If you need debugging or explicit checks: +- `T.device_assert(cond, msg='')` emits deviceโ€‘side asserts on CUDA targets. +- `T.print(obj, msg='...')` prints scalars or buffers safely from one thread. + +## Putting It Together: GEMM Tile + +```python +@T.prim_func +def gemm( + A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16'), +): + with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by): + A_s = T.alloc_shared((BM, BK), 'float16') + B_s = T.alloc_shared((BK, BN), 'float16') + C_f = T.alloc_fragment((BM, BN), 'float32') + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) # Global โ†’ Shared + T.copy(B[ko * BK, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) # compute into fragment + + T.copy(C_f, C[by * BM, bx * BN]) # store back +``` + +## Instruction Reference (Concise) + +Below is a concise list of TileLang instructions grouped by category. For full +signatures, behaviors, constraints, and examples, refer to API Reference +(`autoapi/tilelang/index`). + +Data movement +- `T.copy(src, dst, ...)`: Move tiles between Global/Shared/Fragment. +- `T.c2d_im2col(img, col, ...)`: 2D im2col transform for conv. + +Memory allocation and descriptors +- `T.alloc_shared(shape, dtype, scope='shared.dyn')`: Allocate shared buffer. +- `T.alloc_fragment(shape, dtype, scope='local.fragment')`: Allocate fragment. +- `T.alloc_var(dtype, [init], scope='local.var')`: Scalar var buffer (1 elem). +- `T.alloc_barrier(arrive_count)`: Shared barrier buffer. +- `T.alloc_tmem(shape, dtype)`: Tensor memory (TMEM) buffer (Hopper+). +- `T.alloc_reducer(shape, dtype, op='sum', replication=None)`: Reducer buf. +- `T.alloc_descriptor(kind, dtype)`: Generic descriptor allocator. + - `T.alloc_wgmma_desc(dtype='uint64')` + - `T.alloc_tcgen05_smem_desc(dtype='uint64')` + - `T.alloc_tcgen05_instr_desc(dtype='uint32')` +- `T.empty(shape, dtype='float32')`: Declare function output tensors. + +Compute primitives +- `T.gemm(A_s, B_s, C_f)`: Tile GEMM into fragment accumulator. +- `T.gemm_sp(...)`: Sparse (2:4) tensor core GEMM. +- Reductions: `T.reduce_sum/max/min/abssum/absmax`, bitwise `and/or/xor`. +- Scans: `T.cumsum`, finalize: `T.finalize_reducer`. +- Warp reducers: `T.warp_reduce_sum/max/min/bitand/bitor`. +- Elementwise math: TIR ops (`T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, ...). +- Fast math: `T.__log/__log2/__log10/__exp/__exp2/__exp10/__sin/__cos/__tan`. +- IEEE math: `T.ieee_add/sub/mul/fmaf` (configurable rounding). +- Helpers: `T.clear(buf)`, `T.fill(buf, value)`. +- Views: `T.reshape(buf, shape)`, `T.view(buf, shape=None, dtype=None)`. + +Diagnostics +- `T.print(obj, msg='')`: Print scalar/buffer from one thread. +- `T.device_assert(cond, msg='')`: Device-side assert (CUDA). + +Logical helpers +- `T.any_of(a, b, ...)`, `T.all_of(a, b, ...)`: Multi-term predicates. + +Annotation helpers +- `T.use_swizzle(panel_size=..., enable=True)`: Rasterization hint. +- `T.annotate_layout({...})`: Attach explicit layouts to buffers. +- `T.annotate_safe_value(var, ...)`: Safety/const hints. +- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint. + +Atomics +- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`. +- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`. +- `T.atomic_max(dst, value, memory_order=None, return_prev=False)`. +- `T.atomic_min(dst, value, memory_order=None, return_prev=False)`. +- `T.atomic_load(dst)`, `T.atomic_store(dst, value)`. + +Custom intrinsics +- `T.dp4a(A, B, C)`: 4โ€‘element dotโ€‘product accumulate. +- `T.clamp(x, lo, hi)`: Clamp to [lo, hi]. +- `T.loop_break()`: Break from current loop via intrinsic. + +Barriers, TMA, warpโ€‘group +- Barriers: `T.create_list_of_mbarrier(...)`, `T.get_mbarrier(i)`. +- Parity ops: `T.mbarrier_wait_parity(barrier, parity)`, `T.mbarrier_arrive(barrier)`. +- Expect tx: `T.mbarrier_expect_tx(...)`; sugar: `T.barrier_wait(id, parity=None)`. +- TMA: `T.create_tma_descriptor(...)`, `T.tma_load(...)`, + `T.tma_store_arrive(...)`, `T.tma_store_wait(...)`. +- Proxy/fences: `T.fence_proxy_async(...)`, `T.warpgroup_fence_operand(...)`. +- Warpโ€‘group: `T.warpgroup_arrive()`, `T.warpgroup_commit_batch()`, + `T.warpgroup_wait(num_mma)`, `T.wait_wgmma(id)`. + +Lane/warp index +- `T.get_lane_idx(warp_size=None)`: Lane id in warp. +- `T.get_warp_idx_sync(warp_size=None)`: Canonical warp id (sync). +- `T.get_warp_idx(warp_size=None)`: Canonical warp id (no sync). +- `T.get_warp_group_idx(warp_size=None, warps_per_group=None)`: Group id. + +Register control +- `T.set_max_nreg(reg_count, is_inc)`, `T.inc_max_nreg(n)`, `T.dec_max_nreg(n)`. +- `T.annotate_producer_reg_dealloc(n=24)`, `T.annotate_consumer_reg_alloc(n=240)`. +- `T.no_set_max_nreg()`, `T.disable_warp_group_reg_alloc()`. + + + +## Notes on Dtypes + +Dtypes accept three equivalent forms: +- String: `'float32'` +- TileLang dtype: `T.float32` +- Framework dtype: `torch.float32` +All are normalized internally. See Type System for details. diff --git a/docs/programming_guides/language_basics.md b/docs/programming_guides/language_basics.md new file mode 100644 index 0000000000000000000000000000000000000000..1152680c970460f3c91f871d2b0e82ec73034918 --- /dev/null +++ b/docs/programming_guides/language_basics.md @@ -0,0 +1,234 @@ +# Language Basics + +This page introduces the core TileLang (tileโ€‘lang) DSL that youโ€™ll use to write +highโ€‘performance kernels. It focuses on how to define a kernel, express +iteration, move data across memory scopes, and run it with JIT. + +The examples use the conventional aliases: + +```python +import tilelang +import tilelang.language as T +from tilelang import jit +``` + +## 1. Defining a Kernel with `@T.prim_func` + +TileLang kernels are TIR (TVM IR) functions produced by the `@T.prim_func` +decorator. Arguments are annotated with shapes and dtypes via `T.Tensor` or +`T.Buffer`. + +Note on dtypes +- You can pass dtypes as a string (e.g., 'float32'), a TileLang dtype (e.g., `T.float32`), + or a framework dtype (e.g., `torch.float32`). TileLang normalizes all of these. + See Type System for details. + +```python +@T.prim_func +def add_kernel( + A: T.Tensor((N,), dtype), # dtype could be 'float32' | T.float32 | torch.float32 + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), +): + ... # kernel body +``` + +- Shapes may be concrete integers or symbolic. For symbolic, you can pass + Python ints through the outer `@jit` wrapper (shown below), or annotate with + `T.dyn` when you want a named symbolic dimension. + +```python +# Named symbolic dimension (optional) +K = T.dyn['K'] +@T.prim_func +def uses_dyn(A: T.Tensor((K,), 'float32')): + ... +``` + +### Dynamic symbolic dimensions: two ways + +TileLang supports two complementary ways to introduce symbolic (dynamic) dims: + +- Type-level annotations via `T.dyn[...]` (recommended for function signatures) + - Use in `T.Tensor((T.dyn['K'], ...), dtype)` or bind once then reuse (as above). + - Inside the kernel body, prefer reading from the bufferโ€™s shape, e.g. `M = A.shape[0]`. + +- Term-level variables via `T.dynamic(name, dtype)` + - Creates a TIR `tir.Var` you can use directly in expressions/loops. + - Handy when you need to reference the dimension symbol in the body. + +```python +# 1) Annotation-only symbol; read the bound size via shape +K = T.dyn['K'] # dtype defaults to int32 +@T.prim_func +def foo(A: T.Tensor((K,), 'float32')): + N = A.shape[0] + for i in T.serial(N): + ... + +# 2) Explicit Var symbol usable in the body +K = T.dynamic('K', 'int32') # or T.dynamic('K') defaults to int32 +@T.prim_func +def bar(A: T.Tensor((K,), 'float32')): + for i in T.serial(K): + ... +``` + +Notes +- `T.symbolic(name, dtype)` is a deprecated alias of `T.dynamic`; prefer `T.dynamic`. +- Under `@jit`, concrete sizes come from the actual tensor arguments at the first call. +- Symbols in annotations do not need to be separate kernel arguments; TileLang binds them from argument shapes. + +## 2. Launching Work with `T.Kernel` + +`with T.Kernel(...)` declares a launch context and creates block/thread +bindings. For GPU backends, specify a grid and threads per block. + +```python +with T.Kernel(grid_x, grid_y, threads=128) as (bx, by): + ... # bx/by are blockIdx.x/y +``` + +You rarely need raw thread indices; most kernels use structured loops +(`T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined`) inside a `T.Kernel`. + +## 3. Loops and Control Flow + +Core loop constructs map to familiar hardware patterns: + +- `T.serial(start, stop[, step])`: plain forโ€‘loop +- `T.unroll(start, stop[, step])`: unrolled loop +- `T.Parallel(ext0, ext1, ...)`: nested parallel loops (elementwiseโ€‘friendly) +- `T.Pipelined(iters, num_stages=N)`: software pipelining for producer/consumer + +```python +for i in T.serial(N): + ... + +for i, j in T.Parallel(M, N): + C[i, j] = A[i, j] + B[i, j] + +for k in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + # overlap copy/compute across stages + ... +``` + +Conditionals use standard Python `if`/`else`. Guard edges with predicates when +tile sizes do not divide problem sizes evenly. + +## 4. Memory Scopes and Allocation + +TileLang exposes key softwareโ€‘managed scopes: + +- Global: device memory (default for `T.Tensor` arguments) +- Shared: onโ€‘chip, blockโ€‘visible (`T.alloc_shared(shape, dtype)`) +- Fragment and scalars: perโ€‘thread fragments and scalar vars but in Shared View + (`T.alloc_fragment`, `T.alloc_var`) + +```python +A_shared = T.alloc_shared((BM, BK), 'float16') +B_shared = T.alloc_shared((BK, BN), 'float16') +C_local = T.alloc_fragment((BM, BN), 'float32') +T.clear(C_local) # zero accumulators +``` + +## 5. Moving Data: `T.copy` + +Use `T.copy(src, dst)` to move tiles between scopes. It accepts buffers, +buffer regions, or buffer loads; extents are inferred or can be broadcast. + +```python +# Global -> Shared (tile copy), extents inferred from dst +T.copy(A[by * BM, ko * BK], A_shared) +T.copy(B[ko * BK, bx * BN], B_shared) + +# Fragment -> Global (store back) +T.copy(C_local, C[by * BM, bx * BN]) +``` + +`T.copy` performs coalescing and scopeโ€‘specific lowering during compilation. + +## 6. A Minimal Endโ€‘toโ€‘End Example (Vector Add) + +```python +import tilelang +import tilelang.language as T +from tilelang import jit + +@jit # infers target from tensors at first call +def add(N: int, block: int = 256, dtype: str = 'float32'): + + @T.prim_func + def add_kernel( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block), threads=block) as bx: + for i in T.Parallel(block): + gi = bx * block + i + # Optional โ€” LegalizeSafeMemoryAccess inserts a guard when an access may be OOB + C[gi] = A[gi] + B[gi] + + return add_kernel + +# Host side (PyTorch shown; NumPy/DLPack also supported) +import torch +N = 1 << 20 +A = torch.randn(N, device='cuda', dtype=torch.float32) +B = torch.randn(N, device='cuda', dtype=torch.float32) +C = torch.empty(N, device='cuda', dtype=torch.float32) + +kernel = add(N) +kernel(A, B, C) # runs on GPU +torch.testing.assert_close(C, A + B) +``` + +Notes +- The `@jit` wrapper returns a callable kernel after the first compilation. +- You can pass compileโ€‘time tunables (tile sizes, dtypes) through the outer + Python function and bake them into the generated TIR. + +## 7. Tiled GEMM Skeleton + +Below is a minimal pattern for a tiled GEMM using shared memory staging and a +fragment accumulator. It mirrors the quickstart style found in the repository. + +```python +@T.prim_func +def gemm( + A: T.Tensor((M, K), 'float16'), + B: T.Tensor((K, N), 'float16'), + C: T.Tensor((M, N), 'float16'), +): + with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by): + A_s = T.alloc_shared((BM, BK), 'float16') + B_s = T.alloc_shared((BK, BN), 'float16') + C_f = T.alloc_fragment((BM, BN), 'float32') + T.clear(C_f) + + for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3): + T.copy(A[by * BM, ko * BK], A_s) + T.copy(B[ko * BK, bx * BN], B_s) + T.gemm(A_s, B_s, C_f) # lowered to tensorโ€‘core/ISA specific kernels + + T.copy(C_f, C[by * BM, bx * BN]) +``` + +## 8. Debugging and Printing + +Use `T.print` inside a kernel for quick introspection. TileLang emits printing +from a single thread for shared/fragment scopes to avoid floods. + +```python +T.print(C_f, msg='accumulator:') +T.print(A_s, msg='A tile:') +T.print(C[0], msg='C[0] = ') +``` + +## 9. Where to Go Next + +- Control flow details: see Programming Guides โ†’ Control Flow +- Memory topics: see Programming Guides โ†’ (removed cache/layout); basics are covered inline +- Autotuning tile sizes and mappings: Programming Guides โ†’ Autotuning +- Operator examples (GEMM, GEMV, attention): see Deep Learning Operators diff --git a/docs/programming_guides/overview.md b/docs/programming_guides/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..64b6d20390bd7350001a9882abda10906cea2874 --- /dev/null +++ b/docs/programming_guides/overview.md @@ -0,0 +1,27 @@ +# Programming Guides Overview + +This section provides a practical guide to writing highโ€‘performance kernels with Tile Language (tileโ€‘lang). +It mirrors the structure of a similar guide in another project and adapts it to tileโ€‘lang concepts and APIs. + +- Audience: Developers implementing custom GPU/CPU kernels with tileโ€‘lang +- Prereqs: Basic Python, NumPy/Tensor concepts, and familiarity with GPU programming notions +- Scope: Language basics, control flow, instructions, autotuning, and type system + +## What Youโ€™ll Learn +- How to structure kernels with TileLangโ€™s core DSL constructs +- How to move data across global/shared/fragment and pipeline compute +- How to apply autotuning to tile sizes and schedules +- How to specify and work with dtypes in kernels + +## Suggested Reading Order +1. Language Basics +2. Control Flow +3. Instructions +4. Autotuning +5. Type System + +## Related Docs +- Tutorials: see existing guides in `tutorials/` +- Operators: examples in `deeplearning_operators/` + +> NOTE: This is a draft scaffold. Fill in code snippets and benchmarks as APIs evolve. diff --git a/docs/programming_guides/type_system.md b/docs/programming_guides/type_system.md new file mode 100644 index 0000000000000000000000000000000000000000..32b9274d7c48bcd3b7ebfcdc4b35a56862b86f10 --- /dev/null +++ b/docs/programming_guides/type_system.md @@ -0,0 +1,42 @@ +# Type System + +This page lists the data types supported by TileLang and how to specify them in +kernels. For full details and the authoritative list, see the API Reference +(`autoapi/tilelang/index`) and `tilelang.language.v2.dtypes`. + +How to specify dtypes +- Use any of the following forms; TileLang normalizes them internally: + - String: `'float32'`, `'int8'`, `'bfloat16'`, ... + - TileLang dtype object: `T.float32`, `T.int8`, `T.bfloat16`, ... + - Framework dtype: `torch.float32`, `torch.int8`, `torch.bfloat16`, ... + +Common scalar types +- Boolean: `bool` +- Signed integers: `int8`, `int16`, `int32`, `int64` +- Unsigned integers: `uint8`, `uint16`, `uint32`, `uint64` +- Floatingโ€‘point: `float16` (half), `bfloat16`, `float32`, `float64` + +Float8 and lowโ€‘precision families +- Float8: `float8_e3m4`, `float8_e4m3`, `float8_e4m3b11fnuz`, `float8_e4m3fn`, + `float8_e4m3fnuz`, `float8_e5m2`, `float8_e5m2fnuz`, `float8_e8m0fnu` +- Float6: `float6_e2m3fn`, `float6_e3m2fn` +- Float4: `float4_e2m1fn` + +Vectorized element types (SIMD packs) +- For many base types, vectorโ€‘packed variants are available by lane count: + `x2`, `x4`, `x8`, `x16`, `x32`, `x64`. +- Examples: + - Integers: `int8x2`, `int8x4`, ..., `int32x2`, `int32x4`, ... + - Unsigned: `uint8x2`, `uint8x4`, ... + - Floats: `float16x2`, `float16x4`, `float32x2`, `float32x4`, ... + - Float8/6/4 families also provide `x2/x4/x8/x16/x32/x64` where applicable, + e.g., `float8_e4m3x2`, `float8_e4m3x4`, `float6_e2m3fnx8`, `float4_e2m1fnx16`. + +Notes +- Availability of certain lowโ€‘precision formats (float8/6/4) depends on target + architecture and backend support. +- Choose accumulation dtypes explicitly for mixedโ€‘precision compute (e.g., + GEMM with `float16` inputs and `float32` accumulators). +- The complete, upโ€‘toโ€‘date list is exposed in + `tilelang.language.v2.dtypes` and rendered in the API Reference. + diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..63b64db21c52b066942f05082c71ae78c10fd2b3 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,13 @@ +fastapi +pydantic +sphinx +sphinx-reredirects +sphinx-tabs +sphinx-toolbox +sphinxcontrib-napoleon +sphinxcontrib_httpdomain +furo +uvicorn +myst-parser +sphinx-autoapi == 3.6.0 +astroid < 4 diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt new file mode 100644 index 0000000000000000000000000000000000000000..e859d0e7b109baafea24b390c4c0331393950123 --- /dev/null +++ b/docs/spelling_wordlist.txt @@ -0,0 +1,8 @@ +cancelled +hsa +ist +LOD +nd +NotIn +offen +te diff --git a/docs/tutorials/auto_tuning.md b/docs/tutorials/auto_tuning.md new file mode 100644 index 0000000000000000000000000000000000000000..3f3cad832232898017a439deff4fb84cc499a412 --- /dev/null +++ b/docs/tutorials/auto_tuning.md @@ -0,0 +1,148 @@ +Auto-Tuning Techniques for Performance Optimization +=================================================== +
+Author: yyttt6 +
+ +## Overview + +Auto-tuning a Tile Language program involves three main steps: + +1. Implement the target program using Tile Language with reserved optimization parameters +2. โ€‹Provide candidate configurations through manual search or [auto-generation using Carver](#using-carver-to-auto-generate-candidate-configurations) +3. Parallel compile and benchmark candidate configurations to identify the best performance + +## Matrix Multiplication Example + +The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation. + +### Step 1: Implement with Reserved Parameters +Users can implement matrix multiplication in Tile Language while reserving parameters for optimization: +```python +# Reserved parameters for optimization +def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, +): + dtype = "float16" + accum_dtype = "float" + + # Matrix multiplication implementation + @T.prim_func + def main( + A: T.Buffer((M, K), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((M, N), dtype), + ): + # ...existing code... + + return main +``` +### Step 2: Generate Candidate Configurations +Manually define configurations or use combinatorial generation: +```python +configs = [ + { + "block_M": 128, + "block_N": 128, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "enable_rasteration": True + }, + { + "block_M": 32, + "block_N": 32, + "block_K": 32, + "num_stages": 0, + "thread_num": 32, + "enable_rasteration": False + }, + # ...additional configurations... +] +``` +It can also be given by combinatorial traversal of different parameters +```python +import itertools + +block_M = [64, 128, 256] +block_N = [64, 128, 256] +block_K = [32, 64] +num_stages = [0, 1, 2, 3] +thread_num = [128, 256] +enable_rasterization = [True, False] +_configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + )) + +configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5] + } for c in _configs +] +``` +### Step 3: Compile and Benchmark +Configure JIT compilation and benchmarking settings: +```python +autotuner = AutoTuner.from_kernel( + kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( + out_idx=[-1], + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=False, + target="auto", + ) +result = autotuner.run(warmup=3, rep=20) +out_c = result.kernel(a, b) +``` +The result object contains optimized kernel implementation which can be used by users directly + +## Using Carver to Auto-Generate Candidate Configurations + +Carver is a lightweight framework for generating and ranking tile configurations (also known as tiling strategies, blocking schemes, or scheduling hints) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels. + +or common operators, Carver provides pre-built templates (e.g., `MatmulTemplate`): + +```python +# Configure Matmul template +arch = CUDA("cuda") +carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float", +).with_arch(arch) + +# Generate top-k optimization hints (topk=10 recommended) +roller_hints = carve_template.recommend_hints(topk=10) + +# Configure candidate parameters +for hint in roller_hints: + + # ...existing code... + + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = hint.pipeline_stage + config["thread_num"] = block_rows * block_cols * 32 + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + +``` \ No newline at end of file diff --git a/docs/tutorials/debug_tools_for_tilelang.md b/docs/tutorials/debug_tools_for_tilelang.md new file mode 100644 index 0000000000000000000000000000000000000000..f8dfaab826e0c2896a922a915cce44b2f9ab8153 --- /dev/null +++ b/docs/tutorials/debug_tools_for_tilelang.md @@ -0,0 +1,204 @@ +# Debugging Tile Language Programs + +
+Author: Lei Wang +
+ +## Overview + +A Tile Language program (hereafter referred to as a *program*) is transformed into a hardware-executable file through several stages: + +1. The user writes a Tile Language program. +2. The program undergoes multiple *Passes* for transformation and optimization (the *lower* stage, see `tilelang/engine/lower.py`), finally producing an intermediate representation (e.g., LLVM or C for CPU, CUDA for NVIDIA GPUs, etc.). +3. The generated code is compiled by the respective compiler (e.g., nvcc) into a hardware-executable file. + + +```{figure} ../_static/img/overview.png +:width: 300 +:alt: Overview of the compilation process +:align: center + +``` + +During this process, users may encounter roughly three categories of issues: + +* **Generation issues**: The Tile Language program fails to generate a valid hardware-executable file (i.e., errors during the lowering process). +* **Correctness issues**: The resulting executable runs, but produces incorrect results. +* **Performance issues**: The executable runs with performance significantly below the expected theoretical hardware limits. + +This tutorial focuses on the first two issuesโ€”how to debug generation and correctness problems. Performance tuning often requires using vendor-provided profiling tools (e.g., **Nsight Compute**, **rocProf**, etc.) for further hardware-level analysis, which we will address in future materials. + +Below, we take matrix multiplication (GEMM) as an example to demonstrate how to write and debug a Tile Language program. + +## Matrix Multiplication Example + +In **Tile Language**, you can use the **Tile Library** to implement matrix multiplication. Here's a complete example: + +```python +import tilelang +import tilelang.language as T + +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + # ...existing code... + +# 1. Define the kernel (matmul) with the desired dimensions +func = matmul(1024, 1024, 1024, 128, 128, 32) + +# 2. Compile the kernel into a torch function +# ...existing code... +``` + +## Debugging Generation Issues + +TileLang essentially performs *progressive lowering*. For example, a `T.copy` may first be expanded into `T.Parallel` (see the pass `LowerTileOP`), which is then expanded again, eventually resulting in lower-level statements that can be translated to CUDA C code. + + +```{figure} ../_static/img/ir_transform_diagram.png +:width: 400 +:alt: IR transformation diagram +:align: center + +``` + +When the code fails to generate (for instance, a compilation error occurs), you do **not** necessarily need to jump directly into C++ passes to debug. Instead, you can first inspect the intermediate representations (IR) in Python by printing them. + +For example, consider a case where a simple `T.copy` in 1D causes the lowering process to fail. The snippet below illustrates a simplified version of the problem (based on community Issue #35): + +```python +@T.prim_func +def main(Q: T.Tensor(shape_q, dtype)): + # ...existing code... +``` + +The TileLang lower process might yield an error such as: + +```text +File "/root/TileLang/src/target/codegen_cuda.cc", line 1257 +ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed. +``` + +This indicates that somewhere during code generation, an unsupported vectorization pattern was introduced (a ramp of 8 lanes). Before diving into the underlying C++ code, it is helpful to print the IR right before code generation. For instance: + +```python +device_mod = tir.transform.Filter(is_device_call)(mod) +# ...existing code... +``` + +## Debugging Correctness Issues + +Sometimes, the kernel compiles and runs but produces incorrect results. In such cases, there are two main strategies to help debug: + +1. **Use post-processing callbacks to inspect or modify the generated CUDA code.** +2. **Use the built-in `T.print` debugging primitive to inspect values at runtime.** + +### Post-Processing Callbacks for Generated Source + +After code generation (in the codegen pass), TileLang calls a callback function (if registered) to allow post-processing of the generated source code. In `src/target/rt_mod_cuda.cc`: + +```cpp +std::string code = cg.Finish(); +if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) { + code = (*f)(code, target).operator std::string(); +} +``` + +Hence, by registering a Python function named `tilelang_callback_cuda_postproc`, you can intercept the final CUDA code string. For example: + +```python +import tilelang +import tilelang.language as T +from tilelang import tvm +from tilelang.engine.callback import register_cuda_postproc_callback + +@register_cuda_postproc_callback +def tilelang_callback_cuda_postproc(code, _): + print(code) # print the final CUDA code + code = "// modified by tilelang_callback_cuda_postproc\n" + code + return code + +kernel = tilelang.compile(matmul, target="cuda") +kernel_source = kernel.get_kernel_source() +print(kernel_source) +''' +// modified by tilelang_callback_cuda_postproc +#include "cuda_runtime.h" +... +''' +``` + +### Runtime Debug Prints with `T.print` + +TileLang provides a built-in debugging primitive called `T.print` for printing within kernels. Be mindful of concurrency and thread synchronization when using it in GPU code. Below are some examples showing how to print buffers, variables, and other data inside TileLang programs. + +1. **Printing an Entire Buffer** + +```python +def debug_print_buffer(M=16, N=16): + # ...existing code... +``` + +2. **Conditional Printing** + +```python +def debug_print_buffer_conditional(M=16, N=16): + # ...existing code... +``` + +3. **Printing Thread Indices or Scalar Values** + +```python +def debug_print_value_conditional(M=16, N=16): + # ...existing code... +``` + +4. **Printing Fragment (Register File) Contents** + +```python +def debug_print_register_files(M=16, N=16): + # ...existing code... +``` + +5. **Adding a Message Prefix** + +```python +def debug_print_msg(M=16, N=16): + # ...existing code... +``` + +The output messages will include something like: + +```text +msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0 +``` + +### Visual Layout Inference For TileLang + The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations. + +When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates: +1. **Textual output**: A human-readable description of the layout mapping +2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping + +The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation. + +When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats: +- "txt": Text output only (same as default) +- "all": Generates all formats (TXT, PDF, PNG, SVG) +- "png": Generate PNG format only +- "pdf": Generate PDF format only +- "svg": Generate SVG format only +- "txt,svg": Generate multiple formats (comma-separated) in addition to text output + +The output messages of "txt" will include something like: +``` +C_local inferenced layout: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] +``` + + +## Conclusion + +By carefully examining intermediate representations (IR) before final code generationโ€”and by leveraging runtime printing through `T.print`โ€”one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs. + +For advanced performance tuning (e.g., analyzing memory bandwidth or occupancy), more specialized profiling tools such as **Nsight Compute**, **rocProf**, or vendor-specific profilers may be required. Those aspects will be covered in future documents. diff --git a/docs/tutorials/logging.md b/docs/tutorials/logging.md new file mode 100644 index 0000000000000000000000000000000000000000..5caf432801cd7d0dfa1b5dfb2eef5c9e824e937c --- /dev/null +++ b/docs/tutorials/logging.md @@ -0,0 +1,118 @@ +Logging in Tilelang/TVM +=================================================== +
+Author: SiriusNEO +
+ +## TVM Logging Overview + +Tilelang currently utilizes the logging system from TVM. The implementation can be found in: + +- [include/tvm/runtime/logging.h](https://github.com/apache/tvm/blob/main/include/tvm/runtime/logging.h): Macro definitions +- [src/runtime/logging.cc](https://github.com/apache/tvm/blob/main/src/runtime/logging.cc): Logging logic implementation + +The design style is inspired by [Google's glog](https://google.github.io/glog/stable/). + +## Logging Categories + +There are three primary macro types: + +```c++ +LOG(INFO) << "aaa"; +DLOG(INFO) << "aaa"; +VLOG(1) << "aaa"; +``` + +- **LOG**: Standard logging preserved in code for displaying necessary information at different levels during runtime. Most Tilelang C++ error reporting is implemented via `LOG(FATAL) << "error msg"`. +- **DLOG**: Debug logging for developer debugging output. DLOG is controlled at build time by the TVM_LOG_DEBUG environment variable and is **eliminated in Release builds through dead code elimination**. + - The key difference between LOG(DEBUG) and DLOG is this build-time elimination. We recommend using DLOG over LOG(DEBUG), as the latter has overlapping functionality and gets compiled into the release runtime. +- **VLOG**: [Verbose logging](https://google.github.io/glog/stable/logging/#verbose-logging), primarily for debugging. Its main feature is customizable verbosity levels. For example, VLOG(n) where n can be 1, 2, 3, 4, 5, or 6, enabling complex tracing requirements. In contrast, LOG and DLOG typically use predefined verbose levels like INFO and DEBUG. + - In practical Tilelang development, VLOG is used less frequently. + - TVM's VLOG is implemented using DLOG, thus inheriting DLOG's characteristics. + +Additional useful macros include various **CHECK** variants: + +```c++ +CHECK(cond) << "error msg"; +DCHECK(cond) << "error msg"; +ICHECK(cond) << "error msg"; +``` + +The implementation routes errors to LogFatal: + +```c++ +#define CHECK(x) \ + if (!(x)) \ + ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ + << "Check failed: (" #x << ") is false: " +``` +- **DCHECK**: Debug mode CHECK, only compiled in debug builds +- **ICHECK**: Internal Check that should exist in Release builds. When ICHECK fails, the entire system should report an error. + +## Logging Verbose Levels + +TVM defines 5 levels for LOG and DLOG (adding DEBUG compared to glog): + +```c++ +#define TVM_LOG_LEVEL_DEBUG 0 +#define TVM_LOG_LEVEL_INFO 1 +#define TVM_LOG_LEVEL_WARNING 2 +#define TVM_LOG_LEVEL_ERROR 3 +#define TVM_LOG_LEVEL_FATAL 4 +``` + +## Using Logging in TileLang Development + +### Guidelines + +For temporary debugging output in your code, there are no restrictions (you can even use std::cout). Just remember to remove it before submitting a PR. + +For meaningful logging that should remain in the Tilelang codebase: + +- Critical correctness checks: Use ICHECK with sufficient error messages to facilitate debugging when issues arise. +- Complex Pass debugging: For passes requiring intermediate output that may need future review (e.g., LayoutInference), use DLOG. +- General INFO/WARNING messages: Use standard LOG. + +### Enabling Log Output in Tilelang + +To specify current log level at runtime, we need to set the environment variable `TVM_LOG_LEVEL`. An example usage is: + +```c++ +TVM_LOG_DEBUG=1 python3 code.py +``` + +which enables all DEBUG/INFO (level <= 1) logs for all files. + +#### Detailed Rules for TVM_LOG_DEBUG Specification + +The parsing logic is in `logging.cc`. Reference: [HyperAI Zhihu Article](https://zhuanlan.zhihu.com/p/1933106843468665163). + +Launch Python with `TVM_LOG_DEBUG=`, where `` is a comma-separated list of level assignments in the form `=`. Important notes: + +- The special filename DEFAULT sets the LOG level for all files. +- `` can be set to -1 to disable LOG for that file. +- `` is the C++ source filename (e.g., .cc, not .h) relative to the `src/` directory in the TVM repository. The `src/` prefix is optional when specifying file paths. + +### Enabling Debug Mode + +To enable DLOG/DCHECK, developers need to first build Tilelang in Debug mode: + +```bash +cmake .. -DCMAKE_BUILD_TYPE=Debug -DUSE_CUDA=ON +``` + +Tilelang's CMake logic automatically adds the `TVM_LOG_DEBUG` macro, compiling all DLOG statements: + +```cmake +target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") +``` + +Then you also need to specify the runtime environment variables. For example, to use `DLOG(INFO) << "xxx"` for debugging, run your code with INFO level (1): `TVM_LOG_DEBUG=1`. + +:::{note} + **Important**: There are two TVM_LOG_DEBUG variables. (1) Compile-time macro: Determines whether debug content (like DLOG) is compiled into the .so file. Referenced in C++ source via #ifdef TVM_LOG_DEBUG. This is automatically enabled when using Debug build mode in CMake. (2) Runtime environment variable: Controls logging level at runtime. TVM provides a specification for this variable, allowing control over per-file logging levels. + + These two should ideally have different names, but TVM uses the same name for both, which can cause confusion. +::: + + diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..788aec367c45a56bda3bd001c3c8b5e6c6ebfe45 --- /dev/null +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -0,0 +1,590 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.tileop.base import GemmWarpPolicy +import itertools +import argparse +from functools import partial +import numpy as np +import time + + +def ref_program(Q, K, V, is_causal, groups=1): + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + dim = Q.size(-1) + K_ref = K.repeat_interleave(groups, dim=2) + V_ref = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref) + lse = torch.logsumexp(scores, dim=-1).float() + return output, lse + + +def get_fwd_configs(): + block_M = [32, 64, 128, 256] + block_N = [32, 64, 128, 256] + threads = [128, 256, 512] + num_split_q = [64, 128, 256] + num_stages = [0, 1] + enable_rasterization = [True] + k_pack = [2] + panel_size = [7, 8, 9, 10] + qk_coalesced_width = [8] + v_coalesced_width = [4] + + valid_configs = [] + + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) + return valid_configs + + +@tilelang.autotune(configs=get_fwd_configs(), cache_input_tensors=True) +@tilelang.jit(out_idx=[3, 4]) +def fast_flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_split_q: int, + threads: int, + num_stages: int, + enable_rasterization: bool, + k_pack: int, + panel_size: int, + qk_coalesced_width: int, + v_coalesced_width: int, +): + scale = (1.0 / dim) ** 0.5 + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + vec_size = qk_coalesced_width + v_vec_size = v_coalesced_width + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + LSE: T.Tensor([batch, heads, seq_len], accum_dtype), + ): + with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): + T.use_swizzle(panel_size, enable=enable_rasterization) + + bz = byz_combined // heads + by = byz_combined % heads + + num_q_blocks = T.ceildiv(seq_len, block_M) + + bx_loop_var = T.alloc_var(T.int32) + bx_loop_var = b_split + + with T.While(bx_loop_var < num_q_blocks): + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + m_i = T.alloc_fragment([block_M], accum_dtype) + l_i = T.alloc_fragment([block_M], accum_dtype) + + T.fill(acc_o, 0) + T.fill(m_i, -T.infinity(accum_dtype)) + T.fill(l_i, 0) + + current_bx = bx_loop_var + q_block_offset = current_bx * block_M + + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + m_prev = T.alloc_fragment([block_M], accum_dtype) + scale_factor = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) + + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + + row_sum = T.alloc_fragment([block_M], accum_dtype) + + for k in T.Pipelined(loop_end_k, num_stages=num_stages): + kv_idx = k * block_N + + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + k_pack=k_pack, + policy=GemmWarpPolicy.FullRow, + ) + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = acc_s[i, j] * scale + + T.copy(m_i, m_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) + + for i in T.Parallel(block_M): + if m_prev[i] == -T.infinity(accum_dtype): + scale_factor[i] = 0.0 + else: + scale_factor[i] = T.exp(m_prev[i] - m_i[i]) + + l_i[i] *= scale_factor[i] + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scale_factor[i] + + for i, j in T.Parallel(block_M, block_N): + if acc_s[i, j] == -T.infinity(acc_s.dtype): + acc_s[i, j] = 0.0 + else: + acc_s[i, j] = T.exp(acc_s[i, j] - m_i[i]) + + T.reduce_sum(acc_s, row_sum, dim=1) + for i in T.Parallel(block_M): + l_i[i] += row_sum[i] + + T.copy(acc_s, acc_s_cast) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow) + + l_inv = T.alloc_fragment([block_M], accum_dtype) + for i in T.Parallel(block_M): + safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0) + l_inv[i] = 1.0 / safe_l + + for i, j in T.Parallel(block_M, dim): + Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i] + + for i in T.Parallel(block_M): + if q_block_offset + i < seq_len: + lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype)) + LSE[bz, by, q_block_offset + i] = lse_val + + bx_loop_var = current_bx + num_split_q + + return main + + +def get_bwd_configs(): + block_M = [16, 32, 64, 128, 256] + block_N = [16, 32, 64, 128, 256] + threads = [64, 128, 256, 512, 1024] + num_stages = [0, 1, 2] + enable_rasterization = [True] + panel_size = [7, 8, 9, 10] + + configs = [] + for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size): + configs.append( + { + "block_M": m, + "block_N": n, + "num_stages": stages, + "threads": t, + "enable_rasterization": r, + "panel_size": p, + } + ) + + return configs + + +@tilelang.jit(out_idx=[2]) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): + with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +@tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True) +@tilelang.jit +def flashattn_bwd( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_stages: int, + threads: int, + enable_rasterization: bool, + panel_size: int, +): + sm_scale = (1.0 / dim) ** 0.5 + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + T.use_swizzle(panel_size, enable=enable_rasterization) + + K_shared = T.alloc_shared([block_M, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + q_shared = T.alloc_shared([block_N, dim], dtype) + do_shared = T.alloc_shared([block_N, dim], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta_shared = T.alloc_shared([block_N], accum_dtype) + ds_shared = T.alloc_shared([block_M, block_N], dtype) + + p_cast = T.alloc_fragment([block_M, block_N], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + P_acc = T.alloc_fragment([block_M, block_N], accum_dtype) + dP = T.alloc_fragment([block_M, block_N], accum_dtype) + + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared) + T.clear(qkT) + + T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + + for i, j in T.Parallel(block_M, block_N): + P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j]) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0) + + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared) + T.clear(dP) + + T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(P_acc, p_cast) + T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared) + + for i, j in T.Parallel(block_M, block_N): + p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale + + T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(p_cast, ds_shared) + T.clear(dq) + T.gemm(ds_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + + for i, j in T.Parallel(block_M, dim): + T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j]) + T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk[i, j]) + + return flash_bwd_kernel + + +@tilelang.jit(out_idx=[1]) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.copy( + dQ_in[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], + ) + + return flash_bwd_post + + +def debug_tensor_comparison(tensor1, tensor2, name, rtol=1e-3, atol=1e-3): + print(f"\n=== {name} Comparison ===") + print(f"Shape: {tensor1.shape} vs {tensor2.shape}") + print(f"Data type: {tensor1.dtype} vs {tensor2.dtype}") + print(f"Device: {tensor1.device} vs {tensor2.device}") + + diff = torch.abs(tensor1 - tensor2) + max_diff = diff.max().item() + mean_diff = diff.mean().item() + std_diff = diff.std().item() + + print(f"Max difference: {max_diff:.6f}") + print(f"Mean difference: {mean_diff:.6f}") + print(f"Difference std: {std_diff:.6f}") + + if max_diff > atol: + max_idx = torch.argmax(diff) + max_idx = np.unravel_index(max_idx.cpu().numpy(), tensor1.shape) + print(f"Max difference position: {max_idx}") + print(f"Value1: {tensor1[max_idx].item():.6f}, Value2: {tensor2[max_idx].item():.6f}") + + nan_count1 = torch.isnan(tensor1).sum().item() + nan_count2 = torch.isnan(tensor2).sum().item() + inf_count1 = torch.isinf(tensor1).sum().item() + inf_count2 = torch.isinf(tensor2).sum().item() + + print(f"NaN count: {nan_count1} vs {nan_count2}") + print(f"Inf count: {inf_count1} vs {inf_count2}") + + relative_diff = diff / (torch.abs(tensor2) + 1e-8) + max_relative_diff = relative_diff.max().item() + mean_relative_diff = relative_diff.mean().item() + + print(f"Max relative difference: {max_relative_diff:.6f}") + print(f"Mean relative difference: {mean_relative_diff:.6f}") + + close = torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol) + print(f"Within tolerance (rtol={rtol}, atol={atol}): {close}") + + return close, max_diff, mean_diff + + +def benchmark_function(func, *args, warmup=10, repeat=100): + for _ in range(warmup): + func(*args) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + times = [] + for _ in range(repeat): + start = time.time() + func(*args) + if torch.cuda.is_available(): + torch.cuda.synchronize() + end = time.time() + times.append((end - start) * 1000) + + return np.median(times) + + +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): + device = "cuda" + dtype = torch.float16 + + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}") + + flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 5 * flops_per_gemm + + print(f"Total FLOPs: {total_flops / 1e12:.2f} TFlops") + + q = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype) + k = torch.randn(batch, seq_len, heads // groups, dim, device=device, dtype=dtype) + v = torch.randn(batch, seq_len, heads // groups, dim, device=device, dtype=dtype) + dO = torch.randn_like(q) + + print("Starting autotuning for Fast FlashAttention-V2 Forward Pass...") + fwd_kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups) + if fwd_kernel is None or fwd_kernel.config is None: + print("Forward pass auto-tuning failed.") + return + print(f"Autotuning finished. Best Forward Configuration: {fwd_kernel.config}") + + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) + + profiler = fwd_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + print("Verifying correctness...") + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("Forward pass is correct.") + + o_tl, lse_tl = fwd_kernel(q, k, v) + + bwd_prep = flashattn_bwd_preprocess(batch, heads, seq_len, dim) + delta_tl = bwd_prep(o_tl, dO) + + print("\nStarting FlashAttention-V2 backward pass autotuning...") + bwd_kernel = flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups) + if bwd_kernel is None or bwd_kernel.config is None: + print("Backward pass autotuning failed.") + return + print(f"Autotuning completed. Best backward pass configuration: {bwd_kernel.config}") + + dQ_accum = torch.zeros_like(q, dtype=torch.float32) + dK_tl = torch.zeros_like(k, dtype=torch.float32) + dV_tl = torch.zeros_like(v, dtype=torch.float32) + + bwd_kernel(q, k, v, dO, lse_tl, delta_tl, dQ_accum, dK_tl, dV_tl) + + post_kernel = flashattn_bwd_postprocess(batch, heads, seq_len, dim) + dQ_tl = post_kernel(dQ_accum) + + q_ref = q.clone().detach().requires_grad_() + k_ref = k.clone().detach().requires_grad_() + v_ref = v.clone().detach().requires_grad_() + + o_ref, _ = ref_program(q_ref, k_ref, v_ref, is_causal, groups) + o_ref.backward(dO) + + print("Verifying backward pass correctness...") + dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) + if dq_close: + print("dQ is correct.") + else: + print("dQ mismatch detected.") + + dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) + if dk_close: + print("dK is correct.") + else: + print("dK mismatch detected.") + + dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05) + if dv_close: + print("dV is correct.") + else: + print("dV mismatch detected.") + + print("\n=== Performance Benchmarking ===") + + def run_reference_fwd_bwd(): + q_ref_bench = q.clone().detach().requires_grad_() + k_ref_bench = k.clone().detach().requires_grad_() + v_ref_bench = v.clone().detach().requires_grad_() + + o_ref_bench, _ = ref_program(q_ref_bench, k_ref_bench, v_ref_bench, is_causal, groups) + + o_ref_bench.backward(dO) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100) + print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops") + + def run_complete_fwd_bwd(): + o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v) + + delta_tl_bench = bwd_prep(o_tl_bench, dO) + + dQ_bench = torch.zeros_like(q, dtype=torch.float32) + dK_bench = torch.zeros_like(k, dtype=torch.float32) + dV_bench = torch.zeros_like(v, dtype=torch.float32) + bwd_kernel(q, k, v, dO, lse_tl_bench, delta_tl_bench, dQ_bench, dK_bench, dV_bench) + + post_kernel(dQ_bench) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + tile_latency = benchmark_function(run_complete_fwd_bwd, warmup=10, repeat=100) + print( + f"Complete Flash Attention V2 Forward+Backward (Tile-lang): {tile_latency:.2f} ms | {total_flops / tile_latency * 1e-9:.2f} TFlops" + ) + + speedup = ref_latency / tile_latency + print(f"Speedup: {speedup:.2f}x") + + print("Forward output: Passed") + print(f"dQ: {'Passed' if dq_close else 'Failed'} (Max diff: {dq_max_diff:.6f})") + print(f"dK: {'Passed' if dk_close else 'Failed'} (Max diff: {dk_max_diff:.6f})") + print(f"dV: {'Passed' if dv_close else 'Failed'} (Max diff: {dv_max_diff:.6f})") + + if all([dq_close, dk_close, dv_close]): + print("All checks passed!") + else: + print("Some checks failed, may need further debugging.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=1024, help="sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") + args = parser.parse_args() + + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9c361ff1235a3f7f49b2900b5c5ee868d92a2e --- /dev/null +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -0,0 +1,246 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.tileop.base import GemmWarpPolicy +import itertools +import argparse +from functools import partial + + +# Custom supply function to ensure tensors are created on GPU +def supply_tensors_gpu(params): + """Supply function that creates tensors on GPU for ROCm/HIP.""" + tensors = [] + for param in params: + if hasattr(param, "shape") and hasattr(param, "dtype"): + # Force creation on GPU device + shape = [int(s) for s in param.shape] + tensor = torch.randn(shape, dtype=param.dtype, device="cuda") + tensors.append(tensor) + else: + tensors.append(param) + return tensors + + +def ref_program(Q, K, V, is_causal, groups=1): + assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" + dim = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def get_configs(): + """Generates configurations for the autotuner, tailored for FA-2 style parallelism.""" + block_M = [32, 64, 128, 256] + block_N = [32, 64, 128, 256] + threads = [128, 256, 512] + num_split_q = [64, 128, 256] + num_stages = [0, 1] + enable_rasterization = [True] + k_pack = [2] + panel_size = [7, 8] + qk_coalesced_width = [8] + v_coalesced_width = [4] + + valid_configs = [] + + for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( + block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width + ): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "num_split_q": s, + "threads": t, + "num_stages": stages, + "enable_rasterization": r, + "k_pack": k, + "panel_size": p, + "qk_coalesced_width": qkw, + "v_coalesced_width": vw, + } + ) + return valid_configs + + +@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu) +@tilelang.jit(out_idx=[3]) +def fast_flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups, + block_M: int, + block_N: int, + num_split_q: int, + threads: int, + num_stages: int, + enable_rasterization: bool, + k_pack: int, + panel_size: int, + qk_coalesced_width: int, + v_coalesced_width: int, +): + scale = (1.0 / dim) ** 0.5 + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + vec_size = qk_coalesced_width + v_vec_size = v_coalesced_width + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): + T.use_swizzle(panel_size, enable=enable_rasterization) + + bz = byz_combined // heads + by = byz_combined % heads + + num_q_blocks = T.ceildiv(seq_len, block_M) + + bx = T.alloc_var(T.int32) + bx = b_split + + with T.While(bx < num_q_blocks): + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + m_i = T.alloc_fragment([block_M], accum_dtype) + l_i = T.alloc_fragment([block_M], accum_dtype) + T.fill(acc_o, 0) + T.fill(m_i, -T.infinity(accum_dtype)) + T.fill(l_i, 0) + + current_bx = bx + q_block_offset = current_bx * block_M + + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + # Use register fragment for P instead of shared memory to reduce LDS usage + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + m_prev = T.alloc_fragment([block_M], accum_dtype) + scale_factor = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size) + + loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + + row_sum = T.alloc_fragment([block_M], accum_dtype) + + for k in T.Pipelined(loop_end_k, num_stages=num_stages): + kv_idx = k * block_N + + T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) + T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + k_pack=k_pack, + policy=GemmWarpPolicy.FullRow, + ) + + T.copy(m_i, m_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for i in T.Parallel(block_M): + m_i[i] = T.max(m_i[i], m_prev[i]) + + for i in T.Parallel(block_M): + sf = T.exp(m_prev[i] * scale - m_i[i] * scale) + l_i[i] *= sf + scale_factor[i] = sf + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scale_factor[i] + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale) + + T.reduce_sum(acc_s, row_sum, dim=1) + for i in T.Parallel(block_M): + l_i[i] += row_sum[i] + + # Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V + T.copy(acc_s, acc_s_cast) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow) + + l_inv = T.alloc_fragment([block_M], accum_dtype) + for i in T.Parallel(block_M): + safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0) + l_inv[i] = 1.0 / safe_l + + for i, j in T.Parallel(block_M, dim): + Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i] + + bx = current_bx + num_split_q + + return main + + +def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + print("Starting autotuning for FlashAttention-V2...") + kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups) + print(f"Autotuning finished. Best Configuration: {kernel.config}") + + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + print("Verifying correctness...") + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + + latency = profiler.do_bench(ref_program_processed, warmup=100) + print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") + + latency = profiler.do_bench(warmup=100) + print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=8, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--groups", type=int, default=1, help="groups") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) diff --git a/examples/analyze/README.md b/examples/analyze/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ec0a687547701d6e3a125620306d54fb4e85a9c --- /dev/null +++ b/examples/analyze/README.md @@ -0,0 +1,111 @@ +# TVM IR Performance Analyzer + +A performance analysis toolkit for TVM IR modules, Provides hardware-aware performance metrics including FLOPs, memory bandwidth utilization, and execution time estimation. + +## Features + +- โ€‹**Operation Analysis**: Supports arbitrary operations expressed in TVM IR (including GEMM and convolution) +- โ€‹**Memory Traffic Calculation**: Tracks global memory transfers +- โ€‹**Architecture-aware Metrics**: Pre-configured with NVIDIA GPU architectures (Ampere, Ada Lovelace) +- โ€‹**Performance Estimation**: Predicts execution time using roofline model +- โ€‹**TVM Integration**: Works with TVM IRModule and PrimFunc + +## Quick Start +### GEMM Analysis Example +```python +import tilelang.language as T +from tilelang.tools import Analyzer +from tilelang.carver.arch import CUDA + +M = N = K = 1024 + +def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128): + @T.prim_func + def main(A: T.Tensor((M, K), T.float16), + B: T.Tensor((N, K), T.float16), + C: T.Tensor((M, N), T.float)): + # ... (kernel definition) + return main + +cuda_device = CUDA("cuda") +result = Analyzer.analysis(kernel(), cuda_device) +print(result) +``` + +### Convolution Analysis Example +```python +import tilelang.language as T +from tilelang.tools import Analyzer +from tilelang.carver.arch import CUDA + +def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128): + @T.prim_func + def main(data: T.Tensor((N, H, W, C), T.float16), + kernel: T.Tensor((K, K, C, F), T.float16), + out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)): + # ... (convolution kernel definition) + return main + +cuda_device = CUDA("cuda") +result = Analyzer.analysis(kernel(), cuda_device) +print(result) +``` + +## API Documentation +### `AnalysisResult` Class +```python +@dataclass(frozen=True) +class AnalysisResult: + total_flops: int # Total floating-point operations + total_global_bytes: int # Global memory traffic in bytes + estimated_time: float # Predicted execution time (seconds) + tflops: float # Achieved TFLOPS + bandwidth_GBps: float # Memory bandwidth utilization +``` +### `Analyzer` Class Methods +#### `analysis(fn, device)` +* โ€‹Parameters: + * fn: TVM IRModule or PrimFunc + * device: Device configuration object +* Returns: AnalysisResult +#### Supported Architectures +```python +# Extendable to custom hardware via: "compute_capability": (cores_per_SM, clock_GHz, flops_per_cycle, max_SM_count) +ARCH_CONFIGS = { + "80": (128, 1.41, 2, 108), # A100 + "86": (128, 1.70, 2, 84), # RTX 3080 + "89": (128, 2.52, 2, 128) # RTX 4090 +} +``` + +## Implementation Details + +### Performance Model +Uses roofline model with two constraints: +1. โ€‹**Compute Bound**: `Time = Total FLOPs / (SM Count ร— Cores/SM ร— Clock ร— FLOPs/Cycle)` +2. โ€‹**Memory Bound**: `Time = Memory Bytes / (Bandwidth ร— Utilization)` + +### IR Analysis Pass +1. โ€‹**Traversal**: Walks through TVM IR using `ir_transform` +2. โ€‹**Operation Detection**: + - Counts FLOPs for all compute operations + - Calculates memory traffic for all memory operations +3. โ€‹**Loop Handling**: + - Tracks nested loops for operation scaling + - Accounts for block/grid dimensions + +## Key Metrics Calculation + +| Metric | Formula | +|-------------------------|-----------------------------------------| +| FLOPs per GEMM | `2 ร— M ร— N ร— K` | +| Memory Traffic per Copy | `elements ร— dtype_size ร— loop_product` | +| Achieved TFLOPS | `total_flops / estimated_time / 1e12` | +| Memory Bandwidth | `total_global_bytes / estimated_time` | + +## Limitations +1. Requires memory operations to be properly annotated in the IR +2. Assumes perfect memory coalescing and no bank conflicts + +## Supported Operations +Any operation expressed in TVM IR diff --git a/examples/analyze/example_conv_analyze.py b/examples/analyze/example_conv_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..db21e02f62bc0ad281848a8085d6b661d2d4e93c --- /dev/null +++ b/examples/analyze/example_conv_analyze.py @@ -0,0 +1,89 @@ +import tilelang.language as T +from tilelang.tools import Analyzer +from tilelang.carver.arch import CUDA +from tilelang.carver.arch import CDNA +from tilelang.layout import make_swizzled_layout +import torch + +N = 64 +C = 256 +H = 512 +W = 512 +F = 512 +K = 3 +S = 1 +D = 1 +P = 1 + + +def check_hopper(): + # if not torch.cuda.is_available(): + # return None + # props = torch.cuda.get_device_properties(0) + # compute_capability = props.major, props.minor + # return compute_capability == (9, 0) + return False + + +def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + dtype = T.float16 + accum_dtype = T.float32 + is_hopper = check_hopper() + + @T.prim_func + def conv( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout( + { + out_shared: make_swizzled_layout(out_shared), + data_shared: make_swizzled_layout(data_shared), + kernel_shared: make_swizzled_layout(kernel_shared), + } + ) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + if is_hopper: + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + else: + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return conv + + +def main(): + my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) + cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + result = Analyzer.analysis(my_func, cuda_device) + print(result) + print(f"Analyzed FLOPs: {result.total_flops}") + + +if __name__ == "__main__": + main() diff --git a/examples/analyze/example_gemm_analyze.py b/examples/analyze/example_gemm_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..0367af126e04d631d53c9d63eb6f269b807dcf2d --- /dev/null +++ b/examples/analyze/example_gemm_analyze.py @@ -0,0 +1,60 @@ +import tilelang.language as T +from tilelang.tools import Analyzer +from tilelang.carver.arch import CUDA +from tilelang.carver.arch import CDNA +import torch + +M = N = K = 1024 + + +def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, +): + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def matmul( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return matmul + + +def main(): + my_func = kernel(128, 128, 32, 3, 128, True) + + cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + result = Analyzer.analysis(my_func, cuda_device) + + print(f"Analyzed FLOPs: {result.total_flops}") + print(f"Expected FLOPs: {2 * M * N * K}") + + +if __name__ == "__main__": + main() diff --git a/examples/analyze/test_example_analyze.py b/examples/analyze/test_example_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..448844b900270b61ca452d3acc25104d061d1492 --- /dev/null +++ b/examples/analyze/test_example_analyze.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_gemm_analyze +import example_conv_analyze + + +def test_example_gemm_analyze(): + example_gemm_analyze.main() + + +def test_example_conv_analyze(): + example_conv_analyze.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/attention_sink/README.md b/examples/attention_sink/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ed4b7004e6283c7c2b7a5cdeffbf7da90e1dcca4 --- /dev/null +++ b/examples/attention_sink/README.md @@ -0,0 +1,46 @@ +# Attention Sink + +We compare with an optimized version of the official Triton implementation [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py). + + +## Algorithm +### Forward +The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage. + +### Backward +Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks` additionally, which is given by: + +$$ +dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q} +$$ + +where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row). + +## Benchmark of forward process + +### Benchmark Environment +- **Hardware**: NVIDIA H800 +- **CUDA version**: 12.9 +- **Triton Version**: 3.4.0 + +### Results + +- dtype=bfloat16 +- batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B) +- Full attention is adopted. + +| SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup | +|---------|---------|---------------|----------------------|---------| +| 2048 | 64 | 232.98 | **281.89** | 1.21x | +| 2048 | 128 | 321.55 | **417.98** | 1.30x | +| | | | | | +| 4096 | 64 | 280.70 | **349.47** | 1.25x | +| 4096 | 128 | 369.61 | **497.13** | 1.35x | +| | | | | | +| 8192 | 64 | 299.04 | **385.56** | 1.29x | +| 8192 | 128 | 399.39 | **507.93** | 1.27x | +| | | | | | +| 16384 | 64 | 309.46 | **400.62** | 1.29x | +| 16384 | 128 | 418.99 | **549.11** | 1.31x | + +> The backward performance will be further optimized in the future. \ No newline at end of file diff --git a/examples/attention_sink/benchmark_gqa_sink_fwd.py b/examples/attention_sink/benchmark_gqa_sink_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..211ef1d18cda28f20c6104b17f0330322a437d3f --- /dev/null +++ b/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -0,0 +1,211 @@ +import torch +import argparse +from tilelang.profiler import do_bench +from tilelang import language as T +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor +from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs +from typing import Optional + + +@triton.jit +def triton_kernel( + Q, + K, + V, + Sinks, + sm_scale, + Out, + Z, + H, + N_Q_CTX, + N_KV_CTX, + HEAD_DIM: tl.constexpr, + groups: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BANDWIDTH: tl.constexpr, + start_q: tl.constexpr, +): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + # load attention sinks + if Sinks is not None: # noqa: SIM108 + sink = tl.load(Sinks + off_h).to(tl.float32) + else: + sink = 0 + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) + + if BANDWIDTH: + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + else: + lo, hi = 0, start_q + (start_m + 1) * BLOCK_M + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None] + + if BANDWIDTH: + too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1) + mask = mask | too_old + + k = K.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T + qk = tl.dot(q, k, allow_tf32=False) + + qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + + p = tl.math.exp(qk) + alpha = tl.math.exp(m_i - m_ij) + l_ij = tl.sum(p, 1) + acc = acc * alpha[:, None] + + v = V.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]) + # v = v.to(tl.float32) + p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core + acc = tl.dot(p, v, acc, allow_tf32=False) + + l_i = l_i * alpha + l_ij + m_i = m_ij + + sink = tl.math.exp(sink - m_i) + z = l_i + sink + acc = acc / z[:, None] + # m_i += tl.math.log(l_i) + # m_ptrs = M + off_hz * N_Q_CTX + offs_m + # tl.store(m_ptrs, m_i) + acc = acc.to(Out.dtype)[None, None, :, :] + Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) + + +def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor: + bs, n_heads, seq_q, head_dim = Q.shape + _, n_heads_kv, seq_kv, _ = K.shape + BLOCK_M = 64 + BLOCK_N = 64 + groups = n_heads // n_heads_kv + + o = torch.empty_like(Q) + grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1) + triton_kernel[grid]( + TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]), + TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]), + TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]), + Sinks, + 1.0 / head_dim**0.5, + TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]), + bs, + n_heads, + N_Q_CTX=seq_q, + N_KV_CTX=seq_kv, + HEAD_DIM=head_dim, + groups=groups, + BANDWIDTH=window_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + start_q=seq_kv - seq_q, + ) + return o + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + + if torch.allclose( + triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ): + print("Checks for triton passed.โœ…") + else: + print("Checks for triton failed.โŒ") + + # Benchmark triton + latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) + print("Triton: {:.2f} ms".format(latency_triton)) + print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9)) + + # Benchmark tilelang + latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency_tilelang)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) + + print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/benchmark_mha_sink_fwd.py b/examples/attention_sink/benchmark_mha_sink_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..50747e6b09d902668ec99ee5c267c9dccadf208f --- /dev/null +++ b/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -0,0 +1,198 @@ +import torch +import argparse +from tilelang.profiler import do_bench +from tilelang import language as T +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor +from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs +from typing import Optional + + +@triton.jit +def triton_kernel( + Q, + K, + V, + Sinks, + sm_scale, + Out, + Z, + H, + N_Q_CTX, + N_KV_CTX, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BANDWIDTH: tl.constexpr, + start_q: tl.constexpr, +): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + # load attention sinks + if Sinks is not None: # noqa: SIM108 + sink = tl.load(Sinks + off_h).to(tl.float32) + else: + sink = 0 + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) + + if BANDWIDTH: + lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M + else: + lo, hi = 0, start_q + (start_m + 1) * BLOCK_M + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None] + + if BANDWIDTH: + too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1) + mask = mask | too_old + + k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T + qk = tl.dot(q, k, allow_tf32=False) + + qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + + p = tl.math.exp(qk) + alpha = tl.math.exp(m_i - m_ij) + l_ij = tl.sum(p, 1) + acc = acc * alpha[:, None] + + v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]) + # v = v.to(tl.float32) + p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core + acc = tl.dot(p, v, acc, allow_tf32=False) + + l_i = l_i * alpha + l_ij + m_i = m_ij + + sink = tl.math.exp(sink - m_i) + z = l_i + sink + acc = acc / z[:, None] + # m_i += tl.math.log(l_i) + # m_ptrs = M + off_hz * N_Q_CTX + offs_m + # tl.store(m_ptrs, m_i) + acc = acc.to(Out.dtype)[None, None, :, :] + Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc) + + +def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor: + bs, n_heads, seq_q, head_dim = Q.shape + seq_kv = K.shape[2] + BLOCK_M = 64 + BLOCK_N = 64 + + o = torch.empty_like(Q) + grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1) + triton_kernel[grid]( + TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]), + TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]), + TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]), + Sinks, + 1.0 / head_dim**0.5, + TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]), + bs, + n_heads, + N_Q_CTX=seq_q, + N_KV_CTX=seq_kv, + HEAD_DIM=head_dim, + BANDWIDTH=window_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + start_q=seq_kv - seq_q, + ) + return o + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + + torch.testing.assert_close( + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) + print("All checks passed.โœ…") + + latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) + print("Triton: {:.2f} ms".format(latency)) + print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..541baca0430a4378220ce8b48868e64d4014e5dc --- /dev/null +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -0,0 +1,512 @@ +# Adapted from tilelang/examples/flash_attention/example_gqa_bwd.py + +import torch +import tilelang +from tilelang.profiler import do_bench +import tilelang.language as T +import argparse +from typing import Optional + + +def get_bwd_configs(): + sm_major, sm_minor = torch.cuda.get_device_capability() + sm_version = sm_major * 10 + sm_minor + if sm_version == 80: + return 64, 32, 1, 128 + elif sm_version == 90: + return 128, 32, 2, 256 + else: + raise ValueError(f"Unsupported SM version: {sm_version}") + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd( + batch, + heads, + seq_len, + dim, + groups=1, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, heads, seq_len, dim] + kv_shape = [batch, head_kv, seq_len, dim] + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + Output: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([heads], dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype=T.float16): # None for full attention + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, heads, seq_len, dim] + kv_shape = [batch, head_kv, seq_len, dim] + accum_dtype = T.float32 + + block_M, block_N, num_stages, threads = get_bwd_configs() + + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(kv_shape, dtype), # type: ignore + V: T.Tensor(kv_shape, dtype), # type: ignore + dO: T.Tensor(q_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(kv_shape, accum_dtype), # type: ignore + dV: T.Tensor(kv_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim], accum_dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared) + T.clear(dv) + T.clear(dk) + + loop_st = T.floordiv(by * block_M, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + for i, j in T.Parallel(block_M, block_N): + if window_size is not None: + qkT[i, j] = T.if_then_else( + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) + else: + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) + + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared) + + return flash_bwd + + +@tilelang.jit(out_idx=-1) +def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len] + + @T.prim_func + def flash_bwd_dsink( + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): + sink = T.alloc_local([1], dtype) + lse_fragment = T.alloc_fragment([block], accum_dtype) + delta_fragment = T.alloc_fragment([block], accum_dtype) + dsink_fragment = T.alloc_fragment([block], dtype) + + sink[0] = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) + for i in T.Parallel(block): + dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) + + return flash_bwd_dsink + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sinks, window_size, groups): + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)] + BATCH, H, N_CTX, D_HEAD = q.shape + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 + kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + o, lse = kernel(q, k, v, sinks) + ctx.save_for_backward(q, k, v, sinks, o, lse) + ctx.window_size = window_size + ctx.groups = groups + return o + + @staticmethod + def backward(ctx, do): + q, k, v, sinks, o, lse = ctx.saved_tensors + BATCH, H, N_CTX, D_HEAD = q.shape + groups = ctx.groups + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 + + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + delta = kernel_prep(o, do) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, ctx.window_size, dtype=dtype) + q_shape = [BATCH, H, N_CTX, D_HEAD] + head_kv = H // groups + kv_shape = [BATCH, head_kv, N_CTX, D_HEAD] + dq = torch.zeros(q_shape, dtype=torch.float32, device=q.device) # acc for atomicAdd + dk = torch.zeros(kv_shape, dtype=torch.float32, device=q.device) + dv = torch.zeros(kv_shape, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + + kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) + return dq, dk, dv, dsinks, None, None + + +attention = _attention.apply + + +# Adapted and optimized from +# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + query = query.transpose(1, 2).contiguous() + query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim) + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + + start_q = num_keys - num_queries + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def main( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= N_CTX + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 + total_flops = 5 * flops_per_matmul + + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + V = torch.randn_like(K).requires_grad_() + sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_() + dO = torch.randn_like(Q) + + O = attention(Q, K, V, sinks, window_size, groups) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + dsinks, sinks.grad = sinks.grad.clone(), None + + O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + dsinks_ref, sinks.grad = sinks.grad.clone(), None + + # Checks + rtol, atol = { + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), + }[dtype] + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" + + print("All checks passed for tilelang kernels.โœ…") + + # Only benchmark backward here + def torch_bwd(): + O_ref.backward(dO, retain_graph=True) + + def tl_bwd(): + O.backward(dO, retain_graph=True) + + latency = do_bench(torch_bwd, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(tl_bwd, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--groups", type=int, default=8, help="Groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..df157cd0ff396c3f5e358c71334dc77695e72315 --- /dev/null +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -0,0 +1,332 @@ +# Modified from tilelang/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl + +import torch +import tilelang +from tilelang.autotuner import autotune +from tilelang.profiler import do_bench +import tilelang.language as T +from tilelang.layout import make_swizzled_layout +import itertools +import argparse +from typing import Optional + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune( + configs=get_configs(), + warmup=500, + rep=100, +) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups=1, + window_size=None, # None for full attention + sm_scale=None, + block_M=128, + block_N=128, + num_stages=2, + threads=256, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, head_kv, seq_kv, dim] + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined( + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +# Following functions are adapted and optimized from +# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + query = query.transpose(1, 2).contiguous() + query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim) + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + + start_q = num_keys - num_queries + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") + return query, key, value, sinks + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + + torch.testing.assert_close( + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) + print("All checks passed.โœ…") + + # Benchmark tilelang + latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency_tilelang)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..be405e8bc3c986d0b3241d650c1ab1652e1081ec --- /dev/null +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -0,0 +1,505 @@ +# Adapted from tilelang/examples/flash_attention/example_mha_bwd_bhsd.py + +import torch +import tilelang +from tilelang.profiler import do_bench +import tilelang.language as T +import argparse +from typing import Optional + + +def get_bwd_configs(): + sm_major, sm_minor = torch.cuda.get_device_capability() + sm_version = sm_major * 10 + sm_minor + if sm_version == 80: + return 64, 32, 1, 128 + elif sm_version == 90: + return 128, 32, 2, 256 + else: + raise ValueError(f"Unsupported SM version: {sm_version}") + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd( + batch, + heads, + seq_len, + dim, + window_size=None, # None for full attention, + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + shape = [batch, heads, seq_len, dim] + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Sinks: T.Tensor([heads], dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([heads], dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) + start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd( + batch, + heads, + seq_len, + dim, + window_size=None, # None for full attention + sm_scale=None, + dtype: T.dtype = T.float16, +): + block_M, block_N, num_stages, threads = get_bwd_configs() + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + shape = [batch, heads, seq_len, dim] + accum_dtype = T.float32 + + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + # should not store K to local if dim is large + # K_local = T.alloc_fragment([block_M, dim], dtype) + # K_local_T = T.alloc_fragment([block_M, dim], dtype) + # V_local = T.alloc_fragment([block_M, dim], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) + T.clear(dv) + T.clear(dk) + + loop_st = T.floordiv(by * block_M, block_N) + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N)) + if window_size is not None + else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + for i, j in T.Parallel(block_M, block_N): + if window_size is not None: + qkT[i, j] = T.if_then_else( + by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0 + ) + else: + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq) + + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) + + return flash_bwd + + +@tilelang.jit(out_idx=-1) +def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch, heads, seq_len] + + @T.prim_func + def flash_bwd_dsink( + Sinks: T.Tensor([heads], dtype), # type: ignore + Delta: T.Tensor(shape, accum_dtype), # type: ignore + lse: T.Tensor(shape, accum_dtype), # type: ignore + dsinks: T.Tensor(shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): + sink = T.alloc_local([1], dtype) + lse_fragment = T.alloc_fragment([block], accum_dtype) + delta_fragment = T.alloc_fragment([block], accum_dtype) + dsink_fragment = T.alloc_fragment([block], accum_dtype) + + sink[0] = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) + for i in T.Parallel(block): + dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i] + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) + + return flash_bwd_dsink + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sinks, window_size): + BATCH, H, N_CTX, D_HEAD = q.shape + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 + kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) + o, lse = kernel(q, k, v, sinks) + ctx.save_for_backward(q, k, v, sinks, o, lse) + ctx.window_size = window_size + return o + + @staticmethod + def backward(ctx, do): + q, k, v, sinks, o, lse = ctx.saved_tensors + BATCH, H, N_CTX, D_HEAD = q.shape + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)] + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + delta = kernel_prep(o, do) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.window_size, dtype=dtype) + shape = [BATCH, H, N_CTX, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) # acc for atomicAdd + dk = torch.empty(shape, dtype=q.dtype, device=q.device) + dv = torch.empty(shape, dtype=q.dtype, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + + kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1) + return dq, dk, dv, dsinks, None + + +attention = _attention.apply + + +# Adapted and optimized from +# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + start_q = num_keys - num_queries + + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1) + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: T.dtype = T.float16): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= N_CTX + flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 + total_flops = 5 * flops_per_matmul + + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_() + K = torch.randn_like(Q).requires_grad_() + V = torch.randn_like(Q).requires_grad_() + sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_() + dO = torch.randn_like(Q) + + O = attention(Q, K, V, sinks, window_size) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + dsinks, sinks.grad = sinks.grad.clone(), None + + O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + dsinks_ref, sinks.grad = sinks.grad.clone(), None + + # Checks + rtol, atol = { + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), + }[dtype] + assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" + + print("All checks passed for tilelang kernels.โœ…") + + # Only benchmark backward here + def torch_bwd(): + O_ref.backward(dO, retain_graph=True) + + def tl_bwd(): + O.backward(dO, retain_graph=True) + + latency = do_bench(torch_bwd, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(tl_bwd, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--h", type=int, default=64, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") + parser.add_argument("--d_head", type=int, default=128, help="Head dimension") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..f6754bd94acf6fe9ca440cc9058ed7080b5ed267 --- /dev/null +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -0,0 +1,315 @@ +# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd.py + +import torch +import tilelang +from tilelang.autotuner import autotune +from tilelang.profiler import do_bench +import tilelang.language as T +from tilelang.layout import make_swizzled_layout +import itertools +import argparse +from typing import Optional + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=500, rep=100) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined(start, end, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + start_q = num_keys - num_queries + + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") + return query, key, value, sinks + + +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + + torch.testing.assert_close( + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) + print("All checks passed.โœ…") + + latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..ecaf2ce33941587ceeefc49627424f58adeca276 --- /dev/null +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -0,0 +1,322 @@ +# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl + +import torch +import tilelang +from tilelang.autotuner import autotune +from tilelang.profiler import do_bench +import tilelang.language as T +from tilelang.layout import make_swizzled_layout +import itertools +import argparse +from typing import Optional + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=500, rep=100) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size=None, # None for full attention + sm_scale=None, + block_M=128, + block_N=128, + num_stages=2, + threads=256, + dtype: T.dtype = T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + if window_size is not None: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype)) + else: + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # NOTE(wt): check_inf is necessary for sliding window attention. + for i in T.Parallel(block_M): + if window_size is not None: + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + Sinks: T.Tensor([heads], dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + T.annotate_layout( + { + Q_shared: make_swizzled_layout(Q_shared), + K_shared: make_swizzled_layout(K_shared), + V_shared: make_swizzled_layout(V_shared), + O_shared: make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[by] + + end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + + start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0 + + for k in T.Pipelined( + start, + end, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +# Following functions are adapted and optimized from +# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py +def ref_program( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + sliding_window: Optional[int] = None, + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape + batch_size, num_keys, num_key_value_heads, head_dim = key.shape + start_q = num_keys - num_queries + + sm_scale: float = 1.0 / head_dim**0.5 + + sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() + key = key.unsqueeze(3) + value = value.unsqueeze(3) + + pos_keys = torch.arange(num_keys, device=query.device) + pos_queries = torch.arange(num_queries, device=query.device) + start_q + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + + if sliding_window: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale + logits = logits + mask[None, None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks, logits_max) + sinks = torch.exp(sinks - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + + output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) + + output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype) + return output.transpose(1, 2).contiguous() + + +def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda") + key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda") + sinks = torch.randn([H], dtype=dtype, device="cuda") + return query, key, value, sinks + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: T.dtype = T.float16, + tune: bool = False, +): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() + if window_size is not None: + print("Using sliding window attention.") + assert window_size <= seq_q + flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation + else: + print("Using full attention.") + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 + total_flops = 2 * flops_per_matmul + + if tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype) + print(f"Best latency: {kernel.latency}") + print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}") + print(f"Best config: {kernel.config}") + else: + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}") + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype, + ) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + + torch.testing.assert_close( + kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2 + ) + print("All checks passed.โœ…") + + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query") + parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") + parser.add_argument("--tune", action="store_true", help="tune") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/test_example_attention_sink.py b/examples/attention_sink/test_example_attention_sink.py new file mode 100644 index 0000000000000000000000000000000000000000..57242c199c4b70345fc3fe29d3559da18e4ac990 --- /dev/null +++ b/examples/attention_sink/test_example_attention_sink.py @@ -0,0 +1,65 @@ +import tilelang.testing + +import example_mha_sink_fwd_bhsd +import example_mha_sink_fwd_bhsd_wgmma_pipelined +import example_gqa_sink_fwd_bhsd_wgmma_pipelined +import example_mha_sink_bwd_bhsd +import example_gqa_sink_bwd_bhsd + + +@tilelang.testing.requires_cuda +def test_example_mha_sink_fwd_bhsd_full_attn(): + example_mha_sink_fwd_bhsd.main() + + +@tilelang.testing.requires_cuda +def test_example_mha_sink_fwd_bhsd_sliding_window(): + example_mha_sink_fwd_bhsd.main(window_size=128) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_full_attn(): + example_mha_sink_fwd_bhsd_wgmma_pipelined.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + example_mha_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_full_attn(): + example_gqa_sink_fwd_bhsd_wgmma_pipelined.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + example_gqa_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128) + + +@tilelang.testing.requires_cuda +def test_example_mha_sink_bwd_bhsd(): + example_mha_sink_bwd_bhsd.main() + + +@tilelang.testing.requires_cuda +def test_example_mha_sink_bwd_bhsd_sliding_window(): + example_mha_sink_bwd_bhsd.main(window_size=128) + + +@tilelang.testing.requires_cuda +def test_example_gqa_sink_bwd_bhsd(): + example_gqa_sink_bwd_bhsd.main() + + +@tilelang.testing.requires_cuda +def test_example_gqa_sink_bwd_bhsd_sliding_window(): + example_gqa_sink_bwd_bhsd.main(window_size=128) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/bitnet-1.58b/.gitignore b/examples/bitnet-1.58b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6ea8874968d000cd47f52f55f32a92f0127532b3 --- /dev/null +++ b/examples/bitnet-1.58b/.gitignore @@ -0,0 +1 @@ +models/ \ No newline at end of file diff --git a/examples/bitnet-1.58b/README.md b/examples/bitnet-1.58b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2b587eab4cc6128965be2e4cac4a5d68db13a86c --- /dev/null +++ b/examples/bitnet-1.58b/README.md @@ -0,0 +1,97 @@ +--- +license: mit +--- + + +This is a Tilelang Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. + +## Make Checkpoints for vLLM + +We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension. + +```bash +# move to the integration directory +cd /root/to/BitBLAS/integration/BitNet +# make the checkpoint +./maint/generate_bitnet_model_native_format.sh +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory +``` + +The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization. + +```bash +./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory +``` + +Finnaly, you can use the ckpt in vLLM with: + +```bash +cd vllm_workspace +# inference with the ckpt with fp16 uncompressed metadata +python3 inference_with_native_format.py +# inference with the ckpt with BitBLAS compressed metadata +python3 inference_with_bitblas_format.py +``` + +**Benchmark results of vLLM** + +| Model | Framework | BS16IN32OUT128 | BS1IN512OUT1024 | BS32IN32OUT128 | +|------------------------|--------------------------|----------------|-----------------|----------------| +| bitnet-3b-1.58bits | pytorch | 106.83 | 49.34 | 209.03 | +| bitnet-3b-1.58bits | pytorch-tilelang | 240.33 | 103.09 | 493.31 | +| bitnet-3b-1.58bits | vllm-tilelang | 379.25 | 117.43 | 752.55 | +| bitnet-3b-1.58bits | vllm-tilelang-cuda-graph | 2543.58 | 1621.08 | 2731.79 | + + +## BitBLAS Results + +### Performance + +**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo. + +| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas | +|:---------------:|:------:|:---------:|:------:|:--------:|:---------------------------:|:-----------------------:| +| bitnet_b1_58-3B | A100 | 1 | 1 | LLAMA-3B | 177.6729107 | 64.17962909 | +| bitnet_b1_58-3B | A100 | 128 | 1 | LLAMA-3B | 188.6145592 | 63.48158518 | +| bitnet_b1_58-3B | A100 | 1 | 2048 | LLAMA-3B | 348.7066031 | 202.6877999 | + +### On-the-Fly GPU Memory Footprint + +We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage. + +| **Model** | **Device** | **batchsize** | **in_seq** | **bitnet-1.58b-3b-huggingface** | **bitnet-1.58b-3b-bitblas** | +|:---------------:|:----------:|:-------------:|:----------:|:-------------------------------:|:---------------------------:| +| bitnet_b1_58-3B | A100 | 1 | 1 | 7595 MB | 1729 MB | +| bitnet_b1_58-3B | A100 | 128 | 1 | 7677 MB | 1789 MB | +| bitnet_b1_58-3B | A100 | 1 | 2048 | 8731 MB | 3163 MB | + +## PPL and Zero-shot Accuracy + +The number is Reported from the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B), Please checkout the `eval_ppl.py`. + +PPL and zero-shot accuracy: +| Models | PPL| ARCe| ARCc| HS | BQ | OQ | PQ | WGe | Avg +|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| +| FP16 700M (reported) | 12.33 | 54.7 | 23.0 | 37.0 | 60.0 | 20.2 | 68.9 | 54.8 | 45.5 | +| BitNet b1.58 700M (reported) | 12.87 | 51.8 | 21.4 | 35.1 | 58.2 | 20.0 | 68.1 | 55.2 | 44.3 | +| BitNet b1.58 700M (reproduced) | 12.78 | 51.4 | 21.8 | 35.0 | 59.6 | 20.6 | 67.5 | 55.4 | 44.5 | +| FP16 1.3B (reported) | 11.25 | 56.9 | 23.5 | 38.5 | 59.1 | 21.6 | 70.0 | 53.9 | 46.2 +| BitNet b1.58 1.3B (reported) | 11.29 | 54.9 | 24.2 | 37.7 | 56.7 | 19.6 | 68.8 | 55.8 | 45.4 | +| BitNet b1.58 1.3B (reproduced) | 11.19 | 55.8 | 23.7 | 37.6 | 59.0 | 20.2 | 69.2 | 56.0 | 45.9 +| FP16 3B (reported) | 10.04 | 62.1 | 25.6 | 43.3 | 61.8 | 24.6 | 72.1 | 58.2 | 49.7 +| BitNet b1.58 3B (reported) | 9.91 | 61.4 | 28.3 | 42.9 | 61.5 | 26.6 | 71.5 | 59.3 | 50.2 +| BitNet b1.58 3B (reproduced) | 9.88 | 60.9 | 28.0 | 42.3 | 58.3 | 26.0 | 71.4 | 60.3 | 49.6 | + +The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors. + +## Citations + +```bibtex +@article{ma2024era, + title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits}, + author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu}, + journal={arXiv preprint arXiv:2402.17764}, + year={2024} +} +``` \ No newline at end of file diff --git a/examples/bitnet-1.58b/benchmark.sh b/examples/bitnet-1.58b/benchmark.sh new file mode 100755 index 0000000000000000000000000000000000000000..6a2550d45562387677cf169ae66744fcd6a8657e --- /dev/null +++ b/examples/bitnet-1.58b/benchmark.sh @@ -0,0 +1,11 @@ +python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 | tee b16_i32_o128.log + +python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 | tee b1_i512_o64.log + +python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 | tee b32_i32_o128.log + +python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b16_i32_o128_bitblas.log + +python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 --bitblas | tee b1_i512_o64_bitblas.log + +python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b32_i32_o128_bitblas.log diff --git a/examples/bitnet-1.58b/benchmark_generate.py b/examples/bitnet-1.58b/benchmark_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..d678b91a4e1c970e2209d2dfc0a102af4c3cf81b --- /dev/null +++ b/examples/bitnet-1.58b/benchmark_generate.py @@ -0,0 +1,114 @@ +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers import GenerationConfig +import time +import argparse + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + + +def generate_text_batch(model, tokenizer, prompts, max_length=100): + # Encode the input prompts as a batch + input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) + + # Generate cos and sin values (commented out as not used in generation) + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # position_embeddings = model.embed_positions(position_ids) + # cos = position_embeddings[:, :, 0::2].cos() + # sin = position_embeddings[:, :, 1::2].sin() + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + # output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin) + end_time = time.time() + + # Decode the output ids to text + generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids] + + generation_time = end_time - start_time + num_tokens = sum(len(output_id) for output_id in output_ids) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_texts + + +def profile(model, input_data): + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +model_path = "1bitLLM/bitnet_b1_58-3B" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--bs", default=16, type=int) + parser.add_argument("--in_seq_len", default=32, type=int) + parser.add_argument("--out_seq_len", default=128, type=int) + parser.add_argument("--bitblas", action="store_true") + args = parser.parse_args() + bs = args.bs + in_seq_len = args.in_seq_len + out_seq_len = args.out_seq_len + is_bitblas = args.bitblas + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + if is_bitblas: + with torch.no_grad(): + model.quantize() + + tokenizer = BitnetTokenizer.from_pretrained(model_path) + prompt = "" + for _ in range(in_seq_len): + prompt += "Hello " + + prompts = [] + for _ in range(bs): + prompts.append(prompt) + max_length = out_seq_len + in_seq_len + print(generate_text_batch(model, tokenizer, prompts, max_length=max_length)) + + +if __name__ == "__main__": + main() diff --git a/examples/bitnet-1.58b/benchmark_inference_latency.py b/examples/bitnet-1.58b/benchmark_inference_latency.py new file mode 100644 index 0000000000000000000000000000000000000000..788fc5565d5d58b59ef11a11b33f357e911ba9bc --- /dev/null +++ b/examples/bitnet-1.58b/benchmark_inference_latency.py @@ -0,0 +1,57 @@ +import argparse +import torch + +from modeling_bitnet import BitnetForCausalLM + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) + + +def profile(model, input_data): + import time + + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +def main(): + model = BitnetForCausalLM.from_pretrained( + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", + low_cpu_mem_usage=True, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ).half() + with torch.no_grad(): + model.quantize() + model = torch.compile(model) + + benchmark_sets = [(1024, 1), (1, 2048)] + for batch_size, seq_len in benchmark_sets: + input_id = torch.ones(batch_size, seq_len).long().cuda() + latency = profile(model, input_id) + print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}") + + +if __name__ == "__main__": + main() diff --git a/examples/bitnet-1.58b/benchmark_model_10k_loops.py b/examples/bitnet-1.58b/benchmark_model_10k_loops.py new file mode 100644 index 0000000000000000000000000000000000000000..306c88428277b591bff935be701e5401a8faaf54 --- /dev/null +++ b/examples/bitnet-1.58b/benchmark_model_10k_loops.py @@ -0,0 +1,63 @@ +import argparse +import torch + +from modeling_bitnet import BitnetForCausalLM + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) +parser.add_argument("--batch_size", default=1, type=int) +parser.add_argument("--seq_len", default=1, type=int) + +args = parser.parse_args() + +seq_len = args.seq_len +batch_size = args.batch_size + + +def profile(model, input_data): + import time + + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +def main(): + model = BitnetForCausalLM.from_pretrained( + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", + low_cpu_mem_usage=True, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ).half() + with torch.no_grad(): + model._post_process_weights() + + torch.cuda.empty_cache() + + input_id = torch.ones(batch_size, seq_len).long().cuda() + for _ in range(10000): + _ = model(input_id) + + +if __name__ == "__main__": + main() diff --git a/examples/bitnet-1.58b/configuration_bitnet.py b/examples/bitnet-1.58b/configuration_bitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..63c499db36d96d50f567794bf80a60882e08114f --- /dev/null +++ b/examples/bitnet-1.58b/configuration_bitnet.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class BitnetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BitnetModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BitnetModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Bitnet 1 supports up to 2048 tokens, + Bitnet 2 up to 4096, CodeBitnet up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import BitnetModel, BitnetConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = BitnetConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = BitnetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + weight_bits=1, + input_bits=8, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.weight_bits = weight_bits + self.input_bits = input_bits + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}") + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/examples/bitnet-1.58b/eval_correctness.py b/examples/bitnet-1.58b/eval_correctness.py new file mode 100644 index 0000000000000000000000000000000000000000..11d47004b81edf517d442cb0eb2b70e6c583cce0 --- /dev/null +++ b/examples/bitnet-1.58b/eval_correctness.py @@ -0,0 +1,99 @@ +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers import GenerationConfig +import time +import transformers + +print(f"transformers version is {transformers.__version__}") + +# version must be lower than or equal to 4.40 +assert transformers.__version__ <= "4.40.0" + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def profile(model, input_data): + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +model_path = "1bitLLM/bitnet_b1_58-3B" + + +def main(): + model = ( + BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=False, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + + tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) + input_id = tokenizer("Hello")["input_ids"] + input_id = torch.tensor(input_id).unsqueeze(0).cuda() + + print("original model generated text:") + print(generate_text(model, tokenizer, "Hello", max_length=100)) + + model.quantize() + print("quantized model generated text:") + print(generate_text(model, tokenizer, "Hello", max_length=100)) + + +if __name__ == "__main__": + main() diff --git a/examples/bitnet-1.58b/eval_gpu_memory.py b/examples/bitnet-1.58b/eval_gpu_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..00c914cb31c919fc536d0705f59cacf29a30e287 --- /dev/null +++ b/examples/bitnet-1.58b/eval_gpu_memory.py @@ -0,0 +1,52 @@ +import argparse +import torch + +from modeling_bitnet import BitnetForCausalLM + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) + + +def profile(model, input_data): + import time + + import numpy as np + + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +def main(): + model = BitnetForCausalLM.from_pretrained( + "1bitLLM/bitnet_b1_58-3B", + device_map="auto", + low_cpu_mem_usage=True, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ).half() + print(f"gpu memory: {torch.cuda.memory_allocated() / 1024**3} GB") + with torch.no_grad(): + model._post_process_weights() + print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024**3} GB") + + +if __name__ == "__main__": + main() diff --git a/examples/bitnet-1.58b/eval_ppl.py b/examples/bitnet-1.58b/eval_ppl.py new file mode 100644 index 0000000000000000000000000000000000000000..97db2d0f5236f369a33f70ac1b07fe9a8c01df9d --- /dev/null +++ b/examples/bitnet-1.58b/eval_ppl.py @@ -0,0 +1,72 @@ +# pylint: disable=missing-docstring, invalid-name +"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py.""" + +import math +import argparse +import torch +import random + +from eval_utils import get_test_dataset +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer + +from tqdm import tqdm + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", default=0, type=int) +parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) +parser.add_argument("--seqlen", default=2048, type=int) + + +def calulate_loss(model, input, loss_fct): + output = model(input, use_cache=False, output_hidden_states=False, output_attentions=False)[0] + shift_logits = output[:, :-1, :].contiguous() + shift_labels = input[:, 1:] + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + return loss + + +def main(args): + datasets = ["c4", "wikitext2"] + model = ( + BitnetForCausalLM.from_pretrained( + args.hf_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + with torch.no_grad(): + model._post_process_weights() + tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) + loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda() + + ppl = [] + for dataset in datasets: + testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen) + acc_loss, count = 0.0, 0 + progress = tqdm(range(len(testdata))) + for ii in progress: + input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) + loss = calulate_loss(model, input, loss_fct) + count += input.size(-1) - 1 + acc_loss += loss.item() + progress.set_description(f"avg_loss = {acc_loss / count / math.log(2)}") + + avg_loss = acc_loss / count / math.log(2) + ppl.append(2**avg_loss) + print("{} PPL: {}".format(dataset, ppl[-1])) + + print(ppl) + print("Avg PPL:", sum(ppl) / len(ppl)) + + +if __name__ == "__main__": + torch.set_grad_enabled(False) + args = parser.parse_args() + random.seed(args.seed) + torch.random.manual_seed(args.seed) + main(args) diff --git a/examples/bitnet-1.58b/eval_utils.py b/examples/bitnet-1.58b/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72480c392a7cfa40081546d2da19aa31463aab76 --- /dev/null +++ b/examples/bitnet-1.58b/eval_utils.py @@ -0,0 +1,135 @@ +# ruff: noqa +import torch + +import numpy as np +import torch.nn.functional as F + +from lm_eval.base import BaseLM +from datasets import load_dataset + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_test_dataset(dataset_name, tokenizer, seqlen=2048): + if dataset_name == "wikitext2": + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + testdata = "".join(testdata["text"]).split("\n") + elif dataset_name == "c4": + testdata = load_dataset("allenai/c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation")[ + "text" + ] + else: + raise NotImplementedError + + testdata = [item for item in testdata if item != ""] + tokenized_text = [tokenizer(item, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id] for item in testdata] + + data, doc = [], [tokenizer.bos_token_id] + for sen in tokenized_text: + if len(sen) > seqlen: + continue + if len(doc) + len(sen) > seqlen: + data.append(doc) + doc = [tokenizer.bos_token_id] + doc.extend(sen) + if len(doc) > 1 and len(doc) <= seqlen: + data.append(doc) + return data + + +class LMEvalAdaptor(BaseLM): + def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): + super().__init__() + + assert isinstance(batch_size, int) + + self.model_name = model_name + self.model = model + self.model.eval() + + self.tokenizer = tokenizer + + self.vocab_size = self.tokenizer.vocab_size + + self._batch_size = batch_size + + self._max_length = max_length + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length != -1: + return self._max_length + if hasattr(self.model.config, "n_ctx"): + return self.model.config.n_ctx + elif hasattr(self.model.config, "max_position_embeddings"): + return self.model.config.max_position_embeddings + elif hasattr(self.model.config, "n_positions"): + return self.model.config.n_positions + elif "bloom" in self.model_name: + return 2048 + elif "llama" in self.model_name: + return 2048 # TODO: did not check this + elif "mpt" in self.model_name: + return 2048 + elif "falcon" in self.model_name: + return 2048 + else: + print(self.model.config) + raise NotImplementedError + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + return self._batch_size + + @property + def device(self): + return "cuda" + + def tok_encode(self, string: str, add_special_tokens=True): + return self.tokenizer.encode(string, add_special_tokens=add_special_tokens) + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def loglikelihood(self, requests): + new_reqs = [] + for context, continuation in requests: + context, continuation = context.strip(), continuation.strip() + if context == "": + # end of text as context + context_enc = [self.eot_token_id] + else: + context_enc = self.tok_encode(context, add_special_tokens=True) + + continuation_enc = self.tok_encode(continuation, add_special_tokens=False) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs) + + def _model_call(self, inps): + """ + inps: a torch tensor of shape [batch, sequence] + the size of sequence may vary from call to call + + returns: a torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model + """ + with torch.no_grad(): + out = self.model(inps)[0] + return out + + def _model_generate(self, context, max_length, eos_token_id): + return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8b7b95cdb24f1bba466a8e776796b7ab025315 --- /dev/null +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py @@ -0,0 +1,262 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +import numpy as np + +from tilelang.transform import simplify_prim_func + +torch.manual_seed(42) + +decode_i2s_to_i8s = """template +__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + static constexpr uint MEDIAN_NUM = 0x02020202; +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsub4(i8s[i], MEDIAN_NUM); + } +} +template +__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + + +@simplify_prim_func +def bitnet_158_int8xint2_decode( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + fast_decoding=True, + n_partition=4, + reduce_thread=32, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + storage_nbit = 8 + num_bits = 2 + A_shape = (M, K) + B_shape = (N, K // storage_nbit * num_bits) + C_shape = (M, N) + + num_elems_per_byte = 4 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + micro_size_k_compressed = micro_size_k // num_elems_per_byte + storage_dtype = T.int8 + block_K = reduce_thread * micro_size_k + + use_dp4a = True + dp4a_size = 4 + + @T.prim_func + def kernel( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer(C_shape, out_dtype), + ): + with T.Kernel( + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), + ) as ( + bx, + by, + ): + A_local = T.alloc_local((micro_size_k,), in_dtype) + B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([micro_size_k], in_dtype) + accum_res = T.alloc_local((1,), accum_dtype) + reduced_accum_res = T.alloc_local((1,), accum_dtype) + + kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x") + ni = T.thread_binding(0, n_partition, thread="threadIdx.y") + + T.import_source(decode_i2s_to_i8s) + + T.clear(accum_res) + for ko in T.serial(T.ceildiv(K, block_K)): + for v in T.vectorized(micro_size_k): + A_local[v] = A[by, ko * block_K + kr * micro_size_k + v] + + for v in T.vectorized(micro_size_k_compressed): + B_quant_local[v] = B[ + bx * n_partition + ni, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, + ] + + T.call_extern( + "handle", + "decode_i2u_to_i8s", + T.address_of(B_quant_local[0]), + T.address_of(B_dequantize_local[0]), + ) + + if use_dp4a: + for ki in T.serial(micro_size_k // dp4a_size): + T.dp4a( + A_local[ki * dp4a_size], + B_dequantize_local[ki * dp4a_size], + accum_res[0], + ) + else: + for ki in T.serial(micro_size_k): + accum_res[0] += A_local[ki] * B_dequantize_local[ki] + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + accum_res[0], + True, + reduced_accum_res[0], + kr, + dtype="handle", + ) + ) + if kr == 0: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + + return kernel + + +def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == np.float16: + lowprecision_weight = lowprecision_weight.astype(dtype=np.int8) + int8_weight = np.zeros( + ( + *lowprecision_weight.shape[:-1], + lowprecision_weight.shape[-1] // elems_per_byte, + ), + dtype=np.int8, + ) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) + + return int8_weight.view(storage_dtype) + + +# interleave weight numpy implementation +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(np.int32) + new_qweight = np.zeros_like(qweight) + bits_stride = 8 if target_dtype == T.int8 else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == T.int8: + # special handling for 1b interleave + n16_weight = new_qweight & np.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(np.int8) + elif nbits == 2 and target_dtype == T.float16: + n8_weight = new_qweight & np.int32(0xFF0000FF) + n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(np.int8) + elif nbits == 1 and target_dtype == T.float16: + n8_weight = new_qweight & 0xF000000F + n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 + n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 + n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 + n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 + n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 + n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + + return new_qweight.view(np.int8) + + +def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): + program = bitnet_158_int8xint2_decode(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) + print(program) + kernel = tilelang.compile(program) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + qw = general_compress(B.cpu().numpy(), source_bits=2, storage_dtype=np.int8) + qw = interleave_weight(qw, 2, target_dtype=in_dtype) + qw = torch.from_numpy(qw).to(device="cuda") + + kernel(A, qw, C) + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py new file mode 100644 index 0000000000000000000000000000000000000000..8c337398233f32905bd4dd929490287d38660126 --- /dev/null +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -0,0 +1,385 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +from tilelang.intrinsics.mma_layout import ( + make_mma_swizzle_layout as make_swizzle_layout, +) +import numpy as np + +from tilelang.intrinsics.mma_macro_generator import ( + INT4TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func + +torch.manual_seed(42) + +decode_i2s_to_i8s = """template +__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + static constexpr uint MEDIAN_NUM = 0x02020202; +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsub4(i8s[i], MEDIAN_NUM); + } +} +template +__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + + +@simplify_prim_func +def bitnet_158_int8xint2_prefill( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + fast_decoding=True, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=32, + warp_col_tiles=32, + chunk=64, +): + """ + Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved lowโ€‘precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C. + + The returned prim_func expects: + - A: shape (M, K) with dtype `in_dtype` (T.float16 or T.int8). + - B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte). + - C: output buffer shape (M, N) with dtype `out_dtype` (T.float16, T.float32, or T.int32). + + Details: + - Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter. + - Tiling parameters: + - block_row_warps, block_col_warps: number of warps per block in row/col. + - warp_row_tiles, warp_col_tiles: tiles per warp. + - chunk: K-sized chunk per block (block_K). + - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == T.int32). + - Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior. + - Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values. + + Parameters: + M, N, K (int): Global matrix dimensions. + in_dtype (str): Input and decoded B element dtype; T.float16 or T.int8. + out_dtype (str): Output C dtype; one of T.float16, T.float32, T.int32. + accum_dtype (str): Accumulator dtype used by MMA (e.g., T.int32). + fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used). + block_row_warps (int): Warps in block row dimension. + block_col_warps (int): Warps in block column dimension. + warp_row_tiles (int): Tiles per warp in row dimension. + warp_col_tiles (int): Tiles per warp in column dimension. + chunk (int): K-length per block (block_K). + + Returns: + T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution. + """ + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == T.int32: + micro_size_k = 32 + + num_elems_per_byte = 4 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + shared_scope = "shared.dyn" + storage_dtype = T.int8 + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K // num_elems_per_byte) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + """ + GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C. + + This kernel: + - Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory. + - Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine. + - Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages. + - Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing. + + Parameters: + A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations. + B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel. + C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C). + + Side effects: + Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation. + """ + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i8s, + ) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) + + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_frag) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): + B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = i * threads * local_size_compressed + thread_bindings * local_size_compressed + v + vi, vj = T.index_to_coordinates(index, B_shared_shape) + B_local[v] = B_shared[vi, vj] + + T.call_extern( + "handle", + "decode_i2u_to_i8s", + T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), + ) + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + thread_bindings * local_size + v + vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Perform STMatrix + mma_emitter.stmatrix( + C_frag, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == np.float16: + lowprecision_weight = lowprecision_weight.astype(dtype=np.int8) + int8_weight = np.zeros( + ( + *lowprecision_weight.shape[:-1], + lowprecision_weight.shape[-1] // elems_per_byte, + ), + dtype=np.int8, + ) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) + + return int8_weight.view(storage_dtype) + + +# interleave weight numpy implementation +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(np.int32) + new_qweight = np.zeros_like(qweight) + bits_stride = 8 if target_dtype == T.int8 else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == T.int8: + # special handling for 1b interleave + n16_weight = new_qweight & np.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(np.int8) + elif nbits == 2 and target_dtype == T.float16: + n8_weight = new_qweight & np.int32(0xFF0000FF) + n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(np.int8) + elif nbits == 1 and target_dtype == T.float16: + n8_weight = new_qweight & 0xF000000F + n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 + n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 + n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 + n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 + n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 + n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + + return new_qweight.view(np.int8) + + +def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): + program = bitnet_158_int8xint2_prefill(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) + print(program) + kernel = tilelang.compile(program) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + qw = general_compress(B.cpu().numpy(), source_bits=2, storage_dtype=np.int8) + qw = interleave_weight(qw, 2, target_dtype=in_dtype) + qw = torch.from_numpy(qw).to(device="cuda") + + kernel(A, qw, C) + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d35df4b268e617f4d12e374b63dd51c7b3b071 --- /dev/null +++ b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py @@ -0,0 +1,220 @@ +import torch +import torch.backends +from bitblas import tvm as tvm +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from bitblas.base import simplify_prim_func + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + if in_dtype == T.int8: + A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + # assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) + assert_tl_matmul_correctness(16384, 16384, 16384, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/load_from_quantized.py b/examples/bitnet-1.58b/load_from_quantized.py new file mode 100644 index 0000000000000000000000000000000000000000..8c775aa4c8e819ee3ac800fce4ebe0452fac54be --- /dev/null +++ b/examples/bitnet-1.58b/load_from_quantized.py @@ -0,0 +1,71 @@ +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +import os +from transformers import GenerationConfig +import time + +filepath = os.path.abspath(__file__) +dirpath = os.path.dirname(filepath) + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" +saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def main(): + # load quantized model + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) + tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) + # print("original model generated text:") + # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + input_ids = torch.ones((1, 1), dtype=torch.long).cuda() + # naive model inference + output = qmodel(input_ids) + print("original model output:", output) + print("quantized model generated text:") + print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) + + +if __name__ == "__main__": + main() diff --git a/examples/bitnet-1.58b/maint/README.md b/examples/bitnet-1.58b/maint/README.md new file mode 100644 index 0000000000000000000000000000000000000000..63cc3e275f18b8bec8e96eabc49c1e812218aee3 --- /dev/null +++ b/examples/bitnet-1.58b/maint/README.md @@ -0,0 +1,91 @@ +--- +license: mit +--- + + +This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`. + +## Latest News + +- 08/09/2024 โœจ: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt and study how to deploy, please checkout [Make Checkpoints for vLLM](#make-checkpoints-for-vllm). + +## Make Checkpoints for vLLM + +We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension. + +```bash +# move to the integration directory +cd /root/to/BitBLAS/integration/BitNet +# make the checkpoint +./maint/generate_bitnet_model_native_format.sh +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory +``` + +The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization. + +```bash +./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas +# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory +``` + +Finnaly, you can use the ckpt in vLLM with: + +```bash +cd vllm_workspace +# inference with the ckpt with fp16 uncompressed metadata +python3 inference_with_native_format.py +# inference with the ckpt with BitBLAS compressed metadata +python3 inference_with_bitblas_format.py +``` + +## BitBLAS Results + +### Performance + +**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo. + +| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas | +|:---------------:|:------:|:---------:|:------:|:--------:|:---------------------------:|:-----------------------:| +| bitnet_b1_58-3B | A100 | 1 | 1 | LLAMA-3B | 177.6729107 | 64.17962909 | +| bitnet_b1_58-3B | A100 | 128 | 1 | LLAMA-3B | 188.6145592 | 63.48158518 | +| bitnet_b1_58-3B | A100 | 1 | 2048 | LLAMA-3B | 348.7066031 | 202.6877999 | + +### On-the-Fly GPU Memory Footprint + +We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage. + +| **Model** | **Device** | **batchsize** | **in_seq** | **bitnet-1.58b-3b-huggingface** | **bitnet-1.58b-3b-bitblas** | +|:---------------:|:----------:|:-------------:|:----------:|:-------------------------------:|:---------------------------:| +| bitnet_b1_58-3B | A100 | 1 | 1 | 7595 MB | 1729 MB | +| bitnet_b1_58-3B | A100 | 128 | 1 | 7677 MB | 1789 MB | +| bitnet_b1_58-3B | A100 | 1 | 2048 | 8731 MB | 3163 MB | + +## PPL and Zero-shot Accuracy + +The number is Reported from the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B), Please checkout the `eval_ppl.py`. + +PPL and zero-shot accuracy: +| Models | PPL| ARCe| ARCc| HS | BQ | OQ | PQ | WGe | Avg +|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| +| FP16 700M (reported) | 12.33 | 54.7 | 23.0 | 37.0 | 60.0 | 20.2 | 68.9 | 54.8 | 45.5 | +| BitNet b1.58 700M (reported) | 12.87 | 51.8 | 21.4 | 35.1 | 58.2 | 20.0 | 68.1 | 55.2 | 44.3 | +| BitNet b1.58 700M (reproduced) | 12.78 | 51.4 | 21.8 | 35.0 | 59.6 | 20.6 | 67.5 | 55.4 | 44.5 | +| FP16 1.3B (reported) | 11.25 | 56.9 | 23.5 | 38.5 | 59.1 | 21.6 | 70.0 | 53.9 | 46.2 +| BitNet b1.58 1.3B (reported) | 11.29 | 54.9 | 24.2 | 37.7 | 56.7 | 19.6 | 68.8 | 55.8 | 45.4 | +| BitNet b1.58 1.3B (reproduced) | 11.19 | 55.8 | 23.7 | 37.6 | 59.0 | 20.2 | 69.2 | 56.0 | 45.9 +| FP16 3B (reported) | 10.04 | 62.1 | 25.6 | 43.3 | 61.8 | 24.6 | 72.1 | 58.2 | 49.7 +| BitNet b1.58 3B (reported) | 9.91 | 61.4 | 28.3 | 42.9 | 61.5 | 26.6 | 71.5 | 59.3 | 50.2 +| BitNet b1.58 3B (reproduced) | 9.88 | 60.9 | 28.0 | 42.3 | 58.3 | 26.0 | 71.4 | 60.3 | 49.6 | + +The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors. + +## Citations + +```bibtex +@article{ma2024era, + title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits}, + author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu}, + journal={arXiv preprint arXiv:2402.17764}, + year={2024} +} +``` \ No newline at end of file diff --git a/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..2604ef38770fa58fa80cf87709e0b205eae26ecd --- /dev/null +++ b/examples/bitnet-1.58b/maint/create_bitblas_ckpt.py @@ -0,0 +1,130 @@ +import argparse +import torch +import bitblas +from transformers.utils.hub import cached_file +import os +from transformers import GenerationConfig +import time +import json + +import sys + +sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + "/../") +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer + +filepath = os.path.abspath(__file__) +dirpath = os.path.dirname(filepath) + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +parser = argparse.ArgumentParser() +parser.add_argument("--model_name_or_path", type=str, default="1bitLLM/bitnet_b1_58-3B") +parser.add_argument("--saved_model_path", type=str, default=None) +args = parser.parse_args() + +model_name_or_path = args.model_name_or_path +saved_model_path = ( + os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path +) + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def main(): + model = ( + BitnetForCausalLM.from_pretrained( + model_name_or_path, + use_flash_attention_2=False, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) + + # print("original model generated text:") + # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + input_ids = torch.ones((1, 1), dtype=torch.long).cuda() + # naive model inference + output = model(input_ids) + print("original model output:", output) + + model.quantize(fuse_qkv=True, fuse_gateup=True) + print("original model generated text:") + print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + + model.save_pretrained(saved_model_path) + + # load quant config + quant_config_path = cached_file(model_name_or_path, "quantize_config.json") + with open(quant_config_path, "r") as f: + quant_config = json.load(f) + print("quant config:") + print(quant_config) + quant_config["checkpoint_format"] = "bitblas" + quant_config["fuse_qkv"] = True + quant_config["fuse_gateup"] = True + + # save quant config + quant_config_path = os.path.join(saved_model_path, "quantize_config.json") + with open(quant_config_path, "w") as f: + json.dump(quant_config, f) + print("quant config saved to:", quant_config_path) + + # copy benchmark filed into saved model path + file_list = [ + "configuration_bitnet.py", + "eval_utils.py", + "modeling_bitnet.py", + "tokenization_bitnet.py", + "utils_quant.py", + "README.md", + ] + for file in file_list: + file_path = cached_file(model_name_or_path, file) + os.system(f"cp {file_path} {saved_model_path}") + # load quantized model + qmodel = ( + BitnetForCausalLM.from_quantized( + saved_model_path, + ) + .cuda() + .half() + ) + print("quantized model generated text:") + print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) + + +if __name__ == "__main__": + main() diff --git a/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh b/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh new file mode 100755 index 0000000000000000000000000000000000000000..741c3a124a54bcf4206104b2034771de93a30aea --- /dev/null +++ b/examples/bitnet-1.58b/maint/generate_bitnet_model_bitblas_format.sh @@ -0,0 +1,31 @@ +# retrieve the native model input and saved model directory +MODEL_DIR=$1 +SAVED_MODEL_DIR=$2 + +# check if the model directory exists +if [ ! -d "$MODEL_DIR" ]; then + echo "Model directory does not exist!" + exit 1 +fi + +# if the saved model directory does not exist, create it +# if SAVED_MODEL_DIR is not provided, we do not pass it to the script +if [ -z "$SAVED_MODEL_DIR" ]; then + python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR +else + if [ ! -d "$SAVED_MODEL_DIR" ]; then + mkdir -p $SAVED_MODEL_DIR + fi + python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR --saved_model_path $SAVED_MODEL_DIR +fi + +# get the realpath of the saved model directory +SAVED_MODEL_DIR=$(realpath $SAVED_MODEL_DIR) + +# cp files +cp $MODEL_DIR/quantize_config.json $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer.json $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer.model $SAVED_MODEL_DIR/ +cp $MODEL_DIR/tokenizer_config.json $SAVED_MODEL_DIR/ + +echo "Model has been converted and save to $SAVED_MODEL_DIR" diff --git a/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh b/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh new file mode 100755 index 0000000000000000000000000000000000000000..a2df0eb8cb2e057b751e572e1aa58c2532aece27 --- /dev/null +++ b/examples/bitnet-1.58b/maint/generate_bitnet_model_native_format.sh @@ -0,0 +1,25 @@ +# require git lfs +if ! command -v git-lfs &> /dev/null; then + echo "Please install git-lfs first by running 'sudo apt install git-lfs'" + exit 1 +fi + +mkdir -p models + +cd models + +# download the model +git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B ckpt_bitnet_b1_58-3B --depth 1 + +# copy quantized config into the model directory +cp ../maint/quantize_config.json ckpt_bitnet_b1_58-3B + +# copy README.md into the model directory +cp ../maint/README.md ckpt_bitnet_b1_58-3B + +# get the realpath of the model directory +MODEL_DIR=$(realpath ckpt_bitnet_b1_58-3B) + +cd .. + +echo "Model has been converted and save to $MODEL_DIR" diff --git a/examples/bitnet-1.58b/maint/quantize_config.json b/examples/bitnet-1.58b/maint/quantize_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e2b24123a125ebaf3c4b056e8e6546801fbac4dc --- /dev/null +++ b/examples/bitnet-1.58b/maint/quantize_config.json @@ -0,0 +1,10 @@ +{ + "bits": 2, + "desc_act": false, + "static_groups": false, + "sym": true, + "lm_head": false, + "model_name_or_path": "1bitLLM/bitnet_b1_58-3B", + "quant_method": "bitnet", + "checkpoint_format": "bitnet" +} \ No newline at end of file diff --git a/examples/bitnet-1.58b/maint/upload_models.sh b/examples/bitnet-1.58b/maint/upload_models.sh new file mode 100755 index 0000000000000000000000000000000000000000..b764b0da67a9b69d66a0e2a430356751de9df1e1 --- /dev/null +++ b/examples/bitnet-1.58b/maint/upload_models.sh @@ -0,0 +1,34 @@ +MODEL_DIR=$1 +REMOTE_DIR=$2 + +if [ ! -d "$MODEL_DIR" ]; then + echo "Model directory does not exist!" + exit 1 +fi + +cd $MODEL_DIR +if [ ! -d ".git" ]; then + rm -rf .git +fi + +git init + +git checkout -b main + +git lfs install + +git lfs track *.bin + +git lfs track *.safetensors + +git add . + +git commit -m "Initial commit" + +git remote add origin $REMOTE_DIR + +huggingface-cli lfs-enable-largefiles . + +git fetch origin + +git push -f --set-upstream origin main diff --git a/examples/bitnet-1.58b/modeling_bitnet.py b/examples/bitnet-1.58b/modeling_bitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1830995ee6177536089fe517646b290c18bb05f2 --- /dev/null +++ b/examples/bitnet-1.58b/modeling_bitnet.py @@ -0,0 +1,1686 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaMA model.""" + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from configuration_bitnet import BitnetConfig +from utils_quant import BitLinear, BitLinearBitBLAS +from transformers.utils.hub import cached_file + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 + + +def find_layers(module, layers=None, name=""): + if not layers: + layers = [nn.Linear] + for layer in layers: + if isinstance(module, layer): + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + return res + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BitnetConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class BitnetRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BitnetRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(BitnetRMSNorm) + + +class BitnetRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class" + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class" + ) + return self._cos_cached + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class BitnetMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = BitLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.up_proj = BitLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.down_proj = BitLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.act_fn = ACT2FN[config.hidden_act] + self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps) + + def forward(self, x): + x = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + x = self.ffn_layernorm(x) + x = self.down_proj(x) + return x + + +class BitnetMLPFuseGateUp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = BitLinear( + self.hidden_size, + self.intermediate_size * 2, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.down_proj = BitLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.act_fn = ACT2FN[config.hidden_act] + self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps) + + @classmethod + def from_bit_mlp(cls, bit_mlp: BitnetMLP): + module = cls(bit_mlp.config) + # assign the weights + module.gate_up_proj.weight = nn.Parameter(torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) + module.down_proj = bit_mlp.down_proj + module.ffn_layernorm = bit_mlp.ffn_layernorm + return module + + def forward(self, x): + gate_up = self.gate_up_proj(x) + gate, up = torch.chunk(gate_up, chunks=2, dim=-1) + x = self.act_fn(gate) * up + x = self.ffn_layernorm(x) + x = self.down_proj(x) + return x + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class BitnetAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) + + self.q_proj = BitLinear( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.k_proj = BitLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.v_proj = BitLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.o_proj = BitLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self._init_rope() + self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = BitnetRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise NotImplementedError + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BitnetAttentionQKVFused(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})." + ) + + self.qkv_proj = BitLinear( + self.hidden_size, + self.num_heads * self.head_dim + (self.num_key_value_heads * self.head_dim) * 2, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.o_proj = BitLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self._init_rope() + self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = BitnetRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise NotImplementedError + + @classmethod + def from_bit_attention(cls, bit_attention: BitnetAttention): + module = cls(bit_attention.config, bit_attention.layer_idx) + # assign the weights + module.qkv_proj.weight = nn.Parameter( + torch.cat([bit_attention.q_proj.weight, bit_attention.k_proj.weight, bit_attention.v_proj.weight], dim=0) + ) + if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None: + module.qkv_proj.bias = nn.Parameter( + torch.cat([bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias], dim=0) + ) + module.o_proj = bit_attention.o_proj + module.inner_attn_ln = bit_attention.inner_attn_ln + if bit_attention.config.rope_scaling is None: + module.rotary_emb = bit_attention.rotary_emb + return module + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv_states = self.qkv_proj(hidden_states) + query_states, key_states, value_states = torch.split( + qkv_states, + [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], + dim=-1, + ) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BitnetFlashAttention2(BitnetAttention): + """ + Bitnet flash attention module. This module inherits from `BitnetAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (BitnetRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in BitnetFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +LLAMA_ATTENTION_CLASSES = { + "eager": BitnetAttention, + "flash_attention_2": BitnetFlashAttention2, +} + + +class BitnetDecoderLayer(nn.Module): + def __init__(self, config: BitnetConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = BitnetMLP(config) + self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`", + stacklevel=2, + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BitnetConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class BitnetPreTrainedModel(PreTrainedModel): + config_class = BitnetConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BitnetDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): + if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + for layer in self.model.layers: + device = layer.input_layernorm.weight.device + if hasattr(self.config, "_pre_quantization_dtype"): + dtype = self.config._pre_quantization_dtype + else: + dtype = layer.self_attn.o_proj.weight.dtype + layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) + + def _reset_cache(self): + for layer in self.model.layers: + layer.self_attn.past_key_value = None + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`BitnetTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`BitnetTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class BitnetModel(BitnetPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BitnetDecoderLayer`] + + Args: + config: BitnetConfig + """ + + def __init__(self, config: BitnetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + if use_cache and not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = mask_slice + + return causal_mask + + +class BitnetForCausalLM(BitnetPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = BitnetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.quantized = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import LlamaTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Bitnet-2-7b-hf") + >>> tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Bitnet-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs + ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None + + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length: + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids") + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + else: + cache_position = cache_position[-input_length:] + + if has_static_cache: + past_key_values = None + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) + return reordered_past + + @staticmethod + def recursive_set(model, name, attr): + """ + set layers.25.mlp.up_proj to attr + """ + + names = name.split(".") + obj = model + for n in names[:-1]: + obj = getattr(obj, n) + setattr(obj, names[-1], attr) + + def quantize(self, fuse_qkv=True, fuse_gateup=True): + for name, module in self.model.named_modules(): + # if is bitnet layer + if fuse_qkv and isinstance(module, BitnetAttention): + # create quantized version of the layer + print("Replacing BitnetAttention", name) + bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module) + self.recursive_set(self.model, name, bitnet_attenion_qkv_fused) + if fuse_gateup and isinstance(module, BitnetMLP): + # create quantized version of the layer + print("Replacing BitnetMLP", name) + bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module) + self.recursive_set(self.model, name, bitnet_mlp_fused) + for name, module in self.model.named_modules(): + # if is bitnet layer + if isinstance(module, BitLinear): + # create quantized version of the layer + print("Quantizing module", name) + if name.endswith(".qkv_proj"): + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=3) + elif name.endswith(".gate_up_proj"): + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=2) + else: + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) + print("Replacing module", name, "with a quantized version") + self.recursive_set(self.model, name, bitblas_linear) + self.quantized = True + + def _post_process_weights(self): + for name, module in self.model.named_modules(): + if hasattr(module, "post_process_weights"): + print("Post processing weights for module", name) + module.post_process_weights() + + def _replace_weight_param_with_qweight(self): + for name, module in self.model.named_modules(): + if hasattr(module, "replace_weight_param_with_qweight"): + print("Replacing weight param with qweight for module", name) + module.replace_weight_param_with_qweight() + + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + # == step1: prepare configs and file names == # + config: BitnetConfig = BitnetConfig.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + **cached_file_kwargs, + ) + # load quantize config + quantize_file = cached_file(model_name_or_path, "quantize_config.json") + assert quantize_file is not None, "quantize config file not found" + import json + + # get quantize format + with open(quantize_file, "r") as f: + quant_config = json.load(f) + checkpoint_format = quant_config["checkpoint_format"] + assert checkpoint_format in ["bitblas"], "quantize format not supported" + fuse_qkv = quant_config.get("fuse_qkv", True) + fuse_gateup = quant_config.get("fuse_gateup", True) + + import accelerate + + if checkpoint_format == "bitblas": + model = cls(config) + for name, module in model.named_modules(): + # if is bitnet layer + if fuse_qkv and isinstance(module, BitnetAttention): + # create quantized version of the layer + print("Replacing BitnetAttention", name) + bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module) + model.recursive_set(model, name, bitnet_attenion_qkv_fused) + if fuse_gateup and isinstance(module, BitnetMLP): + # create quantized version of the layer + print("Replacing BitnetMLP", name) + bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module) + model.recursive_set(model, name, bitnet_mlp_fused) + for name, module in model.named_modules(): + if isinstance(module, BitLinear): + # create quantized version of the layer + print("Quantizing module", name) + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) + print("Replacing module", name, "with a quantized version") + model.recursive_set(model, name, bitblas_linear) + accelerate.utils.modeling.load_checkpoint_in_model( + model, + checkpoint=model_name_or_path, + offload_state_dict=True, + offload_buffers=True, + ) + return model + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`BitnetForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class BitnetForSequenceClassification(BitnetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = BitnetModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Bitnet Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class BitnetForQuestionAnswering(BitnetPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Bitnet + def __init__(self, config): + super().__init__(config) + self.transformer = BitnetModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labeled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labeled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/bitnet-1.58b/nvidia_measure_memory.sh b/examples/bitnet-1.58b/nvidia_measure_memory.sh new file mode 100755 index 0000000000000000000000000000000000000000..e8998f3092bc4a7ea9d3539a7625169365133488 --- /dev/null +++ b/examples/bitnet-1.58b/nvidia_measure_memory.sh @@ -0,0 +1 @@ +nvidia-smi --query-gpu=memory.used --format=csv -lms 500 diff --git a/examples/bitnet-1.58b/requirements.txt b/examples/bitnet-1.58b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..67357781e0a2afd5bd550329ec5c756f09f4b6b6 --- /dev/null +++ b/examples/bitnet-1.58b/requirements.txt @@ -0,0 +1,3 @@ +lm_eval==0.3.0 +flash_attn +transformers==4.53.0 diff --git a/examples/bitnet-1.58b/tokenization_bitnet.py b/examples/bitnet-1.58b/tokenization_bitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2adfd6dee10e6d0fba443e14c7b828e73b378554 --- /dev/null +++ b/examples/bitnet-1.58b/tokenization_bitnet.py @@ -0,0 +1,468 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for LLaMA.""" + +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "hf-internal-testing/llama-tokenizer": 2048, +} +SPIECE_UNDERLINE = "โ–" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class BitnetTokenizer(PreTrainedTokenizer): + """ + Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Bitnet should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + legacy=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behavior of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565" + ) + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['โ–He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', 'โ–Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id + + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + @property + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + + The reference for this chat template is [this code + snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) + in the original repository. + """ + logger.warning_once( + "\nNo chat template is defined for this tokenizer - using the default template " + f"for the {self.__class__.__name__} class. If the default is not appropriate for " + "your model, please set `tokenizer.chat_template` to an appropriate template. " + "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n" + ) + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template diff --git a/examples/bitnet-1.58b/utils_quant.py b/examples/bitnet-1.58b/utils_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..5a50edb392ead6d55c9e34f19409cfb94848f13a --- /dev/null +++ b/examples/bitnet-1.58b/utils_quant.py @@ -0,0 +1,230 @@ +# pylint: disable=missing-docstring, invalid-name +"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py to work with BitBLAS.""" + +import torch +from torch import nn +from bitblas.cache import global_operator_cache, get_database_path +from bitblas import Matmul, MatmulConfig +from bitblas import auto_detect_nvidia_target +from logging import getLogger + +logger = getLogger(__name__) +BITBLAS_TARGET = auto_detect_nvidia_target() +BITBLAS_DATABASE_PATH = get_database_path() + + +def weight_quant(weight, num_bits=1): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + +def activation_quant(x, num_bits=8): + dtype = x.dtype + x = x.float() + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) / s + return result.type(dtype) + + +class BitLinearBitBLAS(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + weight_bits=1, + input_bits=8, + **kwargs, + ): + super().__init__() + """ + RMSNorm is placed outside BitLinear + """ + self.in_features = in_features + self.out_features = out_features + self.weight_bits = weight_bits + self.input_bits = input_bits + matmul_config = MatmulConfig( + N=self.out_features, # N dimension + K=self.in_features, # K dimension + A_dtype="int8", # activation A dtype + W_dtype="int2", # weight W dtype + accum_dtype="int32", # accumulation dtype + out_dtype="float32", # output dtype + layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose + with_bias=False, # bias + # configs for weight only quantization + group_size=None, # setting for grouped quantization + with_scaling=False, # setting for scaling factor + with_zeros=False, # setting for zeros + zeros_mode=None, # setting for how to calculating zeros + ) + ENABLE_TUNING = True + self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING) + + self.format = "bitnet" + self.Qp = 2 ** (self.input_bits - 1) - 1 + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + logger.info(f"Loaded {global_operator_cache.size()} operators from database.") + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + # should disable tuning for the first time because we may require loading bitblas operator from database. + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + print("BitBLAS Tuning done, appended operator to global_operator_cache.") + else: + print("BitBLAS Operator created.") + else: + print("BitBLAS Operator found in global_operator_cache.") + return bitblas_matmul + + def replace_weight_param_with_qweight(self): + if hasattr(self, "weight"): + del self.weight + quant_weight = torch.empty(self.bitblas_matmul.retrieve_weight_shape()) + self.qweight = nn.Parameter(quant_weight, requires_grad=False) + self.format = "bitblas" + + @classmethod + def from_bit_linear(cls, bitlinear, weight_group=1): + bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group) + bitblas_linear.register_buffer("qweight", qweight) + bitblas_linear.register_buffer("sw", sw) + if bitlinear.bias is not None: + bitblas_linear.register_buffer("bias", bitlinear.bias) + else: + bitblas_linear.bias = None + return bitblas_linear + + def create_bitblas_weights(self, weight, weight_group=1): + if weight_group: + hidden_size = weight.size(0) + group_size = hidden_size // weight_group + + sw_list = [] + qweight_list = [] + + for i in range(weight_group): + start_idx = i * group_size + end_idx = (i + 1) * group_size + + sw = 1 / weight[start_idx:end_idx].abs().mean().clamp(min=1e-5) + sw_list.append(sw.repeat(group_size)) + + qweight = self.weight_quant(weight[start_idx:end_idx]).detach() + qweight_list.append(qweight) + + sw = torch.cat(sw_list, dim=0) + qweight = torch.cat(qweight_list, dim=0) + else: + sw = 1 / weight.abs().mean().clamp(min=1e-5) + qweight = self.weight_quant(weight).detach() + qweight = self.bitblas_matmul.transform_weight(qweight) + qweight = nn.Parameter(qweight, requires_grad=False) + return sw, qweight + + def post_process_weights(self): + sw = 1 / self.weight.abs().mean().clamp(min=1e-5) + self.sw = sw + quant_weight = self.weight_quant(self.weight).detach() + quant_weight = self.bitblas_matmul.transform_weight(quant_weight) + # remove self.weight and replace it with quant_weight + if hasattr(self, "weight"): + del self.weight + self.qweight = nn.Parameter(quant_weight, requires_grad=False) + self.format = "bitblas" + + @staticmethod + def weight_quant(weight): + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) + return result.type(torch.int8) + + @torch.compile + def activation_quant(self, x, num_bits=8): + x = x.float() + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) + return result.type(torch.int8), s + + @torch.compile + def post_quant_process(self, input, si, sw): + out = input / si + out = out / sw + out = out.half() + return out + + # for the correctness evaluation. + def native_forward(self, input): + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) + return out + + def forward_fp32_simulated(self, input): + quant_input, si = self.activation_quant(input, self.input_bits).detach() + quant_weight = self.weight_quant(self.weight).detach() + + fp32_simulated_input = quant_input.float() + fp32_simulated_weight = quant_weight.float() + fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight) + + sw = 1 / self.weight.abs().mean().clamp(min=1e-5) + # if / (si * sw) it will inf in some cases + out = fp32_simulated_out / si + out = out / sw + out = out.half() + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) + return out + + def forward(self, input): + # return self.forward_fp32_simulated(input) + quant_input, si = self.activation_quant(input, self.input_bits) + fp32_out = self.bitblas_matmul(quant_input, self.qweight) + sw = self.sw + # if / (si * sw) it will inf in some cases + out = self.post_quant_process(fp32_out, si, sw) + + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) + return out + + +# Naive BitLinear from HuggingFace +class BitLinear(nn.Linear): + def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): + super(BitLinear, self).__init__(*kargs, **kwargs) + """ + RMSNorm is placed outside BitLinear + """ + self.weight_bits = weight_bits + self.input_bits = input_bits + + def forward(self, input): + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) + + return out diff --git a/examples/bitnet-1.58b/vllm_workspace/conftest.py b/examples/bitnet-1.58b/vllm_workspace/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e2997ef67c5c22b26235d00000332dfe20910f --- /dev/null +++ b/examples/bitnet-1.58b/vllm_workspace/conftest.py @@ -0,0 +1,587 @@ +import contextlib +import gc +import os +import sys +from collections import UserList +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoTokenizer, + BatchEncoding, +) + +from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset +from vllm.config import TokenizerPoolConfig +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel +from vllm.inputs import TextPrompt +from vllm.logger import init_logger +from vllm.sequence import SampleLogprobs +from vllm.utils import cuda_device_count_stateless, is_cpu + +logger = init_logger(__name__) + +_TEST_DIR = os.path.dirname(__file__) +_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] +_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] + + +def _read_prompts(filename: str) -> List[str]: + with open(filename, "r") as f: + prompts = f.readlines() + return prompts + + +class _ImageAssetPrompts(TypedDict): + stop_sign: str + cherry_blossom: str + + +if sys.version_info < (3, 9): + # UserList cannot be subscripted + class _ImageAssetsBase(UserList): + pass + +else: + + class _ImageAssetsBase(UserList[ImageAsset]): + pass + + +class _ImageAssets(_ImageAssetsBase): + def __init__(self) -> None: + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) + + def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: + """ + Convenience method to define the prompt for each test image. + + The order of the returned prompts matches the order of the + assets when iterating through this object. + """ + return [prompts["stop_sign"], prompts["cherry_blossom"]] + + +IMAGE_ASSETS = _ImageAssets() +"""Singleton instance of :class:`_ImageAssets`.""" + + +def cleanup(): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + if not is_cpu(): + torch.cuda.empty_cache() + + +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + """Allow subdirectories to skip global cleanup by overriding this fixture. + This can provide a ~10x speedup for non-GPU unit tests since they don't need + to initialize torch. + """ + + if not request.node.get_closest_marker("skip_global_cleanup"): + return False + + +@pytest.fixture(autouse=True) +def cleanup_fixture(should_do_global_cleanup_after_test: bool): + yield + if should_do_global_cleanup_after_test: + cleanup() + + +@pytest.fixture +def example_prompts() -> List[str]: + prompts = [] + for filename in _TEST_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture +def example_long_prompts() -> List[str]: + prompts = [] + for filename in _LONG_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture(scope="session") +def image_assets() -> _ImageAssets: + return IMAGE_ASSETS + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, +} + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) + + +class HfRunner: + def wrap_device(self, input: _T) -> _T: + if not is_cpu(): + return input.to("cuda") + else: + return input.to("cpu") + + def __init__( + self, + model_name: str, + dtype: str = "half", + *, + model_kwargs: Optional[Dict[str, Any]] = None, + is_embedding_model: bool = False, + is_vision_model: bool = False, + is_sparseml_model: bool = False, + ) -> None: + assert dtype in _STR_DTYPE_TO_TORCH_DTYPE + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + + self.model_name = model_name + + if is_embedding_model: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + + self.model = self.wrap_device( + SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype) + ) + else: + if is_vision_model: + auto_cls = AutoModelForVision2Seq + elif is_sparseml_model: + from sparseml.transformers import SparseAutoModelForCausalLM + + auto_cls = SparseAutoModelForCausalLM + else: + auto_cls = AutoModelForCausalLM + + model_kwargs = model_kwargs if model_kwargs is not None else {} + self.model = self.wrap_device( + auto_cls.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + **model_kwargs, + ) + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead.", + model_name, + ) + self.processor = self.tokenizer + + def generate( + self, + prompts: List[str], + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[List[int]], List[str]]]: + if images: + assert len(prompts) == len(images) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + + output_ids = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + **kwargs, + ) + output_str = self.processor.batch_decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output_ids = output_ids.cpu().tolist() + outputs.append((output_ids, output_str)) + return outputs + + def generate_greedy( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[int], str]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + images=images, + **kwargs, + ) + + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width, + ) + for i in range(len(outputs)): + output_ids, output_str = outputs[i] + for j in range(len(output_ids)): + output_ids[j] = [x for x in output_ids[j] if x != self.tokenizer.pad_token_id] + outputs[i] = (output_ids, output_str) + return outputs + + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[List[torch.Tensor]]: + all_logprobs: List[List[torch.Tensor]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + seq_logprobs: List[torch.Tensor] = [] + for hidden_states in output.hidden_states: + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if self.model.get_output_embeddings().bias is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze(0) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + all_logprobs.append(seq_logprobs) + return all_logprobs + + def generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[List[Image.Image]] = None, + **kwargs: Any, + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for i, prompt in enumerate(prompts): + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + input_ids = inputs.input_ids + + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + + seq_logprobs: List[torch.Tensor] = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", None) is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze(0) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst: List[Dict[int, float]] = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRunner + + +class VllmRunner: + def __init__( + self, + model_name: str, + tokenizer_name: Optional[str] = None, + # Use smaller max model length, otherwise bigger model cannot run due + # to kv cache size limit. + max_model_len: int = 1024, + dtype: str = "half", + disable_log_stats: bool = True, + tensor_parallel_size: int = 1, + block_size: int = 16, + enable_chunked_prefill: bool = False, + swap_space: int = 4, + enforce_eager: bool = False, + **kwargs, + ) -> None: + self.model = LLM( + model=model_name, + tokenizer=tokenizer_name, + trust_remote_code=True, + dtype=dtype, + swap_space=swap_space, + enforce_eager=enforce_eager, + disable_log_stats=disable_log_stats, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, + **kwargs, + ) + + def generate( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[List[int]], List[str]]]: + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for req_output in req_outputs: + prompt_str = req_output.prompt + prompt_ids = req_output.prompt_token_ids + req_sample_output_ids: List[List[int]] = [] + req_sample_output_strs: List[str] = [] + for sample in req_output.outputs: + output_str = sample.text + output_ids = list(sample.token_ids) + req_sample_output_ids.append(prompt_ids + output_ids) + req_sample_output_strs.append(prompt_str + output_str) + outputs.append((req_sample_output_ids, req_sample_output_strs)) + return outputs + + def generate_w_logprobs( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + assert sampling_params.logprobs is not None + + if images is not None: + assert len(prompts) == len(images) + + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = {"image": image} + + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + for req_output in req_outputs: + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs)) + return outputs + + def generate_greedy( + self, + prompts: List[str], + max_tokens: int, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str]]: + greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + outputs = self.generate(prompts, greedy_params, images=images) + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] + + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[List[Image.Image]] = None, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images) + + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + beam_search_params = SamplingParams( + n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ) + outputs = self.generate(prompts, beam_search_params) + return outputs + + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +def get_tokenizer_pool_config(tokenizer_group_type): + if tokenizer_group_type is None: + return None + if tokenizer_group_type == "ray": + return TokenizerPoolConfig(pool_size=1, pool_type="ray", extra_config={}) + raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog + + +@pytest.fixture(scope="session") +def num_gpus_available(): + """Get number of GPUs without initializing the CUDA context + in current process.""" + + return cuda_device_count_stateless() diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18239cbc8fc00aaf65297a77fd5db0bf27e6ac --- /dev/null +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py @@ -0,0 +1,45 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py`. +""" + +from conftest import VllmRunner +import os +import argparse + +# get the path of the current file +current_file_path = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file_path) + +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas") +parser = argparse.ArgumentParser(description="Inference with BitNet") +parser.add_argument( + "--ckpt_path", + type=str, + default=ckpt_path, + help="Path to the checkpoint", +) + +args = parser.parse_args() + +ckpt_path = args.ckpt_path +with VllmRunner( + ckpt_path, + dtype="half", + quantization="bitblas", + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, +) as bitnet_model: + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024) + print("bitnet inference:") + print(bitbnet_outputs[0][0]) + print(bitbnet_outputs[0][1]) diff --git a/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py new file mode 100644 index 0000000000000000000000000000000000000000..f631fb306772408b17d71c35a5ae8bc1084e10d9 --- /dev/null +++ b/examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py @@ -0,0 +1,47 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py`. +""" + +from conftest import VllmRunner +import os +import argparse + +# get the path of the current file +current_file_path = os.path.realpath(__file__) +current_dir = os.path.dirname(current_file_path) +ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B") + +parser = argparse.ArgumentParser(description="Inference with BitNet") +parser.add_argument( + "--ckpt_path", + type=str, + default=ckpt_path, + help="Path to the checkpoint", +) + +args = parser.parse_args() + +ckpt_path = args.ckpt_path + +with VllmRunner( + ckpt_path, + dtype="half", + quantization="bitnet_bitblas", + gpu_memory_utilization=0.5, + # set enforce_eager = False to enable cuda graph + # set enforce_eager = True to disable cuda graph + enforce_eager=False, +) as bitnet_model: + bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128) + print("bitnet inference output:") + print(bitbnet_outputs[0][0]) + print(bitbnet_outputs[0][1]) diff --git a/examples/bitnet-1.58b/vllm_workspace/utils.py b/examples/bitnet-1.58b/vllm_workspace/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e96b19e28ca9e21af070bdd187e4b026aca26bc7 --- /dev/null +++ b/examples/bitnet-1.58b/vllm_workspace/utils.py @@ -0,0 +1,45 @@ +from typing import Dict, List, Tuple + +TokensText = Tuple[List[int], str] + + +def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str): + """ + Compare the two sequences generated by different models, + which should be equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): + output_ids_0, output_str_0 = outputs_0 + output_ids_1, output_str_1 = outputs_1 + + assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + + +TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] + + +def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): + """ + Compare the logprobs of two sequences generated by different models, + which should be similar but not necessarily equal. + """ + assert len(outputs_0_lst) == len(outputs_1_lst) + + # Loop through responses to each prompt. + for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + + # Loop through generated tokens. + for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + # If generated tokens don't match, then + if output_id_0 != output_id_1: + # Each predicted token must be in top N logprobs of the other + assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}" + + # Break out since sequences will now diverge. + break diff --git a/examples/blocksparse_attention/README.md b/examples/blocksparse_attention/README.md new file mode 100644 index 0000000000000000000000000000000000000000..89f75b81de950a1139c78c73616d5689afac6b49 --- /dev/null +++ b/examples/blocksparse_attention/README.md @@ -0,0 +1,6 @@ +# Block-Sparse Flash-Attention + +Tilelang implementation of block-sparse flash-attention kernels. + +The kernels have been used in [Rectified Sparse Attention](https://arxiv.org/abs/2506.04108) and [SeerAttention-R](https://arxiv.org/abs/2506.08889). + diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..1794836342197de8c16bfa2eb515e872c94c663b --- /dev/null +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -0,0 +1,361 @@ +# ruff: noqa: E712 +import math +import torch + +import triton +import triton.language as tl +import torch.nn.functional as F + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + k_block_col_idx, + block_mask_ptr, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, + sm_scale, + seqlen_k, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) + # print + + if mask_val == True: + start_n = k_block_col_idx * BLOCK_N + # -- compute qk ---- + + k = tl.load(k_ptrs + start_n * stride_kt) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK: + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + start_n * stride_vt) + + p = p.to(v.type.element_ty) + + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + return acc, l_i, m_i + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + block_mask_ptr, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, + PAST_LEN, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + Q_LEN = N_CTX - PAST_LEN + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_h = off_hz % H + off_z = off_hz // H + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + mask_ptrs = block_mask_ptr + start_m * stride_bmm + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) + + k_block_start = 0 + k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) + + # loop over k, v and update accumulator + for col_idx in range(k_block_start, k_block_end): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + col_idx, + mask_ptrs, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, + sm_scale, + N_CTX, + PAST_LEN, + col_idx == k_block_end - 1, + BLOCK_M, + BLOCK_N, + ) + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(Out.dtype.element_ty) + + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) + + +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert k.shape[2] == v.shape[2] + o = out if out is not None else torch.empty_like(q).contiguous() + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) + + assert q.shape[-1] in [64, 128] + BLOCK_DMODEL = q.shape[-1] + + if is_hip(): + num_warps, num_stages = 8, 1 + else: + num_warps, num_stages = 4, 2 + + N_CTX = k.shape[2] + PAST_LEN = N_CTX - q.shape[2] + + H = q.shape[1] + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + block_sparse_mask, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), + *o.stride(), + H, + N_CTX, + PAST_LEN, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o + + +class _sparse_attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, block_sparse_dense, sm_scale): + # shape constraints + return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) + + @staticmethod + def backward(ctx, do): + # No gradient propagation. + raise NotImplementedError("It does not support gradient propagation yet") + return None, None, None, None, None + + +block_sparse_triton_fn = _sparse_attention.apply + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + print("downsample_len", downsample_len) + + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + print("x_ds.shape", x_ds.shape) + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + # print("block_mask", block_mask) + print("block_mask.shape", block_mask.shape) + + # Run Triton kernel + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + # print("ref_output", ref_output) + # print("triton_output", triton_output) + + # Verify accuracy + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" + print("Pass topk sparse attention test with qlen == klen") + + +def test_topk_sparse_attention_qlt_kl(): + BATCH, N_HEADS = 2, 4 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. + TOPK = 1 + BLOCK = 64 # block size used in downsampling + torch.manual_seed(0) + + # Create inputs. + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + # softmax scale + sm_scale = 1.0 / (D_HEAD**0.5) + + downsample_factor = BLOCK + print("downsample_factor", downsample_factor) + downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension + print("downsample_len", downsample_len) + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) + # Force the first column to be high so that the first block is always selected. + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + print("block_mask", block_mask) + print("block_mask.shape", block_mask.shape) + # Run Triton kernel. + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + past_len = K_LEN - Q_LEN + + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() + full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + + effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + + i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) + + final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) + + attn = attn.masked_fill(~final_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + # Verify accuracy. + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" + + print("Pass topk sparse attention test with qlen < klen") + + +def main(): + test_topk_sparse_attention() + test_topk_sparse_attention_qlt_kl() + + +if __name__ == "__main__": + main() diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..934b0b25efaac9568dac2b398d274321a803b54a --- /dev/null +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -0,0 +1,221 @@ +import math +import torch + +import tilelang +import tilelang.language as T +import torch.nn.functional as F + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@tilelang.jit( + out_idx=[4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 1 + threads = 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool + + def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def blocksparse_flashattn( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[k] != 0: + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return blocksparse_flashattn + + return kernel_func(block_M, block_N, num_stages, threads) + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + # Run tilelang kernel + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + + tilelang_output = kernel(q, k, v, block_mask) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + print("ref_output", ref_output) + print("tilelang_output", tilelang_output) + + # Verify accuracy + torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2) + print("Pass topk sparse attention test with qlen == klen") + + +def main(): + test_topk_sparse_attention() + + +if __name__ == "__main__": + main() diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py new file mode 100644 index 0000000000000000000000000000000000000000..77a29ebe284ef7df8265687bca1217166475739d --- /dev/null +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -0,0 +1,551 @@ +# ruff: noqa +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse +import time +import math + +from heuristic import num_splits_heuristic + + +def flashattn(batch, heads, heads_kv, dim, dim_v): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // heads_kv + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + def kernel_func( + block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks + ): + shape_q = [batch, heads, dim] + shape_k = [num_pages, page_block_size, heads_kv, dim] + shape_v = [num_pages, page_block_size, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_block_table = [batch, max_num_blocks_per_seq] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + assert block_N <= page_block_size and page_block_size % block_N == 0 + block_ratio = page_block_size // block_N + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var("bool") + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + logical_block_idx = block_indices[bid, cur_kv_head, start + k] + if logical_block_idx >= 0: + has_valid_block = True + block_table_idx = T.floordiv(logical_block_idx, block_ratio) + block_tile_idx = T.floormod(logical_block_idx, block_ratio) + physical_block_idx = block_table[bid, block_table_idx] + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else( + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] /= logsum[i] + + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + max_split = T.alloc_local([1], T.int32) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split[0] = glse[bz, by, k] + if lse_local_split[0] != 0: + max_split[0] = k + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split[0]: + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + if k <= max_split[0]: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, Output_partial) + combine(glse, Output_partial, Output) + + return main + + return kernel_func + + +class SparseFlashAttn(torch.nn.Module): + def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages): + super(SparseFlashAttn, self).__init__() + self.batch = batch + self.heads = heads + self.heads_kv = heads_kv + self.dim = dim + self.dim_v = dim_v + self.block_N = block_N + self.page_block_size = page_block_size + self.num_pages = num_pages + self.block_H = 64 + + self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_N, + block_H=self.block_H, + page_block_size=page_block_size, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + num_pages=num_pages, + max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"), + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) + + props = torch.cuda.get_device_properties(torch.device("cuda:0")) + self.num_sm = props.multi_processor_count + + def forward(self, query, key, value, block_indices, cache_seqlens, block_table): + batch = self.batch + heads = self.heads + heads_kv = self.heads_kv + dim_v = self.dim_v + dim = self.dim + block_size = self.block_N + max_selected_blocks = block_indices.shape[-1] + + # Compute static scheduling parameters + num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + + num_sm = self.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = self.kernel( + query, + key, + value, + block_indices, + cache_seqlens, + block_table, + glse, + output_partial, + ) + return output + + +def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size): + """ + Paged version of sparse attention reference implementation. + + Args: + query: [batch, heads, dim] + key_cache: [num_pages, page_block_size, heads_kv, dim] + value_cache: [num_pages, page_block_size, heads_kv, dim] + block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices + cache_seqlens: [batch] - actual sequence lengths + block_table: [batch, max_num_blocks_per_seq] - maps logical to physical blocks + page_block_size: size of each page block + block_size: size of attention blocks (block_N) + """ + batch, heads, dim = query.shape + heads_kv = key_cache.shape[2] + dim_v = value_cache.shape[3] + num_head_groups = heads // heads_kv + scale = dim**0.5 + + # Reconstruct the full key and value tensors from paged cache + max_cache_seqlen = max(cache_seqlens).item() + key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device) + value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device) + + # Reconstruct full tensors from paged cache using block_table + for b in range(batch): + seq_len = cache_seqlens[b].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + + for block_idx in range(num_blocks_needed): + physical_block_idx = block_table[b, block_idx].item() + + # Calculate the range of tokens for this block + start_token = block_idx * page_block_size + end_token = min(start_token + page_block_size, seq_len) + actual_block_size = end_token - start_token + + # Copy from paged cache to full tensors + key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1) + + # Reshape query for grouped attention + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + # Compute attention scores + scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + # Create sparse mask based on block_indices + sparse_mask = torch.zeros_like(scores) + + # Apply sparse mask based on selected blocks + for b in range(batch): + for h in range(heads_kv): + valid_indices = block_indices[b, h] # Extract indices for this batch and head + for idx in valid_indices: + if idx >= 0: # Valid block index + start_pos = idx * block_size + end_pos = min(start_pos + block_size, max_cache_seqlen) + sparse_mask[b, :, h, start_pos:end_pos] = 1 + + # Apply sparse mask + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + # Apply causal mask based on actual sequence lengths + range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + + # Compute attention weights + attention = F.softmax(scores / scale, dim=-1) + + # Apply attention to values + out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + + # Reshape output back to original format + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + + return out + + +def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table) + output = output.squeeze(1) + return output + + +def main(args): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( + args.batch, + args.heads, + args.heads_kv, + args.max_cache_seqlen, + args.dim, + args.dim_v, + ) + sparse_ratio = args.sparse_ratio + block_N = args.block_N + page_block_size = args.page_block_size + num_blocks = args.num_pages # Use num_pages from args + + # For dense case verification, set sparse_ratio to 0 to select all blocks + max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) + print("max_selected_blocks: ", max_selected_blocks) + dtype = torch.float16 + + # Generate random inputs + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") + print("cache_seqlens: ", cache_seqlens) + + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + + # Create paged KV cache + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") + + # Create block table and block indices for dense case (all blocks selected) + max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) + print("max_num_blocks_per_seq: ", max_num_blocks_per_seq) + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") + + # Fill block table and block indices and cache + + # Create a pool of available physical blocks + total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + available_blocks = list(range(total_blocks_needed)) + import random + + random.seed(42) # For reproducibility + random.shuffle(available_blocks) + + # Fill block table with random physical block indices + block_assignment = {} # Map (seq_idx, block_idx) -> physical_block_idx + block_idx_counter = 0 + + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + + # Assign random physical blocks for each sequence + for block_idx in range(num_blocks_needed): + physical_block_idx = available_blocks[block_idx_counter] + block_table[seq_idx, block_idx] = physical_block_idx + block_assignment[(seq_idx, block_idx)] = physical_block_idx + block_idx_counter += 1 + + print(f"Block table: {block_table}") + + # Fill K_cache and V_cache with data from original K and V tensors using random block assignment + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + + for block_idx in range(num_blocks_needed): + physical_block_idx = block_assignment[(seq_idx, block_idx)] + + # Calculate the range of tokens for this block + start_token = block_idx * page_block_size + end_token = min(start_token + page_block_size, seq_len) + actual_block_size = end_token - start_token + + # Copy K and V data to the paged cache + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :] + + # Fill block_indices for sparse attention + # For dense case (verification), we select all blocks in reverse order + # For sparse case, we select a subset of blocks based on sparse_ratio + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_tile = int(math.ceil(seq_len / block_N)) + + if sparse_ratio == 0.0: + # Dense case: select all blocks in reverse order + selected_blocks = min(num_tile, max_selected_blocks) + for head_idx in range(heads_kv): + for i in range(selected_blocks): + # Select blocks in reverse order (most recent first) + block_indices[seq_idx, head_idx, i] = num_tile - 1 - i + # Fill remaining slots with -1 (invalid) + for i in range(selected_blocks, max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + else: + # Fill block_indices for all KV heads + num_selected = int(num_tile * (1.0 - sparse_ratio)) + num_selected = max(1, min(num_selected, max_selected_blocks)) + all_blocks = list(range(num_tile)) + for head_idx in range(heads_kv): + selected_blocks = [] + # Always include the most recent blocks + recent_blocks = 1 + selected_blocks.append(num_tile - 1) + + # Randomly select some earlier blocks + if num_selected > recent_blocks: + remaining_blocks = [b for b in all_blocks if b not in selected_blocks] + if remaining_blocks: + import random + + random.seed(42) # For reproducibility + additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks))) + selected_blocks.extend(additional_blocks) + + # Sort selected blocks in reverse order (most recent first) + selected_blocks.sort(reverse=True) + + for i in range(len(selected_blocks)): + block_indices[seq_idx, head_idx, i] = selected_blocks[i] + # Fill remaining slots with -1 (invalid) + for i in range(len(selected_blocks), max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + + # Initialize sparse attention module + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) + output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) + + import flash_attn # noqa: F401 + + output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N) + + output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) + # Check correctness + if sparse_ratio == 0.0: + max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item() + mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item() + assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!" + else: + max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item() + mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item() + + print(f"Max difference: {max_diff:.6f}") + print(f"Mean difference: {mean_diff:.6f}") + + if max_diff < 1e-2: + print("โœ“ Verification PASSED: Results match within tolerance") + else: + print("โœ— Verification FAILED: Results differ significantly") + + # Performance measurement + for _ in range(10): # Warm-up + sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) + + torch.cuda.synchronize() + start_time = time.time() + for _ in range(100): # Run multiple times for averaging + sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table) + torch.cuda.synchronize() + end_time = time.time() + + kernel_time = (end_time - start_time) / 100 * 1000 # Convert to ms + print(f"Kernel execution time: {kernel_time:.2f} ms") + + # FA performance measurement + for _ in range(10): # Warm-up + ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) + + torch.cuda.synchronize() + start_time_fa = time.time() + for _ in range(100): # Run multiple times for averaging + ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table) + torch.cuda.synchronize() + end_time_fa = time.time() + kernel_time_fa = (end_time_fa - start_time_fa) / 100 * 1000 # Convert to ms + print(f"FA kernel execution time: {kernel_time_fa:.2f} ms") + + print(f"Speedup: {kernel_time_fa / kernel_time:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio") + parser.add_argument("--block_N", type=int, default=64, help="block_N") + parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages") + parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages") + args = parser.parse_args() + main(args) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py new file mode 100644 index 0000000000000000000000000000000000000000..257f41543c3fc2f9d4e044d4ef9a4283edf01142 --- /dev/null +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -0,0 +1,435 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from einops import rearrange, einsum +import argparse +import time +import math +from heuristic import num_splits_heuristic + + +def flashattn(batch, heads, heads_kv, dim, dim_v): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // heads_kv + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var("bool") + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + + for k in T.Pipelined(loop_range, num_stages=num_stages): + i_s = block_indices[bid, cur_kv_head, start + k] + if i_s >= 0: + has_valid_block = True + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] /= logsum[i] + + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + max_split = T.alloc_local([1], T.int32) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split[0] = glse[bz, by, k] + if lse_local_split[0] != 0: + max_split[0] = k + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split[0]: + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + if k <= max_split[0]: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) + flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial) + combine(glse, Output_partial, Output) + + return main + + return kernel_func + + +class SparseFlashAttn(torch.nn.Module): + def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): + super(SparseFlashAttn, self).__init__() + self.batch = batch + self.heads = heads + self.heads_kv = heads_kv + self.dim = dim + self.dim_v = dim_v + self.block_size = block_size + + self.block_H = 64 + + self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_size, + block_H=self.block_H, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + max_cache_seqlen=T.dynamic("max_cache_seqlen"), + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) + + props = torch.cuda.get_device_properties(torch.device("cuda:0")) + self.num_sm = props.multi_processor_count + + def forward(self, query, key, value, block_indices, cache_seqlens): + batch = self.batch + heads = self.heads + heads_kv = self.heads_kv + dim_v = self.dim_v + dim = self.dim + block_size = self.block_size + max_selected_blocks = block_indices.shape[-1] + + # Compute static scheduling parameters + num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + + num_sm = self.num_sm + + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) + return output + + +def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size): + """ + Args: + query: [batch, heads, dim] + key: [batch, max_cache_seqlen, heads_kv, dim] + value: [batch, max_cache_seqlen, heads_kv, dim_v] + block_indices: [batch, heads_kv, max_selected_blocks], indices of selected blocks, -1 for padding + cache_seqlens: [batch], sequence lengths of the kvcache + max_cache_seqlen: maximum sequence length of kvcache + block_size: block size + Returns: + output: [batch, heads, dim_v] + + """ + + batch, heads, dim = query.shape + heads_kv = key.shape[2] + dim_v = value.shape[-1] + max_selected_blocks = block_indices.shape[-1] + block_H = 64 + + actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32) + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + + # get num_split + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size + # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = 132 + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_size, + block_H=block_H, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + max_cache_seqlen=T.dynamic("max_cache_seqlen"), + max_selected_blocks=T.dynamic("max_selected_blocks"), + ) + + output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) + return output + + +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + batch, heads, dim = query.shape + heads_kv = key.shape[2] + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + sparse_mask = torch.zeros_like(scores) + # Assign mask values based on block_indices + for b in range(batch): + for h in range(heads_kv): + valid_indices = block_indices[b, h] # Extract indices for this batch and head + for idx in valid_indices: + if idx >= 0: + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) + output = output.squeeze(1) + return output + + +def debug(name, expect, actual, atol=1e-3, rtol=1e-3): + all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) + print(name + " all_close={}".format(all_close)) + if not all_close: + diff = (expect - actual).abs() + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) + max_indices = torch.nonzero(diff == diff.max().item()) + first_index = tuple(max_indices[0].tolist()) + print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") + + +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + print("max_selected_blocks: ", max_selected_blocks) + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') + # # Ensure at least one element equals cache_seqlen + # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index + # # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + + print("cache_seqlens: ", cache_seqlens) + + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + print("max_valid_num_blocks: ", max_valid_num_blocks) + # Initialize block_indices with -1 (for padding blocks) + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") + # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) + # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') + + # Assign valid indices while ensuring no duplicates within each batch-group + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch + if max_valid_block > 0: # Ensure there's at least one valid block + for h in range(heads_kv): + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices + + # Sort indices within each batch-group for consistency + block_indices, _ = block_indices.sort(dim=-1, descending=True) + # print("block_indices: ", block_indices) + actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] + print("actual_num_blocks: ", actual_num_blocks) + # print(block_indices.shape, actual_num_blocks.shape) + + max_num_blocks = torch.max(max_valid_num_blocks).item() + print("max_num_blocks: ", max_num_blocks) + + # parity reference + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) + + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) + debug("output", ref, out, atol=1e-3, rtol=1e-3) + + import flash_attn # noqa: F401 + + ## latency reference + for _ in range(10): + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) + torch.cuda.synchronize() + print("dense time: ", (time.time() - start) / 100 * 1000) + + for _ in range(10): + # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) + out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) + out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) + torch.cuda.synchronize() + print("sparse time: ", (time.time() - start) / 100 * 1000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..2957f8c970986e4f5f48673a7026677a10dc2b17 --- /dev/null +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -0,0 +1,420 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + +import time +import math +from heuristic import num_splits_heuristic + + +def flashattn(batch, heads, heads_kv, dim, dim_v): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // heads_kv + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + ) + def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_mask = [batch, heads_kv, num_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var("bool") + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[bid, hid, start + k]: + has_valid_block = True + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else( + (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) + combine(glse, Output_partial, Output) + + return main + + return kernel_func + + +class SparseFlashAttn(torch.nn.Module): + def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): + super(SparseFlashAttn, self).__init__() + self.batch = batch + self.heads = heads + self.heads_kv = heads_kv + self.dim = dim + self.dim_v = dim_v + self.block_size = block_size + + self.block_H = 64 + + self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_size, + block_H=self.block_H, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + max_cache_seqlen=T.dynamic("max_cache_seqlen"), + num_blocks=T.dynamic("num_blocks"), + ) + + props = torch.cuda.get_device_properties(torch.device("cuda:0")) + self.num_sm = props.multi_processor_count + + def forward(self, query, key, value, block_mask, cache_seqlens): + batch = self.batch + heads = self.heads + heads_kv = self.heads_kv + dim_v = self.dim_v + dim = self.dim + block_size = self.block_size + block_H = self.block_H + max_cache_seqlen = key.shape[1] + # get num_split + max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + # num_sm = 132 + num_sm = self.num_sm + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + # print("num_split: ", num_split) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) + return output + + +def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_size): + """ + Args: + query: [batch, heads, dim] + key: [batch, max_cache_seqlen, heads_kv, dim] + value: [batch, max_cache_seqlen, heads_kv, dim_v] + block_mask: [batch, heads_kv, num_blocks], mask for valid blocks + cache_seqlens: [batch], sequence lengths of the kvcache + block_size: block size + Returns: + output: [batch, heads, dim_v] + + """ + + batch, heads, dim = query.shape + heads_kv = key.shape[2] + dim_v = value.shape[-1] + block_H = 64 + + actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32) + actual_num_blocks = actual_num_blocks[ + :, 0 + ] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks + max_selected_blocks = actual_num_blocks.max().item() + # get num_split + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size + # num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = 132 + num_split = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + block_N=block_size, + block_H=block_H, + num_split=T.dynamic("num_split"), + num_stages=2, + threads=128, + max_cache_seqlen=T.dynamic("max_cache_seqlen"), + num_blocks=T.dynamic("num_blocks"), + ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + # print(kernel.get_kernel_source()) + + output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) + + return output + + +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + batch, heads, dim = query.shape + heads_kv = key.shape[2] + + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + sparse_mask = torch.zeros_like(scores) + # Assign mask values + for b in range(batch): + for h in range(heads_kv): + for idx in range(num_blocks): + if block_mask[b, h, idx]: + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) + output = output.squeeze(1) + return output + + +def debug(name, expect, actual, atol=1e-3, rtol=1e-3): + all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) + print(name + " all_close={}".format(all_close)) + if not all_close: + # print(expect[3, 28]) + # print(actual[3, 28]) + diff = (expect - actual).abs() + print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) + max_indices = torch.nonzero(diff == diff.max().item()) + first_index = tuple(max_indices[0].tolist()) + print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") + + +def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + print("max_selected_blocks: ", max_selected_blocks) + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + # Ensure at least one element equals cache_seqlen + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') + + print("cache_seqlens: ", cache_seqlens) + + num_blocks = (max_cache_seqlen + block_size - 1) // block_size + + valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int() + print("valid_num_blocks: ", valid_num_blocks) + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + print("max_valid_num_blocks: ", max_valid_num_blocks) + # Initialize block_mask with false (for padding blocks) + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") + + # Assign valid indices while ensuring no duplicates within each batch-group + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch + valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch + if valid_num_block > 0: # Ensure there's at least one valid block + for h in range(heads_kv): + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] + block_mask[b, h, perm] = True + # print("block_mask: ", block_mask) + + # parity reference + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) + # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) + model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + out = model(Q, K, V, block_mask, cache_seqlens) + debug("output", ref, out, atol=1e-3, rtol=1e-3) + + import flash_attn # noqa: F401 + + ## latency reference + for _ in range(10): + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) + torch.cuda.synchronize() + print("dense time: ", (time.time() - start) / 100 * 1000) + + for _ in range(10): + # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) + out = model(Q, K, V, block_mask, cache_seqlens) + + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) + out = model(Q, K, V, block_mask, cache_seqlens) + torch.cuda.synchronize() + print("sparse time: ", (time.time() - start) / 100 * 1000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py new file mode 100644 index 0000000000000000000000000000000000000000..b61d52fa092f4d8cd115905d71cde59a99ca88dc --- /dev/null +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -0,0 +1,433 @@ +# ruff: noqa +import torch +import triton +import triton.language as tl +import argparse +from einops import rearrange, einsum +import torch.nn.functional as F + +import math +import time +from heuristic import num_splits_heuristic + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], +) +@triton.jit +def _split_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + cache_seqlens_ptr, + o_partial_ptr, + lse_partial_ptr, + mask_ptr, + sm_scale, + num_splits, + gqa_group_size, + max_selected_blocks, + stride_q_b, + stride_q_h, + stride_q_d, + stride_k_b, + stride_k_s, + stride_k_h, + stride_k_d, + stride_v_b, + stride_v_s, + stride_v_h, + stride_v_d, + stride_o_b, + stride_o_h, + stride_o_split, + stride_o_d, + stride_lse_b, + stride_lse_h, + stride_lse_split, + stride_mask_b, + stride_mask_h, + stride_mask_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_idx_kv = tl.program_id(1) + split_idx = tl.program_id(2) + + head_idx_q = head_idx_kv * gqa_group_size + offs_h = tl.arange(0, BLOCK_H) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + + cache_seqlens = tl.load(cache_seqlens_ptr + batch_idx) + num_blocks = max_selected_blocks + blocks_per_split = tl.floor(num_blocks / num_splits).to(tl.int32) + remaining_blocks = num_blocks % num_splits + if split_idx < remaining_blocks: + loop_range = blocks_per_split + 1 + else: + loop_range = blocks_per_split + + q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d + mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h + + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) + start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) + for i in range(loop_range): + block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s) + if block_idx >= 0: + start_n = block_idx * BLOCK_N + k_ptr = k_cache_ptr + start_n * stride_k_s + v_ptr = v_cache_ptr + start_n * stride_v_s + + k = tl.load(k_ptr, mask=start_n + offs_n[None, :] < cache_seqlens, other=0.0) + v = tl.load(v_ptr, mask=start_n + offs_n[:, None] < cache_seqlens, other=0.0) + + qk = tl.dot(q, k) + qk = qk * sm_scale + qk = tl.where(start_n + offs_n[None, :] < cache_seqlens, qk, float("-inf")) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + p = p.to(v.type.element_ty) + acc += tl.dot(p, v) + m_i = m_ij + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(o_partial_ptr.dtype.element_ty) + + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) + + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) + tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], +) +@triton.jit +def _merge_kernel( + o_partial_ptr, + lse_partial_ptr, + o_ptr, + lse_partial_stride_b, + lse_partial_stride_h, + lse_partial_stride_split, + o_partial_stride_b, + o_partial_stride_h, + o_partial_stride_split, + o_partial_stride_d, + o_stride_b, + o_stride_h, + o_stride_d, + BLOCK_D: tl.constexpr, + num_splits: tl.constexpr, + num_splits_pow2: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + offs_splits = tl.arange(0, num_splits_pow2) + offs_d = tl.arange(0, BLOCK_D) + + lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) + + lse_max = tl.max(lse) + + o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h + o_partial = tl.load( + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) + sumexp_normalized_splitk = tl.exp(lse - lse_max) + sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) + numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) + acc = numerator_normalized / sumexp_normalized + acc = acc.to(o_ptr.dtype.element_ty) + o_ptr += batch_idx * o_stride_b + head_idx * o_stride_h + tl.store(o_ptr + offs_d * o_stride_d, acc) + + +def block_sparse_flash_decode_gqa_indice_triton( + q, + k_cache, + v_cache, + cache_seqlens, + max_cache_seqlen, + max_selected_blocks, + block_indices, + block_size, + sm_scale=None, +): + batch, heads, dim = q.shape + + if sm_scale is None: + sm_scale = 1 / math.sqrt(dim) + + _, max_cache_seqlen_cache, heads_kv, dim_v = v_cache.shape + assert max_cache_seqlen == max_cache_seqlen_cache, "max_cache_seqlen mismatch" + group_size = heads // heads_kv + + block_H = 16 + + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = 64 + # num_sm = self.num_sm + num_splits = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) + + num_splits_pow2 = triton.next_power_of_2(num_splits) + + o_partial = torch.empty((batch, heads, num_splits, dim_v), device=q.device, dtype=q.dtype) + lse_partial = torch.empty((batch, heads, num_splits), device=q.device, dtype=torch.float32) + + BLOCK_D = dim + BLOCK_H = group_size if group_size > 16 else 16 + grid = (batch, heads_kv, num_splits) + _split_kernel[grid]( + q, + k_cache, + v_cache, + cache_seqlens, + o_partial, + lse_partial, + block_indices, + sm_scale, + num_splits, + group_size, + max_selected_blocks, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + block_indices.stride(0), + block_indices.stride(1), + block_indices.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=block_size, + BLOCK_D=BLOCK_D, + ) + + output = torch.zeros((batch, heads, dim_v), device=q.device, dtype=q.dtype) + grid = (batch, heads) + _merge_kernel[grid]( + o_partial, + lse_partial, + output, + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + output.stride(0), + output.stride(1), + output.stride(2), + BLOCK_D=dim_v, + num_splits=num_splits, + num_splits_pow2=num_splits_pow2, + ) + + return output + + +def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + batch, heads, dim = query.shape + heads_kv = key.shape[2] + dim_v = value.shape[-1] + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + sparse_mask = torch.zeros_like(scores) + # Assign mask values based on block_indices + for b in range(batch): + for h in range(heads_kv): + valid_indices = block_indices[b, h] # Extract indices for this batch and head + for idx in valid_indices: + if idx >= 0: + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def ref_program_fa(query, key, value, cache_seqlens): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) + output = output.squeeze(1) + return output + + +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + qk_flops = 2 * batch * heads * max_cache_seqlen * dim + pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v + total_flops = qk_flops + pv_flops + + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + print("max_selected_blocks: ", max_selected_blocks) + dtype = torch.float16 + block_H = 64 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') + # Ensure at least one element equals cache_seqlen + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + + print("cache_seqlens: ", cache_seqlens) + + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + print("max_valid_num_blocks: ", max_valid_num_blocks) + # Initialize block_indices with -1 (for padding blocks) + block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") + + # Assign valid indices while ensuring no duplicates within each batch-group + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch + if max_valid_block > 0: # Ensure there's at least one valid block + for h in range(heads_kv): + valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, : len(valid_indices)] = valid_indices + + # Sort indices within each batch-group for consistency + block_indices, _ = block_indices.sort(dim=-1, descending=True) + # print("block_indices: ", block_indices) + actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] + print("actual_num_blocks: ", actual_num_blocks) + # print(block_indices.shape, actual_num_blocks.shape) + + max_num_blocks = torch.max(max_valid_num_blocks).item() + print("max_num_blocks: ", max_num_blocks) + + ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) + + triton_out = block_sparse_flash_decode_gqa_indice_triton( + Q, + K, + V, + cache_seqlens, + max_cache_seqlen, + max_selected_blocks, + block_indices, + block_size, + ) + + print("max difference: ", torch.max(torch.abs(ref - triton_out))) + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + print("Passed the ref test!") + + # Measure performance + torch.cuda.synchronize() + start = time.time() + for _ in range(1000): + block_sparse_flash_decode_gqa_indice_triton( + Q, + K, + V, + cache_seqlens, + max_cache_seqlen, + max_selected_blocks, + block_indices, + block_size, + ) + torch.cuda.synchronize() + end = time.time() + elapsed_time = end - start + avg_time = elapsed_time / 1000 + avg_flops = total_flops / avg_time + print(f"Average time: {avg_time:.6f} seconds") + + # Measure performance of reference implementation + import flash_attn # noqa: F401 + + start = time.time() + for _ in range(1000): + ref_program_fa(Q, K, V, cache_seqlens) + torch.cuda.synchronize() + end = time.time() + elapsed_time_ref = end - start + avg_time_ref = elapsed_time_ref / 1000 + avg_flops_ref = total_flops / avg_time_ref + print(f"Average time of ref: {avg_time_ref:.6f} seconds") + + print(f"Speedup: {avg_time_ref / avg_time:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..c05b3777952fddc834cc46377a823a4c14e0e999 --- /dev/null +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -0,0 +1,419 @@ +import torch +import triton +import triton.language as tl +import argparse +from einops import rearrange, einsum +import torch.nn.functional as F + +import math +import time +from heuristic import num_splits_heuristic + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_H", "BLOCK_N", "BLOCK_D"], +) +@triton.jit +def _split_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + cache_seqlens_ptr, + o_partial_ptr, + lse_partial_ptr, + mask_ptr, + sm_scale, + num_splits, + gqa_group_size, + stride_q_b, + stride_q_h, + stride_q_d, + stride_k_b, + stride_k_s, + stride_k_h, + stride_k_d, + stride_v_b, + stride_v_s, + stride_v_h, + stride_v_d, + stride_o_b, + stride_o_h, + stride_o_split, + stride_o_d, + stride_lse_b, + stride_lse_h, + stride_lse_split, + stride_mask_b, + stride_mask_h, + stride_mask_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_idx_kv = tl.program_id(1) + split_idx = tl.program_id(2) + + head_idx_q = head_idx_kv * gqa_group_size + offs_h = tl.arange(0, BLOCK_H) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + + cache_seqlens = tl.load(cache_seqlens_ptr + batch_idx) + num_blocks = (cache_seqlens + BLOCK_N - 1) // BLOCK_N + blocks_per_split = tl.floor(num_blocks / num_splits).to(tl.int32) + remaining_blocks = num_blocks % num_splits + if split_idx < remaining_blocks: + loop_range = blocks_per_split + 1 + else: + loop_range = blocks_per_split + + q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h + k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d + v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d + mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h + + q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size) + start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks) + for block_idx in range(loop_range): + start_n = (start + block_idx) * BLOCK_N + mask_val = tl.load(mask_ptr + (start + block_idx) * stride_mask_s) + if mask_val == 1: + k_ptr = k_cache_ptr + start_n * stride_k_s + v_ptr = v_cache_ptr + start_n * stride_v_s + + k = tl.load(k_ptr, mask=start_n + offs_n[None, :] < cache_seqlens, other=0.0) + v = tl.load(v_ptr, mask=start_n + offs_n[:, None] < cache_seqlens, other=0.0) + + qk = tl.dot(q, k) + qk = qk * sm_scale + qk = tl.where(start_n + offs_n[None, :] < cache_seqlens, qk, float("-inf")) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + p = p.to(v.type.element_ty) + acc += tl.dot(p, v) + m_i = m_ij + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(o_partial_ptr.dtype.element_ty) + + lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split + tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size) + + o_partial_ptr += ( + batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d + ) + tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size) + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]], + key=["BLOCK_D"], +) +@triton.jit +def _merge_kernel( + o_partial_ptr, + lse_partial_ptr, + o_ptr, + lse_partial_stride_b, + lse_partial_stride_h, + lse_partial_stride_split, + o_partial_stride_b, + o_partial_stride_h, + o_partial_stride_split, + o_partial_stride_d, + o_stride_b, + o_stride_h, + o_stride_d, + BLOCK_D: tl.constexpr, + num_splits: tl.constexpr, + num_splits_pow2: tl.constexpr, +): + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + offs_splits = tl.arange(0, num_splits_pow2) + offs_d = tl.arange(0, BLOCK_D) + + lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h + lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf")) + + lse_max = tl.max(lse) + + o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h + o_partial = tl.load( + o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d, + mask=offs_splits[:, None] < num_splits, + ) + sumexp_normalized_splitk = tl.exp(lse - lse_max) + sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) + numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0) + acc = numerator_normalized / sumexp_normalized + acc = acc.to(o_ptr.dtype.element_ty) + o_ptr += batch_idx * o_stride_b + head_idx * o_stride_h + tl.store(o_ptr + offs_d * o_stride_d, acc) + + +def block_sparse_flash_decode_gqa_mask_triton( + q, + k_cache, + v_cache, + cache_seqlens, + max_cache_seqlen, + block_mask, + block_size, + sm_scale=None, +): + batch, heads, dim = q.shape + + if sm_scale is None: + sm_scale = 1 / math.sqrt(dim) + + _, max_cache_seqlen_cache, heads_kv, dim_v = v_cache.shape + assert max_cache_seqlen == max_cache_seqlen_cache, "max_cache_seqlen mismatch" + group_size = heads // heads_kv + + block_H = 16 + + max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = 64 + # num_sm = self.num_sm + num_splits = num_splits_heuristic( + total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 + ) + + # print("num_splits:", num_splits, "num_blocks:", num_n_blocks) + + num_splits_pow2 = triton.next_power_of_2(num_splits) + + o_partial = torch.empty((batch, heads, num_splits, dim_v), device=q.device, dtype=q.dtype) + lse_partial = torch.empty((batch, heads, num_splits), device=q.device, dtype=torch.float32) + + BLOCK_D = dim + BLOCK_H = group_size if group_size > 16 else 16 + grid = (batch, heads_kv, num_splits) + _split_kernel[grid]( + q, + k_cache, + v_cache, + cache_seqlens, + o_partial, + lse_partial, + block_mask, + sm_scale, + num_splits, + group_size, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + block_mask.stride(0), + block_mask.stride(1), + block_mask.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=block_size, + BLOCK_D=BLOCK_D, + ) + + output = torch.zeros((batch, heads, dim_v), device=q.device, dtype=q.dtype) + grid = (batch, heads) + _merge_kernel[grid]( + o_partial, + lse_partial, + output, + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + output.stride(0), + output.stride(1), + output.stride(2), + BLOCK_D=dim_v, + num_splits=num_splits, + num_splits_pow2=num_splits_pow2, + ) + + return output + + +def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size): + batch, heads, dim = query.shape + heads_kv = key.shape[2] + + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + sparse_mask = torch.zeros_like(scores) + # Assign mask values + for b in range(batch): + for h in range(heads_kv): + for idx in range(num_blocks): + if block_mask[b, h, idx]: + sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1 + + scores = scores.masked_fill(sparse_mask == 0, float("-inf")) + + range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1) + pad_mask = range_len >= cache_seqlens_expanded + pad_mask = pad_mask[:, None, None, :] + scores = scores.masked_fill(pad_mask, float("-inf")) + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def ref_program_fa(query, key, value, cache_seqlens): + # latency reference + # from flash_attn_interface import flash_attn_with_kvcache # fa3 + from flash_attn import flash_attn_with_kvcache # fa2 + + query = query.unsqueeze(1) + output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens) + output = output.squeeze(1) + return output + + +def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + block_size = block_size + sparse_ratio = sparse_ratio + qk_flops = 2 * batch * heads * max_cache_seqlen * dim + pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v + total_flops = qk_flops + pv_flops + + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") + # Ensure at least one element equals cache_seqlen + random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index + cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence + + num_blocks = (max_cache_seqlen + block_size - 1) // block_size + + valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int() + print("valid_num_blocks: ", valid_num_blocks) + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + print("max_valid_num_blocks: ", max_valid_num_blocks) + # Initialize block_mask with false (for padding blocks) + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda") + + # Assign valid indices while ensuring no duplicates within each batch-group + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch + valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch + if valid_num_block > 0: # Ensure there's at least one valid block + for h in range(heads_kv): + perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] + block_mask[b, h, perm] = True + + ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size) + + triton_out = block_sparse_flash_decode_gqa_mask_triton( + Q, + K, + V, + cache_seqlens, + max_cache_seqlen, + block_mask, + block_size, + ) + + # print("max difference: ", torch.max(torch.abs(ref - triton_out))) + assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation" + print("Passed the ref test!") + + # Measure performance + torch.cuda.synchronize() + start = time.time() + for _ in range(1000): + block_sparse_flash_decode_gqa_mask_triton( + Q, + K, + V, + cache_seqlens, + max_cache_seqlen, + block_mask, + block_size, + ) + torch.cuda.synchronize() + end = time.time() + elapsed_time = end - start + avg_time = elapsed_time / 1000 + avg_flops = total_flops / avg_time + print(f"Average time: {avg_time:.6f} seconds") + print(f"Average flops: {avg_flops:.2f} GFLOPS") + + import flash_attn # noqa: F401 + + start = time.time() + for _ in range(1000): + ref_program_fa(Q, K, V, cache_seqlens) + + torch.cuda.synchronize() + end = time.time() + elapsed_time_ref = end - start + avg_time_ref = elapsed_time_ref / 1000 + avg_flops_ref = total_flops / avg_time_ref + print(f"Average time of ref: {avg_time_ref:.6f} seconds") + print(f"Average flops of ref: {avg_flops_ref:.2f} GFLOPS") + + print(f"Speedup: {avg_time_ref / avg_time:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv") + parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--dim_v", type=int, default=128, help="dim_v") + parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio") + parser.add_argument("--block_size", type=int, default=32, help="block_size") + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/heuristic.py b/examples/blocksparse_attention/heuristic.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6fc528196e3f111924b7d16b34d0c9af8c3800 --- /dev/null +++ b/examples/blocksparse_attention/heuristic.py @@ -0,0 +1,54 @@ +import math + + +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits): + """ + Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. + + Parameters: + - total_mblocks (int): Total number of m_blocks. + - num_SMs (int): Number of Streaming Multiprocessors (SMs) in the GPU. + - num_n_blocks (int): Number of n_blocks. + - num_m_blocks (int): Number of m_blocks. + - size_one_kv_head (int): Size of one KV head in bytes. + - is_causal_or_local (bool): Indicates whether the operation is causal or local. + - max_splits (int): Maximum number of allowed splits. + + Returns: + - int: The optimal number of splits. + """ + # If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply. + if total_mblocks >= 0.8 * num_SMs: + size_l2 = 50 * 1024 * 1024 # L2 cache size assumption (50MB) + # Only split if each KV head is too large for L2 and there are enough m_blocks + if size_one_kv_head > size_l2 and num_m_blocks >= num_SMs * 2 and not is_causal_or_local: + return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits) + else: + return 1 + + # If num_n_blocks is too small, we don't split + if num_n_blocks <= 4: + return 1 + + # Limit max_splits to a reasonable range + max_splits = min(max_splits, num_SMs, num_n_blocks) + + max_efficiency = 0.0 + efficiency = [] + + # Compute efficiency for different splits + for num_splits in range(1, max_splits + 1): + n_waves = (total_mblocks * num_splits) / num_SMs + eff = n_waves / math.ceil(n_waves) + # Track max efficiency + if eff > max_efficiency: + max_efficiency = eff + + efficiency.append(eff) + + # Find the smallest number of splits that achieves at least 85% of max efficiency + for num_splits in range(1, max_splits + 1): + if efficiency[num_splits - 1] >= 0.85 * max_efficiency: + return num_splits + + return 1 diff --git a/examples/blocksparse_attention/test_example_blocksparse_attention.py b/examples/blocksparse_attention/test_example_blocksparse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..dd33f46c4ef9705350bc2cc8894cb715d4444346 --- /dev/null +++ b/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -0,0 +1,39 @@ +import tilelang.testing +import block_sparse_attn_triton +import example_tilelang_block_sparse_attn +import example_tilelang_sparse_gqa_decode_varlen_indice +import example_tilelang_sparse_gqa_decode_varlen_mask +import example_triton_sparse_gqa_decode_varlen_indice +import example_triton_sparse_gqa_decode_varlen_mask + + +def test_block_sparse_attn_triton(): + block_sparse_attn_triton.main() + + +def test_example_tilelang_block_sparse_attn(): + example_tilelang_block_sparse_attn.main() + + +def test_example_tilelang_sparse_gqa_decode_varlen_indice(): + example_tilelang_sparse_gqa_decode_varlen_indice.main(batch=1, max_cache_seqlen=2048) + + +def test_example_tilelang_sparse_gqa_decode_varlen_mask(): + example_tilelang_sparse_gqa_decode_varlen_mask.main(batch=1, max_cache_seqlen=2048) + + +def test_example_triton_sparse_gqa_decode_varlen_indice(): + example_triton_sparse_gqa_decode_varlen_indice.main( + batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) + + +def test_example_triton_sparse_gqa_decode_varlen_mask(): + example_triton_sparse_gqa_decode_varlen_mask.main( + batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a34e45de7c31f4594f76127cea577306c7554e --- /dev/null +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -0,0 +1,179 @@ +import argparse +import itertools +import tilelang +import tilelang.language as T +from tilelang.engine.param import KernelParam +from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType +import torch +from typing import List + +DEFAULT_BLOCK_M = 128 +DEFAULT_BLOCK_N = 128 +DEFAULT_BLOCK_K = 32 +DEFAULT_NUM_STAGES = 2 +DEFAULT_THREAD_NUM = 128 +DEFAULT_ENABLE_RASTERIZATION = True + +parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark") +parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") +parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") +parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") +parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") +parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune") + +args, _ = parser.parse_known_args() +M, N, K = args.m, args.n, args.k +sparsity = args.sparsity +use_autotune = args.use_autotune +default_tensor_supply = get_tensor_supply(TensorSupplyType.Auto) + +print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}") +print(f"Target Block Sparsity: {sparsity}") +print(f"Using Autotuner: {use_autotune}\n") + + +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization)) + + return [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], + } + for c in _configs + ] + + +def ref_program(A, B, BlockMask, block_M, block_N, block_K): + ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) + for i in range(M // block_M): + for j in range(N // block_N): + accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) + for k in range(K // block_K): + if BlockMask[i, j, k]: + accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[ + k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N + ].to(torch.float32) + ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16) + return ref_c + + +def supply_program(params: List[KernelParam]): + input_tensors = [] + + for p in params: + # Check if the kernel parameter is BlockMask tensor. + # Here, BlockMask is uniquely identified by having 3 dimensions. + if len(p.shape) != 3: + # For non-BlockMask tensors, use the default tensor generation logic. + input_tensors.append(default_tensor_supply(p)) + else: + # For BlockMask tensor, randomly set elements to True based on desired + # sparsity level. + block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device()) + block_mask[:, :, :] = torch.rand(p.shape) > sparsity + input_tensors.append(block_mask) + + return input_tensors + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit(out_idx=[-1]) +def blocksparse_matmul( + M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 +): + block_mask_shape = (M // block_M, N // block_N, K // block_K) + + @T.prim_func + def block_sparse_matmul( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + BlockMask: T.Tensor(block_mask_shape, "bool"), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if BlockMask[by, bx, k]: + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return block_sparse_matmul + + +def main(): + # Initialize input matrices A and B on the GPU with half precision + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + if args.use_autotune: + # Run the autotuner to find the best kernel configuration and performance + # get_best_config is expected to return an object containing the compiled kernel, + # the best configuration found, latency, and reference latency. + kernel = blocksparse_matmul(M, N, K) + + best_config = kernel.config + best_latency = kernel.latency + block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"] + + print(f"Best Config: {best_config}") + print(f"Sparsity Ratio: {sparsity}") + print(f"Best Kernel Latency: {best_latency:.6f} ms") + else: + kernel = blocksparse_matmul( + M, + N, + K, + block_M=DEFAULT_BLOCK_M, + block_N=DEFAULT_BLOCK_N, + block_K=DEFAULT_BLOCK_K, + num_stages=DEFAULT_NUM_STAGES, + thread_num=DEFAULT_THREAD_NUM, + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + ) + block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K + print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") + # Create block mask with desired sparsity + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + + # Run the compiled kernel (either tuned or default) with the inputs + c = kernel(a, b, block_mask) + + # Compute the reference result using the naive PyTorch implementation + ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K) + + try: + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("โœ… Results are close! Verification successful.") + except AssertionError as e: + print("โŒ Verification FAILED: Results differ significantly.") + print(e) + + +if __name__ == "__main__": + main() diff --git a/examples/blocksparse_gemm/test_example_blocksparse_gemm.py b/examples/blocksparse_gemm/test_example_blocksparse_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b39f5e882b454392d9d7a380b923320e9cbbea --- /dev/null +++ b/examples/blocksparse_gemm/test_example_blocksparse_gemm.py @@ -0,0 +1,10 @@ +import tilelang.testing +import example_blocksparse_gemm + + +def test_example_blocksparse_gemm(): + example_blocksparse_gemm.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..6bde50c512ad038424e99235edf7ca44abb2d853 --- /dev/null +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -0,0 +1,209 @@ +import torch +import tilelang +import tilelang.language as T +from typing import Tuple +from tilelang.utils.tensor import torch_assert_close + +# support bfloat16, float, float16 +dtype = T.bfloat16 +accum_dtype = T.float32 + + +@tilelang.jit(out_idx=[2, 3]) +def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): + group_size = 128 + fp8_min = -448.0 + fp8_max = 448.0 + + @T.prim_func + def group_per_split_token_cast( + X: T.Tensor((M, N), dtype), + batch_sizes: T.Tensor((BG,), T.int32), + X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn), + X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype), + ): + with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): + row = bx + row_g_id = by + bg = bz + y_local = T.alloc_fragment((blk_m, group_size), accum_dtype) + y_amax_local = T.alloc_fragment((blk_m,), accum_dtype) + y_s_local = T.alloc_fragment((blk_m,), accum_dtype) + y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + row_offset = T.alloc_fragment((1,), T.int32) + + T.annotate_layout( + { + y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), + } + ) + + row_offset[0] = 0 + for i in T.serial(bg): + row_offset[0] += batch_sizes[i] + + T.copy( + X[row_offset[0] + row * blk_m : row_offset[0] + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], + y_local, + ) + T.reduce_absmax(y_local, y_amax_local, dim=1) + for i in T.Parallel(blk_m): + y_amax_local[i] = T.max(y_amax_local[i], 1e-4) + y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0) + for i, j in T.Parallel(blk_m, group_size): + y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) + T.copy(y_q_local, y_q_local_fp8) + for i, j in T.Parallel(blk_m, group_size): + y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0) + for i in T.Parallel(blk_m): + X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i] + T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) + + return group_per_split_token_cast + + +def ceil_div(x: int, y: int) -> int: + """ + Perform ceiling division of two integers. + + Args: + x: the dividend. + y: the divisor. + + Returns: + The result of the ceiling division. + """ + return (x + y - 1) // y + + +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return ceil_div(x, alignment) * alignment + + +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """ + Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. + If the input tensor is already column-major layout and 16-byte aligned along the M axis + (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. + + Arguments: + x: usually the LHS scaling tensor in GEMM. + + Returns: + The LHS scaling tensor of TMA-aligned transposed format. + """ + # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + assert x.dim() in (2, 3) + remove_dim = False + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.dim() == 2: + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x + x, remove_dim = x.unsqueeze(0), True + + b = x.shape[0] + + # The last kernel gives a column-major TMA aligned layout + if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: + return x.squeeze(0) if remove_dim else x + + # Normal layout requires transposing + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # this function don't support cpu tensor + assert x.dim() == 2 + m, n = x.shape + new_n = ceil_div(n, 128) * 128 + x_padded = torch.nn.functional.pad(x, (0, new_n - n)) + x_view = x_padded.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() + return x_fp8, (x_amax / 448.0).view(m, -1) + + +def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # assert x.shape[0] == batch_sizes.sum() + M_max = ceil_div(batch_sizes.max(), 128) * 128 + split_x = torch.split(x, batch_sizes.tolist(), dim=0) + padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x] + num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1] + x_fp8 = ( + torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float), + ) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i]) + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8 + + +def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [2048, 6144] + if dtype == T.float: + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + elif dtype == T.float16: + x = torch.randn(M, N, device="cuda", dtype=torch.float16) + elif dtype == T.bfloat16: + x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32) + M_max = int(ceil_div(batch_sizes.max(), 128) * 128) + + print("batch_sizes:", batch_sizes) + print("M_max:", M_max) + + kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) + print(kernel.get_kernel_source()) + # profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + + x_fp8, x_amax = kernel(x, batch_sizes) + x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes) + + torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01) + torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01) + print("All checks pass.") + + from tilelang.profiler import do_bench + + def run_tilelang(): + x_fp8_tilelang_, x_amax_tilelang_ = kernel(x, batch_sizes) + return x_fp8_tilelang_, x_amax_tilelang_ + + def run_torch(): + x_fp8_torch_, x_amax_torch_ = ref_program(x, batch_sizes) + return x_fp8_torch_, x_amax_torch_ + + latency = do_bench(run_tilelang) + print("Tile-lang: {:.2f} ms".format(latency)) + + latency = do_bench(run_torch) + print("Torch: {:.2f} ms".format(latency)) + + +if __name__ == "__main__": + main() diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6d14884039d228e2d2082662e952571622db20 --- /dev/null +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -0,0 +1,113 @@ +import torch +import tilelang +import tilelang.language as T +from typing import Tuple +from tilelang.utils.tensor import torch_assert_close + + +@tilelang.jit(out_idx=[1, 2]) +def per_token_cast_to_fp8(M, N, blk_m): + dtype = T.float + group_size = 128 + fp8_min = -448.0 + fp8_max = 448.0 + + @T.prim_func + def per_token_cast( + X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) + ): + with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): + row = bx + row_g_id = by + y_local = T.alloc_fragment((blk_m, group_size), dtype) + y_amax_local = T.alloc_fragment((blk_m,), dtype) + y_s_local = T.alloc_fragment((blk_m,), dtype) + y_q_local = T.alloc_fragment((blk_m, group_size), dtype) + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + + T.annotate_layout( + { + y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32), + } + ) + + T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local) + T.reduce_absmax(y_local, y_amax_local, dim=1) + for i in T.Parallel(blk_m): + y_amax_local[i] = T.max(y_amax_local[i], 1e-4) + y_s_local[i] = y_amax_local[i] / fp8_max + for i, j in T.Parallel(blk_m, group_size): + y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max) + T.copy(y_q_local, y_q_local_fp8) + for i in T.Parallel(blk_m): + X_amax[row * blk_m + i, row_g_id] = y_s_local[i] + T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size]) + + return per_token_cast + + +def ceil_div(x: int, y: int) -> int: + """ + Perform ceiling division of two integers. + + Args: + x: the dividend. + y: the divisor. + + Returns: + The result of the ceiling division. + """ + return (x + y - 1) // y + + +def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # this function don't support cpu tensor + assert x.dim() == 2 + m, n = x.shape + new_n = ceil_div(n, 128) * 128 + x_padded = torch.nn.functional.pad(x, (0, new_n - n)) + x_view = x_padded.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous() + return x_fp8, (x_amax / 448.0).view(m, -1) + + +def main(M=8192, N=8192, blk_m=8): + kernel = per_token_cast_to_fp8(M, N, blk_m) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + + x_fp8, x_amax = kernel(x) + x_fp8_ref, x_amax_ref = ref_program(x) + + print("x_fp8:", x_fp8, x_fp8.shape) + print("x_amax:", x_amax, x_amax.shape) + print("x_fp8_ref:", x_fp8_ref, x_fp8_ref.shape) + print("x_amax_ref:", x_amax_ref, x_amax_ref.shape) + + torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01) + torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01) + print("All checks pass.") + + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + latency = profiler.do_bench() + print("Tile-lang: {:.2f} ms".format(latency)) + + from tilelang.profiler import do_bench + from example_triton_cast_to_fp8 import per_token_group_quant_fp8 + + def run_triton(): + x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) + return x_fp8_triton_, x_amax_triton_ + + x_fp8_triton, x_amax_triton = run_triton() + latency = do_bench(run_triton) + print("Triton: {:.2f} ms".format(latency)) + + +if __name__ == "__main__": + main() diff --git a/examples/cast/example_triton_cast_to_fp8.py b/examples/cast/example_triton_cast_to_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..1859433f10b6f6bd438846473b5661718c34fe4f --- /dev/null +++ b/examples/cast/example_triton_cast_to_fp8.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}" + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s diff --git a/examples/cast/test_example_cast.py b/examples/cast/test_example_cast.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b10a7979cf6506ec93c21bc8e9d3ddec2cc214 --- /dev/null +++ b/examples/cast/test_example_cast.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_group_per_split_token_cast_to_fp8 +import example_per_token_cast_to_fp8 + + +def test_example_group_per_split_token_cast_to_fp8(): + example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896]) + + +def test_example_per_token_cast_to_fp8(): + example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/conftest.py b/examples/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..4010e0d83ae84c641151d6dd56dbf40ee42e301f --- /dev/null +++ b/examples/conftest.py @@ -0,0 +1,41 @@ +import os +import random +import pytest + +os.environ["PYTHONHASHSEED"] = "0" + +random.seed(0) + +try: + import torch +except ImportError: + pass +else: + torch.manual_seed(0) + +try: + import numpy as np +except ImportError: + pass +else: + np.random.seed(0) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + """Ensure that at least one test is collected. Error out if all tests are skipped.""" + known_types = { + "failed", + "passed", + "skipped", + "deselected", + "xfailed", + "xpassed", + "warnings", + "error", + } + if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0: + terminalreporter.write_sep( + "!", + (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"), + ) + pytest.exit("No tests were collected.", returncode=5) diff --git a/examples/convolution/README.md b/examples/convolution/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8ddca8a6aeb369e20c9e18d463b06f755ccb9221 --- /dev/null +++ b/examples/convolution/README.md @@ -0,0 +1 @@ +# Convolution diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd3972fb081ec57e6c569f1d42e28db0caa55be --- /dev/null +++ b/examples/convolution/example_convolution.py @@ -0,0 +1,111 @@ +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse + + +def check_hopper(): + if not torch.cuda.is_available(): + return None + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +def ref_program(stride, padding, dilation): + def main(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + + return main + + +@tilelang.jit(out_idx=[2]) +def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + dtype = T.float16 + accum_dtype = T.float32 + is_hopper = check_hopper() + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + } + ) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + if is_hopper: + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + else: + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + + args = parser.parse_args(argv) + N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p + a = torch.randn(N, H, W, C).cuda().half() + b = torch.randn(K, K, C, F).cuda().half() + + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads) + + out_c = kernel(a, b) + ref_c = ref_program(S, P, D)(a, b) + torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed.โœ…") + + +if __name__ == "__main__": + main() diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..59588ac4fbd40db9e1c3d45ab0ff105af28ce004 --- /dev/null +++ b/examples/convolution/example_convolution_autotune.py @@ -0,0 +1,177 @@ +import torch +import argparse +import itertools +import tilelang +import tilelang.language as T + + +def check_hopper(): + if not torch.cuda.is_available(): + return None + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + + +def ref_program(stride, padding, dilation): + def main(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + + return main + + +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + ) + ) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } + for c in _configs + ] + return configs + + +def get_heuristic_config() -> dict: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version in {80}: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} + elif sm_version in {90}: + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} + else: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[2]) +def convolution( + N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 +): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + dtype = T.float16 + accum_dtype = T.float32 + is_hopper = check_hopper() + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + if is_hopper: + T.annotate_layout( + { + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + } + ) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + if is_hopper: + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + else: + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W) + data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + if is_hopper: + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + else: + T.copy(out_local, out_flat[by * block_M, bx * block_N]) + + return main + + +def main( + n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True, +): + N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p + ref_prog = ref_program(S, P, D) + + if use_autotune: + kernel = convolution(N, C, H, W, F, K, S, D, P) + else: + config = get_heuristic_config() + kernel = convolution(N, C, H, W, F, K, S, D, P, **config) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + tilelang_latency = profiler.do_bench() + ref_latency = profiler.do_bench(ref_prog) + profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2) + print(f"TileLang latency: {tilelang_latency}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--n", type=int, default=128, help="n") + parser.add_argument("--c", type=int, default=128, help="c") + parser.add_argument("--h", type=int, default=64, help="h") + parser.add_argument("--w", type=int, default=64, help="w") + parser.add_argument("--f", type=int, default=128, help="f") + parser.add_argument("--k", type=int, default=3, help="k") + parser.add_argument("--s", type=int, default=1, help="s") + parser.add_argument("--d", type=int, default=1, help="d") + parser.add_argument("--p", type=int, default=1, help="p") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space") + args = parser.parse_args() + main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller) diff --git a/examples/convolution/test_example_convolution.py b/examples/convolution/test_example_convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..4c06fb0044e0a0d531bf9abeffe484d3d48acfa1 --- /dev/null +++ b/examples/convolution/test_example_convolution.py @@ -0,0 +1,21 @@ +import tilelang.testing + +import example_convolution +import example_convolution_autotune + + +# TODO(@cy): TMA with convolution must be fixed in future. +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_example_convolution(): + example_convolution.main([]) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_example_convolution_autotune(): + example_convolution_autotune.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py new file mode 100644 index 0000000000000000000000000000000000000000..18467a811898d20813b8ed1ac6b9838fd5efe59d --- /dev/null +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -0,0 +1,186 @@ +from typing import Tuple + +import torch +import tilelang.testing +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + +tilelang.testing.set_random_seed(42) + + +@tilelang.jit +def tl_gemm( + M, + N, + K, + block_N, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float8_e4m3fn, + ], "Currently only float8_e4m3 is supported" + assert out_dtype in [ + T.bfloat16, + T.float32, + ], "Currently only float16 and float32 are supported" + + group_size = 128 + block_M = 128 + block_K = 128 + + A_shape = (M, K) + Scales_A_shape = (M, T.ceildiv(K, group_size)) + B_shape = (N, K) + Scales_B_shape = (T.ceildiv(N, group_size), T.ceildiv(K, group_size)) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = (block_M, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + scales_a: T.Tensor(Scales_A_shape, T.float32), + scales_b: T.Tensor(Scales_B_shape, T.float32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + Scale_C_shared = T.alloc_shared((block_M), T.float32) + C_local = T.alloc_fragment(C_shared_shape, accum_dtype) + C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def ceildiv(a, b): + return (a + b - 1) // b + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): + # A_scale: (M, K//128) ==> (M//128, K//128, 128) + # B_scale: (N//128, K//128) ==> (N//128, K//128, 128) + # A_fp8: (M, K) + # B_fp8: (N, K) + # out_dtype: float16 or float32 + # return C: (M, N) + M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1] + A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1) + B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128) + C = torch.zeros(M, N, device="cuda", dtype=out_dtype) + c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32) + for i in range(ceildiv(M, 128)): + for j in range(ceildiv(N, 128)): + c_acc.zero_() + for k in range(ceildiv(K, 128)): + c = torch._scaled_mm( + A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128], + B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T, + scale_a=A_scales[i, k].view(128, 1).contiguous(), + scale_b=B_scales[j, k].view(1, 128).contiguous(), + out_dtype=torch.bfloat16, + ) + c_acc += c.to(torch.float32) + C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype) + return C + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype): + kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype) + src_code = kernel.get_kernel_source() + + # src_code is the generated cuda source + assert src_code is not None + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + A = torch.randn(M, K).to(torch.bfloat16).cuda() + B = torch.randn(N, K).to(torch.bfloat16).cuda() + A_fp8, A_scale = per_token_cast_to_fp8(A.clone()) + B_fp8, B_scale = per_block_cast_to_fp8(B.clone()) + + C = torch.zeros(M, N, device="cuda", dtype=out_dtype) + + kernel(A_fp8, B_fp8, C, A_scale, B_scale) + # Get Reference Result + ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype) + diff = calc_diff(C, ref_c) + print(f"diff: {diff}") + assert diff < 1e-3 + + profiler = kernel.get_profiler() + latency = profiler.do_bench(warmup=25) + # Ensure that the latency is not None + assert latency is not None + print(f"latency: {latency} ms") + tflops = 2 * M * N * K / latency / 1e9 + print(f"tflops: {tflops}") + + +def main(): + assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32) + + +if __name__ == "__main__": + for dtype in [T.float8_e4m3fn]: + for out_dtype in [T.bfloat16, T.float32]: + for block_N in [16, 32, 64, 128]: + assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32) diff --git a/examples/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py new file mode 100644 index 0000000000000000000000000000000000000000..c3dac38af95a572a07e55bcc74e2fe8e2d750d6a --- /dev/null +++ b/examples/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py @@ -0,0 +1,13 @@ +import tilelang.testing + +from example_deepgemm_fp8_2xAcc import main + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_deepgemm_fp8_2xAcc(): + main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/deepseek_mla/README.md b/examples/deepseek_mla/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e64b1c37d002559ea8313706340b48532a0c0b61 --- /dev/null +++ b/examples/deepseek_mla/README.md @@ -0,0 +1,140 @@ +# ๐Ÿš€ How to write high-performance kernel with TileLang: take MLA as an example + +TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang. + +## Introduction to MLA + +DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance. + +## Benchmark Results + +We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below. + +
+ + bs64_float16 + +
Figure 1๏ผšPerformance under batch size=64
+
+ +
+ + bs128_float16 + +
Figure 2๏ผšPerformance under batch size=128
+
+ +As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton. +Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this. + +## Implementation + +First, let's review the core computation logic of traditional FlashAttention: + +```python +# acc_s: [block_M, block_N] +# scores_max: [block_M] +# scores_scale: [block_M] +# acc_o: [block_M, dim] + +for i in range(loop_range): + acc_s = Q @ K[i] + scores_max_prev = scores_max + scores_max = max(acc_s, dim=1) + scores_scale = exp(scores_max_prev - scores_max) + acc_o *= scores_scale + acc_s = exp(acc_s - scores_max) + acc_o = acc_s @ V[i] + ... +``` + +Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency. + +Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance. + +This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling. + +Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input. + +Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory. + +### Layout Inference + +While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you. + +Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations. + +
+ + QK Layout + +
Figure 3๏ผšBuffer shapes in Q @ K
+
+ +
+ + PV Layout + +
Figure 4๏ผšBuffer shapes in acc_s @ V
+
+ +The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA. + +For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`. + +It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance. + +### Threadblock Swizzling + +Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions. + +In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code: + +```python +T.use_swizzle(panel_size: int, order: str = "row") +``` + +Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col". + + +### Shared Memory Swizzling + +In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance. + +One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency. + +Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code: + +```python +T.annotate_layout({ + S_shared: TileLang.layout.make_swizzled_layout(S_shared), +}) +``` + +Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout. + + +### Warp-Specialization + +The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. + +In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. + + +### Pipeline + + +Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation: + +```python +T.pipelined(range: int, stage: int) +``` + +Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases. + + +### Split-KV + +We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results. + +In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter. \ No newline at end of file diff --git a/examples/deepseek_mla/amd/README.md b/examples/deepseek_mla/amd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cc0fb576dce7e77e60d3cf5eedda2732dc3f546a --- /dev/null +++ b/examples/deepseek_mla/amd/README.md @@ -0,0 +1,52 @@ +# ๐Ÿš€ High-Performance FlashMLA Implementation Using TileLang on AMD MI300X Accelerators + +Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms. + +## Architectural Considerations and Optimization Strategies + +Key implementation differences between Hopper and MI300X architectures include: + +1. **Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations. + +2. **Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes: + - Reducing software pipeline stages + - Register-based caching of Q matrices instead of shared memory utilization: + ```python + # Original shared memory allocation + Q_shared = T.alloc_shared([block_H, dim], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + + # Optimized register allocation + Q_local = T.alloc_fragment([block_H, dim], dtype) + Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) + ``` + +3. **Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64. + +4. **Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code. + +## Performance Evaluation + +We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate: + +
+ + AMD FlashMLA Performance Comparison + +
Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)
+
+ +Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) (from 0.73x to 1.21x) in most test cases, while significantly outperforming Triton (up to 6.5x faster)implementations. This performance is achieved through a concise 70-line Python implementation! + +## Future Optimization Opportunities + +1. **Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction. + +2. **Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to: + - Reduce shared memory pressure + - Improve compute-to-memory access ratios + - Enhance parallelism through dimension-wise task distribution + +## Acknowledgment + +We would like to express our sincere gratitude to the AMD ROCm and Composable Kernel team for their outstanding contributions. We have learned a great deal from the ROCm software stack. diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..a9035793b9f305ee330185db329e615710488635 --- /dev/null +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -0,0 +1,307 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +def get_configs(): + import itertools + + BLOCK_N = [16, 32, 64, 128] + BLOCK_H = [16, 32, 64, 128] + num_split = [1, 2, 4, 8, 16, 32] + threads = [128, 256] + + _configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads)) + + return [ + { + "block_N": c[0], + "block_H": c[1], + "num_split": c[2], + "threads": c[3], + } + for c in _configs + ] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.macro + def flash_attn( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by): + Q_local = T.alloc_fragment([block_H, dim], dtype) + Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(seqlen_kv, block_N) + for k in T.Pipelined(loop_range, num_stages=0): + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # T.copy(acc_s, S_shared) + T.copy(acc_s, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) + + @T.macro + def flash_attn_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz): + Q_local = T.alloc_fragment([block_H, dim], dtype) + Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=0): + kv_start = (seqlen_kv // num_split) * bz + k * block_N + kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N + T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn(Q, Q_pe, KV, K_pe, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + parser.add_argument("--autotune", action="store_true", help="auto tune") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + enable_autotune = args.autotune + + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 32 + BLOCK_H = 64 + num_split = 4 + threads = 128 + + if enable_autotune: + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) + else: + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + input_tensors = profiler._get_inputs() + tilelang_output = kernel(*input_tensors) + ref_output = ref_program(*input_tensors) + print(f"Tilelang output: {tilelang_output}") + print(f"Ref output: {ref_output}") + torch.testing.assert_close(tilelang_output, ref_output, rtol=0.01, atol=0.01) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..18c0a5f86d7625af022832d36f58b123c0feb0f8 --- /dev/null +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py @@ -0,0 +1,512 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch +import triton +import triton.language as tl + +import tilelang +from tilelang.profiler import do_bench + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + num_stages=1, # 2 will oom in amd + ) + + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + ) + + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "flash_mla_triton": run_flash_mla_triton, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flash_mla_triton"]: + # flash_mla_triton doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}" + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + "flash_mla_triton", +] + +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="torch") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) + elif args.compare: + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" + ) + elif args.one: + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..861e841c4ec8b68851cd4bfdbfdce0fede87960f --- /dev/null +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py @@ -0,0 +1,509 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch +import triton +import triton.language as tl + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + num_stages=1, # 2 will oom in amd + ) + + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + ) + + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "flash_mla_triton": run_flash_mla_triton, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flash_mla_triton"]: + # flash_mla_triton doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}" + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + "flash_mla_triton", +] + +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [64, 128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="flash_mla_triton") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) + elif args.compare: + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" + ) + elif args.one: + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) diff --git a/examples/deepseek_mla/benchmark_mla.py b/examples/deepseek_mla/benchmark_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..544b5e1285c173e1521f049e1de9521baa53afee --- /dev/null +++ b/examples/deepseek_mla/benchmark_mla.py @@ -0,0 +1,628 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch +import triton +import triton.language as tl + +import tilelang +from tilelang.profiler import do_bench +from example_mla_decode_paged import mla_decode_tilelang + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@torch.inference_mode() +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + from flash_mla import flash_mla_with_kvcache, get_mla_metadata + + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + ) + + out_flash, lse_flash = flash_mla() + t = triton.testing.do_bench(flash_mla) + return out_flash, lse_flash, t + + +@torch.inference_mode() +def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + # pip install flashinfer-python + import flashinfer + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + kv_indptr = [0] + kv_indices = [] + for i in range(b): + seq_len = cache_seqlens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_table[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + for seq_len in cache_seqlens[1:]: + kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) + + q_indptr = torch.arange(0, b + 1).int() * s_q + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") + mla_wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + cache_seqlens, + h_q, + dv, + d - dv, + block_size, + causal, + 1 / math.sqrt(d), + q.dtype, + blocked_k.dtype, + ) + + def flashinfer(): + output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True) + return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) + + out_flash, lse_flash = flashinfer() + t = triton.testing.do_bench(flashinfer) + return out_flash, lse_flash, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + ) + + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + num_warps=4, + num_stages=2, + ) + + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), + o, + block_table, + cache_seqlens, + attn_logits, + num_kv_splits, + 1 / math.sqrt(d), + block_size, + ) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +@torch.inference_mode() +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = 64 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) + + def flash_mla_tilelang(): + out = kernel( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "tilelang": run_flash_mla_tilelang, + "flash_mla": run_flash_mla, + "flashinfer": run_flashinfer, + "flash_mla_triton": run_flash_mla_triton, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + # flashinfer has a different lse return value + # flash_mla_triton and flash_mla_tilelang doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + "tilelang", + "flash_mla", + "flashinfer", + "flash_mla_triton", +] + +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.float16, + } + for batch in [128] + for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] + for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="tilelang") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) + elif args.compare: + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" + ) + elif args.one: + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..0d141b4b39500f7860a3308158e5bc5c30d4fb5d --- /dev/null +++ b/examples/deepseek_mla/example_mla_decode.py @@ -0,0 +1,301 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): + scale = float(softmax_scale * 1.44269504) # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.macro + def flash_attn( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): + Q_shared = T.alloc_shared([block_H, dim], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = hid // (kv_group_num // block_H) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(seqlen_kv, block_N) + for k in T.Pipelined(loop_range, num_stages=2): + T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :]) + + @T.macro + def flash_attn_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = hid // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = (seqlen_kv // num_split) * bz + k * block_N + kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N + T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads, batch, threads=128) as (hid, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, hid, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, hid, k, i] + lse_local_split[0] = glse[bz, hid, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, hid, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn(Q, Q_pe, KV, K_pe, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def main( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + num_split = 1 + softmax_scale = (dim + pe_dim) ** -0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py new file mode 100644 index 0000000000000000000000000000000000000000..23001bde8a1fc7846d45f406d9e0db4d95252ece --- /dev/null +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -0,0 +1,378 @@ +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse +from tilelang.profiler import do_bench +import math + + +@tilelang.jit( + out_idx=[8], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None): + if softmax_scale is None: + softmax_scale = (dv + dpe) ** -0.5 + scale = float(softmax_scale * 1.44269504) # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = h_q // h_kv + VALID_BLOCK_H = min(block_H, kv_group_num) + assert h_kv == 1, "h_kv must be 1" + assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N" + + @T.macro + def flash_mla_kernel( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + CACHE_SEQLENS: T.Tensor([batch], T.int32), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): + Q_shared = T.alloc_shared([block_H, dv], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) + KV_shared = T.alloc_shared([block_N, dv], dtype) + K_pe_shared = T.alloc_shared([block_N, dpe], dtype) + O_shared = T.alloc_shared([block_H, dv], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_o = T.alloc_fragment([block_H, dv], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) + for kr in T.Pipelined(loop_range, num_stages=2): + k = loop_range - 1 - kr + kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + if kr == 0: + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] *= scores_scale[i] + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) + + @T.macro + def flash_mla_split_kv_kernel( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + CACHE_SEQLENS: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + ): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dv], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) + KV_shared = T.alloc_shared([block_N, dv], dtype) + K_pe_shared = T.alloc_shared([block_N, dpe], dtype) + O_shared = T.alloc_shared([block_H, dv], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dv], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) + blocks_per_split = T.floordiv(total_blocks, num_split) + remaining_blocks = T.floormod(total_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0) + start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N + + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + with T.Kernel(h_q, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dv], dtype) + o_accum_local = T.alloc_fragment([dv], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dv): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dv): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dv): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + # q: [b, s_q, h_q, d] + # block_table: [b, max_seqlen_pad // block_size] + # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] + # cache_seqlens: [b] + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out.to(dtype), lse.to(dtype) + + out_torch, _ = ref_mla() + return out_torch + + +def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = min(64, h_q // h_kv) + softmax_scale = d**-0.5 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + + def flash_mla_tilelang(): + out = profiler.func( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) + print("All close") + return out_flash, t + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--h_q", type=int, default=128, help="q heads number") + parser.add_argument("--h_kv", type=int, default=1, help="kv heads number") + parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length") + parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe") + parser.add_argument("--dv", type=int, default=512, help="value head dim") + args = parser.parse_args() + b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv + + device = "cuda" + dtype = torch.float16 + + s_q = 1 # for decode, s_q = 1 + block_size = 64 + cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device) + dpe = d - dv + causal = True + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 + + total_flops = s_q * total_seqlens * h_q * d * 2 + + q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) + out_flash, latency = run_tilelang_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a1300a239b57de9ac82c4f1705bea20920e873 --- /dev/null +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -0,0 +1,209 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from tilelang.carver.arch import driver +from einops import rearrange, einsum +import argparse + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + sm_num = driver.get_num_sms() + + @T.prim_func + def main_split_persistent( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(sm_num, threads=256) as (block_id): + Q_shared = T.alloc_shared([block_H, dim], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + # O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + # O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + T.use_swizzle(10) + + total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split + waves = T.ceildiv(total_tiles, sm_num) + for w in T.serial(waves): + tile_id = sm_num * w + block_id + bid = tile_id // ((heads // min(block_H, kv_group_num)) * num_split) + hid = tile_id // num_split % (heads // min(block_H, kv_group_num)) + sid = tile_id % num_split + cur_kv_head = hid // (kv_group_num // block_H) + + if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split: + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = (seqlen_kv // num_split) * sid + k * block_N + kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N + T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) + T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid]) + # T.copy(acc_o, O_shared) + T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :]) + + T.sync_grid() + waves = T.ceildiv(heads * batch, sm_num) + for w in T.serial(waves): + tile_id = sm_num * w + block_id + hid = tile_id // batch + bid = tile_id % batch + if bid < batch and hid < heads: + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bid, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bid, hid, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bid, hid, k, i] + lse_local_split[0] = glse[bid, hid, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bid, hid, i] = o_accum_local[i] + + return main_split_persistent + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = 64 + num_split = 2 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") + + +if __name__ == "__main__": + main() diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py new file mode 100644 index 0000000000000000000000000000000000000000..8e317fa00183e39581f460ff6159efa4aefb7d52 --- /dev/null +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -0,0 +1,606 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, + compile_flags=[ + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", + ], +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): + sm_scale = float(softmax_scale * 1.44269504) # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.macro + def flash_attn( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid): + Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) + Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) + Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared_0_l = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_0_r = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_1_l = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_1_r = T.alloc_shared([block_N, dim // 2], dtype) + K_tail_shared_0 = T.alloc_shared([block_N, pe_dim], dtype) + K_tail_shared_1 = T.alloc_shared([block_N, pe_dim], dtype) + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + + acc_o_l = T.alloc_fragment([block_H, dim // 2], accum_dtype) + acc_o_r = T.alloc_fragment([block_H, dim // 2], accum_dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + sumexp = T.alloc_fragment([block_H], accum_dtype) + sum_exp_shared = T.alloc_shared([block_H], accum_dtype) + sumexp_i = T.alloc_fragment([block_H], accum_dtype) + alpha_shared = T.alloc_shared([block_H], accum_dtype, scope="shared") + alpha_local = T.alloc_fragment([block_H], accum_dtype) + m_i = T.alloc_fragment([block_H], accum_dtype) + m_i_prev = T.alloc_fragment([block_H], accum_dtype) + + # TODO: Multi buffer + bar_q = T.alloc_barrier(arrive_count=384) + bar_k_0_ready = T.alloc_barrier(arrive_count=128) + bar_k_1_ready = T.alloc_barrier(arrive_count=128) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) + + cur_kv_head = hid // (kv_group_num // block_H) + NI = T.ceildiv((seqlen_kv // num_split), block_N) + + tx = T.get_thread_binding() + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + + T.barrier_arrive(bar_q) + + if tx < 128: + T.set_max_nreg(240, 1) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + T.fill(acc_o_l, 0) + T.barrier_wait(bar_q, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + T.clear(acc_s) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + if i_i != 0: + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, out=m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(block_H): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(block_H, block_N): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(block_H): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_0_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_0_free[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + T.clear(acc_s) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(block_H): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(block_H, block_N): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(block_H): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_1_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_1_free[0]) + + # Rescale + for h_i in T.Parallel(block_H): + sum_exp_shared[h_i] = sumexp[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(block_H): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2]) + + elif tx >= 128 and tx < 256: + T.set_max_nreg(168, 1) + T.fill(acc_o_r, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_0_r, acc_o_r) + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_sScale_and_sS_free) + + # Buffer 1 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_1_r, acc_o_r) + T.barrier_arrive(bar_k_1_free[0]) + if i_i != T.ceildiv(NI, 2) - 1: + T.barrier_arrive(bar_sScale_and_sS_free) + + # Rescale + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim]) + + elif tx >= 256: + # producer + T.set_max_nreg(80, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + kv_indices = (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + kv_indices = (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + @T.macro + def flash_attn_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=384) as (bid, hid, bz): + Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype) + Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype) + Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype) + KV_shared_0_l = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_0_r = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_1_l = T.alloc_shared([block_N, dim // 2], dtype) + KV_shared_1_r = T.alloc_shared([block_N, dim // 2], dtype) + K_tail_shared_0 = T.alloc_shared([block_N, pe_dim], dtype) + K_tail_shared_1 = T.alloc_shared([block_N, pe_dim], dtype) + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + + acc_o_l = T.alloc_fragment([block_H, dim // 2], accum_dtype) + acc_o_r = T.alloc_fragment([block_H, dim // 2], accum_dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + sumexp = T.alloc_fragment([block_H], accum_dtype) + sum_exp_shared = T.alloc_shared([block_H], accum_dtype) + sumexp_i = T.alloc_fragment([block_H], accum_dtype) + alpha_shared = T.alloc_shared([block_H], accum_dtype, scope="shared") + alpha_local = T.alloc_fragment([block_H], accum_dtype) + m_i = T.alloc_fragment([block_H], accum_dtype) + m_i_prev = T.alloc_fragment([block_H], accum_dtype) + + # TODO: Multi buffer + bar_q = T.alloc_barrier(arrive_count=384) + bar_k_0_ready = T.alloc_barrier(arrive_count=128) + bar_k_1_ready = T.alloc_barrier(arrive_count=128) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) + + cur_kv_head = hid // (kv_group_num // block_H) + NI = T.ceildiv((seqlen_kv // num_split), block_N) + + tx = T.get_thread_binding() + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared) + + T.barrier_arrive(bar_q) + + if tx < 128: + T.set_max_nreg(240, 1) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + T.fill(acc_o_l, 0) + T.barrier_wait(bar_q, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + T.clear(acc_s) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + if i_i != 0: + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(block_H): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(block_H, block_N): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(block_H): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_0_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_0_free[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + T.clear(acc_s) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(block_H): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(block_H, block_N): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(block_H): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_1_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_1_free[0]) + + # Rescale + for h_i in T.Parallel(block_H): + sum_exp_shared[h_i] = sumexp[h_i] + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_l[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(block_H): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, 0 : dim // 2]) + T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz]) + + elif tx >= 128 and tx < 256: + T.set_max_nreg(168, 1) + T.fill(acc_o_r, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_0_r, acc_o_r) + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_sScale_and_sS_free) + + # Buffer 1 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_1_r, acc_o_r) + T.barrier_arrive(bar_k_1_free[0]) + if i_i != T.ceildiv(NI, 2) - 1: + T.barrier_arrive(bar_sScale_and_sS_free) + + # Rescale + for h_i, d_i in T.Parallel(block_H, dim // 2): + acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, dim // 2 : dim]) + + elif tx >= 256: + # producer + T.set_max_nreg(80, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ + bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads, batch, threads=128) as (hid, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, hid, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, hid, k, i] + lse_local_split[0] = glse[bz, hid, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, hid, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn(Q, Q_pe, KV, K_pe, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def main( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + num_split = 1 + softmax_scale = (dim + pe_dim) ** -0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=132, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..fa39fa498f552c9409dfbd313a85a985e710c054 --- /dev/null +++ b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -0,0 +1,150 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + q_dtype = T.float8_e4m3fn + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by): + Q_shared = T.alloc_shared([block_H, dim], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + qKV_shared = T.alloc_shared([block_N, dim], q_dtype) + KV_shared = T.alloc_shared([block_N, dim], dtype) + K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) + O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.disable_warp_group_reg_alloc() + loop_range = T.ceildiv(seqlen_kv, block_N) + for k in T.Pipelined(loop_range, num_stages=2): + T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], qKV_shared) + T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) + T.copy(qKV_shared, KV_shared) + + T.clear(acc_s) + T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :]) + + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = 64 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") diff --git a/examples/deepseek_mla/figures/bs128_float16.png b/examples/deepseek_mla/figures/bs128_float16.png new file mode 100644 index 0000000000000000000000000000000000000000..3cf24c84b82532bf422efee26afe61b4ae0e1948 Binary files /dev/null and b/examples/deepseek_mla/figures/bs128_float16.png differ diff --git a/examples/deepseek_mla/figures/bs64_float16.png b/examples/deepseek_mla/figures/bs64_float16.png new file mode 100644 index 0000000000000000000000000000000000000000..15807c3d2e57f5a2848b792d0fe746db31be455d Binary files /dev/null and b/examples/deepseek_mla/figures/bs64_float16.png differ diff --git a/examples/deepseek_mla/figures/flashmla-amd.png b/examples/deepseek_mla/figures/flashmla-amd.png new file mode 100644 index 0000000000000000000000000000000000000000..75470bb30184b866402124fe1917eb7591623a7e Binary files /dev/null and b/examples/deepseek_mla/figures/flashmla-amd.png differ diff --git a/examples/deepseek_mla/figures/pv_layout.jpg b/examples/deepseek_mla/figures/pv_layout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..79b0c8cf301d9c04eef050c893156c71549ce03d Binary files /dev/null and b/examples/deepseek_mla/figures/pv_layout.jpg differ diff --git a/examples/deepseek_mla/figures/qk_layout.jpg b/examples/deepseek_mla/figures/qk_layout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d5bd923d0d8ab1fe5edece222f31777ccd0d746 Binary files /dev/null and b/examples/deepseek_mla/figures/qk_layout.jpg differ diff --git a/examples/deepseek_mla/test_example_mla_decode.py b/examples/deepseek_mla/test_example_mla_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..a269ea57aed102b83596d4c7a896322fb105fbfb --- /dev/null +++ b/examples/deepseek_mla/test_example_mla_decode.py @@ -0,0 +1,12 @@ +import tilelang.testing +import example_mla_decode + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mla_decode(): + example_mla_decode.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/deepseek_mla/torch_refs.py b/examples/deepseek_mla/torch_refs.py new file mode 100644 index 0000000000000000000000000000000000000000..aae6c7cd2b619afee90f39058cfd9a4a6a71e49e --- /dev/null +++ b/examples/deepseek_mla/torch_refs.py @@ -0,0 +1,81 @@ +import torch + +num_split = 1 + + +def flash_split_ref(Q, Q_pe, KV, K_pe): + dim = Q.shape[-1] + pe_dim = Q_pe.shape[-1] + batch = Q.size(0) + nheads = Q.size(1) + block_N = 64 + seqlen_kv = KV.size(1) + + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float) + glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) + + Q_ = Q * scale + Q_pe_ = Q_pe * scale + KV_ = KV.expand(-1, -1, nheads, -1) + K_pe_ = K_pe.expand(-1, -1, nheads, -1) + + for ks in range(num_split): + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) + for i in range(int((seqlen_kv // num_split) / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum( + "bhd,bkhd->bhk", + Q_, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] + acc_s += torch.einsum( + "bhd,bkhd->bhk", + Q_pe_, + K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] + scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] + acc_o *= scores_scale[:, :, None] + acc_s = torch.exp2(acc_s - scores_max[:, :, None]) + acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] + acc_o += torch.einsum( + "bhk,bkhd->bhd", + acc_s_cast, + KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o /= logsum[:, :, None] + logsum = torch.log2(logsum) + scores_max + gacc_o[ks, :, :, :] = acc_o + glogsum[ks, :, :] = logsum + + return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3) + + +def reduce_ref(Q, Q_pe, KV, K_pe, glse, Output_partial): + o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0) + lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) + lse_max = glse.max(dim=2, keepdim=False).values + for ks in range(num_split): + lse = glse[:, :, ks] + lse_logsum += torch.exp2(lse - lse_max) + lse_logsum = torch.log2(lse_logsum) + lse_max + for ks in range(num_split): + lse = glse[:, :, ks] + scale = torch.exp2(lse - lse_logsum) + o += Output_partial[:, :, ks, :] * scale[:, :, None] + return o.to(torch.float16) diff --git a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..dadb4b4cb916ebc6f8ebce9120071be10a55e5e4 --- /dev/null +++ b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -0,0 +1,954 @@ +# ruff: noqa + +import torch +import time +import argparse +import tilelang +from tilelang import language as T +import tilelang.testing +from typing import Optional, Union +from einops import rearrange, repeat +import triton +import triton.language as tl +from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_fwd, contiguous + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o_slc: torch.Tensor, + o_swa: Optional[torch.Tensor], + lse_slc: torch.Tensor, + lse_swa: Optional[torch.Tensor], + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + return o_slc, lse_slc, o_swa, lse_swa + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None`. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + + dtype = q.dtype + G = q.shape[2] // k.shape[2] + BS = block_size + S = block_indices.shape[-1] + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + if isinstance(block_counts, torch.Tensor): + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + + o_slc = torch.zeros_like(v) + o_swa = torch.zeros_like(v) if window_size > 0 else None + varlen = True + if cu_seqlens is None: + varlen = False + B, T = q.shape[:2] + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) + + for i in range(len(cu_seqlens) - 1): + if not varlen: + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[i] + else: + s_b = block_counts + else: + T = cu_seqlens[i + 1] - cu_seqlens[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] + else: + s_b = block_counts + + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [HQ] + g_slc_i = g_slc_b[i_q] + # [HQ] + g_swa_i = g_swa_b[i_q] + # [S*BS, HQ] + i_i = i_b[i_q] + # [HQ] + if isinstance(block_counts, torch.Tensor): + s_i = s_b[i_q] + else: + s_i = s_b + # [S*BS, HQ, -1] + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + # [S*BS, HQ] + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) + if not varlen: + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + else: + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + if window_size > 0: + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) + if not varlen: + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + else: + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + + if head_first: + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") + + return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) + + +def get_configs(): + import itertools + + iter_params = dict( + block_T=[128, 256, 512], + num_stages=[0, 1, 2, 4, 5], + threads=[32, 64, 128, 256, 512], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def tilelang_sparse_attention( + batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32 +): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + else: + scale = scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(block_T, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + + @T.prim_func + def tilelang_sparse_attention( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_shared([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)}) + + i_t, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + if is_causal: + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(G, BV): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) + + return tilelang_sparse_attention + + +def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size): + """Generate random block indices for the benchmark.""" + block_indices = torch.full((batch, seq_len, heads, selected_blocks), seq_len, dtype=torch.long, device="cuda") + + for b in range(batch): + for t in range(seq_len): + for h in range(heads): + i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks] + block_indices[b, t, h, : len(i_i)] = i_i + + return block_indices.sort(-1)[0] + + +def benchmark_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): + """Benchmark the TileLang Sparse Attention implementation.""" + + # Set random seed for reproducibility + tilelang.testing.set_random_seed(0) + torch.random.manual_seed(0) + + # Compile the NSA kernel + kernel = tilelang_sparse_attention( + batch=batch_size, + heads=head_query, + seq_len=seq_len, + dim=dim, + is_causal=True, + block_size=block_size, + groups=head_query // heads, + selected_blocks=selected_blocks, + scale=scale, + ) + + profiler = kernel.get_profiler() + + profiler_latency = profiler.do_bench() + print(f"Profiler latency: {profiler_latency} ms") + + # Create input tensors + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + + # Generate block indices + block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size).to(torch.int32) + + # Warmup + for _ in range(warmup): + kernel(Q, K, V, block_indices, out) + + # Synchronize before timing + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(iterations): + kernel(Q, K, V, block_indices, out) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate metrics + elapsed_time = end_time - start_time + avg_time = elapsed_time / iterations * 1000 # ms + + # Calculate FLOPs (approximate for NSA) + # Each token attends to selected_blocks * block_size tokens + # Each attention calculation involves 2*dim FLOPs for QK + # And another 2*dim FLOPs for attention * V + flops_per_token = 4 * dim * selected_blocks * block_size + total_flops = batch_size * seq_len * head_query * flops_per_token + flops_per_sec = total_flops / (elapsed_time / iterations) + tflops = flops_per_sec / 1e12 + + # Validate result against reference if requested + if validate: + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") + + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + is_valid = torch.allclose(ref, out, atol=1e-2, rtol=1e-2) + if is_valid: + print("Validation: PASSED") + else: + print("Validation: FAILED") + print(f"Max difference: {(ref - out).abs().max().item()}") + + # Return benchmark results + return { + "avg_time_ms": avg_time, + "tflops": tflops, + "batch_size": batch_size, + "seq_len": seq_len, + "heads": heads, + "head_query": head_query, + "dim": dim, + "selected_blocks": selected_blocks, + "block_size": block_size, + } + + +def benchmark_triton_nsa( + batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False +): + """Benchmark the Triton-based TileLang Sparse Attention implementation.""" + + # Set random seed for reproducibility + tilelang.testing.set_random_seed(0) + torch.random.manual_seed(0) + + # Create input tensors + Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda") + g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda") + + # Generate block indices + block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size) + block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda") + o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda") + lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device="cuda") + + # Warmup + for _ in range(warmup): + out = parallel_nsa_fwd( + q=Q, + k=K, + v=V, + o_slc=o_slc, + o_swa=None, + lse_slc=lse_slc, + lse_swa=None, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=0, + scale=scale, + ) + + # Synchronize before timing + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(iterations): + out = parallel_nsa_fwd( + q=Q, + k=K, + v=V, + o_slc=o_slc, + o_swa=None, + lse_slc=lse_slc, + lse_swa=None, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=0, + scale=scale, + ) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate metrics + elapsed_time = end_time - start_time + avg_time = elapsed_time / iterations * 1000 # ms + + # Calculate FLOPs (approximate for NSA) + flops_per_token = 4 * dim * selected_blocks * block_size + total_flops = batch_size * seq_len * head_query * flops_per_token + flops_per_sec = total_flops / (elapsed_time / iterations) + tflops = flops_per_sec / 1e12 + + # Validate result against reference if requested + if validate: + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + is_valid = torch.allclose(ref, out, atol=1e-2, rtol=1e-2) + if is_valid: + print("Validation: PASSED") + else: + print("Validation: FAILED") + print(f"Max difference: {(ref - out).abs().max().item()}") + + # Return benchmark results + return { + "avg_time_ms": avg_time, + "tflops": tflops, + "batch_size": batch_size, + "seq_len": seq_len, + "heads": heads, + "head_query": head_query, + "dim": dim, + "selected_blocks": selected_blocks, + "block_size": block_size, + } + + +def run_benchmark_suite(impl="all"): + """Run a suite of benchmarks with different configurations.""" + + # Define configurations to benchmark + configs = [ + # Small model config - Note: head_query must be a multiple of heads*16 for Triton + {"batch_size": 2, "seq_len": 1024, "heads": 8, "head_query": 8 * 16, "dim": 64, "selected_blocks": 8, "block_size": 32}, + # Medium model config + {"batch_size": 2, "seq_len": 2048, "heads": 16, "head_query": 16 * 16, "dim": 64, "selected_blocks": 16, "block_size": 64}, + # Large model config + {"batch_size": 1, "seq_len": 4096, "heads": 32, "head_query": 32 * 16, "dim": 128, "selected_blocks": 32, "block_size": 128}, + ] + + results = [] + for config in configs: + print(f"Running benchmark with config: {config}") + + if impl in ["all", "tilelang"]: + print("Benchmarking TileLang implementation:") + result = benchmark_nsa( + batch_size=config["batch_size"], + seq_len=config["seq_len"], + heads=config["heads"], + head_query=config["head_query"], + dim=config["dim"], + selected_blocks=config["selected_blocks"], + block_size=config["block_size"], + dtype=torch.float16, + scale=0.1, + validate=False, + ) + results.append({"impl": "tilelang", **result}) + print(f"Average time: {result['avg_time_ms']:.2f} ms") + print(f"Performance: {result['tflops']:.2f} TFLOPs") + + if impl in ["all", "triton"]: + print("Benchmarking Triton implementation:") + result = benchmark_triton_nsa( + batch_size=config["batch_size"], + seq_len=config["seq_len"], + heads=config["heads"], + head_query=config["head_query"], + dim=config["dim"], + selected_blocks=config["selected_blocks"], + block_size=config["block_size"], + dtype=torch.float16, + scale=0.1, + validate=False, + ) + results.append({"impl": "triton", **result}) + print(f"Average time: {result['avg_time_ms']:.2f} ms") + print(f"Performance: {result['tflops']:.2f} TFLOPs") + + if impl in ["all"]: + # Print comparison if both implementations were run + tilelang_result = next( + r + for r in results + if r["impl"] == "tilelang" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) + triton_result = next( + r + for r in results + if r["impl"] == "triton" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"] + ) + speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"] + print(f"Speedup (Triton vs TileLang): {speedup:.2f}x") + + print("-" * 50) + + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark TileLang Sparse Attention") + parser.add_argument("--batch", type=int, default=32, help="Batch size") + parser.add_argument("--seq_len", type=int, default=1024, help="Sequence length") + parser.add_argument("--heads", type=int, default=1, help="Number of heads") + parser.add_argument("--head_query", type=int, default=16, help="Number of query heads") + parser.add_argument("--dim", type=int, default=128, help="Head dimension") + parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks") + parser.add_argument("--block_size", type=int, default=32, help="Block size") + parser.add_argument("--dtype", type=str, default=T.float16, help="Data type (float16 or float32)") + parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor") + parser.add_argument("--iterations", type=int, default=100, help="Number of iterations") + parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") + parser.add_argument("--validate", action="store_true", help="Validate against reference") + parser.add_argument("--suite", action="store_true", help="Run benchmark suite") + parser.add_argument( + "--impl", + type=str, + default="all", + choices=["tilelang", "triton", "all"], + help="Implementation to benchmark (tilelang, triton, or all)", + ) + + args = parser.parse_args() + + # For Triton impl, ensure head_query is a multiple of heads*16 + if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0: + # Adjust head_query to nearest valid value + args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16) + print(f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation") + + if args.suite: + run_benchmark_suite(impl=args.impl) + else: + dtype = torch.float16 if args.dtype == T.float16 else torch.float32 + + if args.impl in ["tilelang", "all"]: + print("Benchmarking TileLang implementation:") + result = benchmark_nsa( + batch_size=args.batch, + seq_len=args.seq_len, + heads=args.heads, + head_query=args.head_query, + dim=args.dim, + selected_blocks=args.selected_blocks, + block_size=args.block_size, + dtype=dtype, + scale=args.scale, + warmup=args.warmup, + iterations=args.iterations, + validate=args.validate, + ) + print("\nBenchmark Results (TileLang):") + print( + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) + print(f"Average time: {result['avg_time_ms']:.2f} ms") + print(f"Performance: {result['tflops']:.2f} TFLOPs") + + if args.impl in ["triton", "all"]: + print("Benchmarking Triton implementation:") + result = benchmark_triton_nsa( + batch_size=args.batch, + seq_len=args.seq_len, + heads=args.heads, + head_query=args.head_query, + dim=args.dim, + selected_blocks=args.selected_blocks, + block_size=args.block_size, + dtype=dtype, + scale=args.scale, + warmup=args.warmup, + iterations=args.iterations, + validate=args.validate, + ) + print("\nBenchmark Results (Triton):") + print( + f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " + + f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " + + f"block_size={args.block_size}" + ) + print(f"Average time: {result['avg_time_ms']:.2f} ms") + print(f"Performance: {result['tflops']:.2f} TFLOPs") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..41f1dd86b99833d56b3164c865f55d6ae315311e --- /dev/null +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -0,0 +1,865 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import torch +import triton + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from reference import naive_nsa +from einops import rearrange +import tilelang + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def tilelang_kernel_fwd( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, +): + from tilelang import language as T + + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + else: + scale = scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + o_slc_shape = [batch, seq_len, heads, dim] + lse_slc_shape = [batch, seq_len, heads] + block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(128, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + num_stages = 0 + threads = 32 + + @T.prim_func + def native_sparse_attention( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + ): + with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_fragment([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_t, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + if is_causal: + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for k in T.Parallel(G): + scores_scale[k] = T.exp2(scores_max_prev[k] * scale - scores_max[k] * scale) + for k, j in T.Parallel(G, BS): + acc_s[k, j] = T.exp2(acc_s[k, j] * scale - scores_max[k] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for k in T.Parallel(G): + logsum[k] = logsum[k] * scores_scale[k] + scores_sum[k] + T.copy(acc_s, acc_s_cast) + + # Rescale + for k, j in T.Parallel(G, BV): + acc_o[k, j] *= scores_scale[k] + + # V * softmax(Q * K) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy( + O_shared, + O_slc[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV], + ) + for i in T.Parallel(G): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, LSE_slc[i_b, i_t, i_h * G : (i_h + 1) * G]) + + return native_sparse_attention + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def tilelang_kernel_bwd_dkv( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + dtype=T.float16, + accum_dtype=T.float32, +): + if scale is None: + sm_scale = (1.0 / dim) ** 0.5 + else: + sm_scale = scale + + scale = sm_scale * 1.44269504 + + from tilelang import language as T + + B = batch + BS = block_size + G = groups + V = dim + K = dim + BK = tilelang.next_power_of_2(K) + BV = min(128, tilelang.next_power_of_2(dim)) + NS = tilelang.cdiv(seq_len, BS) + NV = tilelang.cdiv(V, BV) + + heads_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + k_shape = [batch, seq_len, heads_kv, dim] + v_shape = [batch, seq_len, heads_kv, dim] + lse_slc_shape = [batch, seq_len, heads] + delta_slc_shape = [batch, seq_len, heads] + o_shape = [batch, heads, seq_len, dim] + do_slc_shape = [batch, seq_len, heads, dim] + dk_shape = [NV, batch, seq_len, heads_kv, dim] + dv_shape = [batch, seq_len, heads_kv, dim] + + block_mask_shape = [batch, seq_len, heads_kv, NS] + num_threads = 32 + print("NV", NV, "NS", NS, "B", B, "H", H) + + @T.prim_func + def flash_bwd_dkv( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), + ): + with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + Q_shared = T.alloc_shared([G, BK], dtype) + qkT = T.alloc_fragment([BS, G], accum_dtype) + qkT_cast = T.alloc_fragment([BS, G], dtype) + dsT = T.alloc_fragment([BS, G], accum_dtype) + dsT_cast = T.alloc_fragment([BS, G], dtype) + lse_shared = T.alloc_shared([G], accum_dtype) + delta = T.alloc_shared([G], accum_dtype) + + do = T.alloc_shared([G, BV], dtype) + dv = T.alloc_fragment([BS, BV], accum_dtype) + dk = T.alloc_fragment([BS, BK], accum_dtype) + dq = T.alloc_fragment([BS, G], accum_dtype) + + dv_shared = T.alloc_shared([BS, BV], dtype) + dk_shared = T.alloc_shared([BS, BK], dtype) + + i_b, i_h = i_bh // H, i_bh % H + + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) + + # [BS, BK] + T.clear(dk) + # [BS, BV] + T.clear(dv) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + loop_st = i_s * BS + loop_ed = seq_len + for i in T.Pipelined( + start=loop_st, + stop=loop_ed, + num_stages=0, + ): + b_m_slc = BlockMask[i_b, i, i_h, i_s] + if b_m_slc != 0: + # [G, BK] + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) + T.clear(qkT) + # [BS, BK] @ [G, BK] -> [BS, G] + T.gemm( + K_shared, + Q_shared, + qkT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + # [G] + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) + + for _i, _j in T.Parallel(BS, G): + qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) + + for _i, _j in T.Parallel(BS, G): + qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) + + # [G, BV] + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) + T.clear(dsT) + # [BS, BV] @ [G, BV] -> [BS, G] + T.gemm( + V_shared, + do, + dsT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(qkT, qkT_cast) + # [BS, G] @ [G, BV] -> [BS, BV] + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + # [G] + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) + for i, j in T.Parallel(BS, G): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + + # [BS, G] @ [G, BK] -> [BS, BK] + T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) + + return flash_bwd_dkv + + +def make_dq_layout(dQ): + from tilelang import language as T + + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout( + dQ.shape, + lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2], + ) + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def tilelang_kernel_bwd_dqkv( + batch, + heads, + seq_len, + dim, + is_causal, + scale=None, + block_size=64, + groups=1, + selected_blocks=16, + dtype=T.float16, + accum_dtype=T.float32, +): + if scale is None: + sm_scale = (1.0 / dim) ** 0.5 + else: + sm_scale = scale + + scale = sm_scale * 1.44269504 + + from tilelang import language as T + + B = batch + BS = block_size + G = groups + V = dim + K = dim + BK = tilelang.next_power_of_2(K) + BV = min(128, tilelang.next_power_of_2(dim)) + NS = tilelang.cdiv(seq_len, BS) + NV = tilelang.cdiv(V, BV) + + heads_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + k_shape = [batch, seq_len, heads_kv, dim] + v_shape = [batch, seq_len, heads_kv, dim] + lse_slc_shape = [batch, seq_len, heads] + delta_slc_shape = [batch, seq_len, heads] + o_shape = [batch, heads, seq_len, dim] + do_slc_shape = [batch, seq_len, heads, dim] + dq_shape = [NV, batch, seq_len, heads, dim] + dk_shape = [NV, batch, seq_len, heads_kv, dim] + dv_shape = [batch, seq_len, heads_kv, dim] + + block_mask_shape = [batch, seq_len, heads_kv, NS] + num_threads = 32 + + @T.prim_func + def flash_bwd_dqkv( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(k_shape, dtype), + V: T.Tensor(v_shape, dtype), + LSE_slc: T.Tensor(lse_slc_shape, accum_dtype), + Delta_slc: T.Tensor(delta_slc_shape, accum_dtype), + DO_slc: T.Tensor(do_slc_shape, dtype), + DQ: T.Tensor(dq_shape, dtype), + DK: T.Tensor(dk_shape, dtype), + DV: T.Tensor(dv_shape, dtype), + BlockMask: T.Tensor(block_mask_shape, T.int32), + ): + with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): + K_shared = T.alloc_shared([BS, BK], dtype) + dsT_shared = T.alloc_shared([BS, G], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + Q_shared = T.alloc_shared([G, BK], dtype) + qkT = T.alloc_fragment([BS, G], accum_dtype) + qkT_cast = T.alloc_fragment([BS, G], dtype) + dsT = T.alloc_fragment([BS, G], accum_dtype) + dsT_cast = T.alloc_fragment([BS, G], dtype) + lse_shared = T.alloc_shared([G], accum_dtype) + delta = T.alloc_shared([G], accum_dtype) + + do = T.alloc_shared([G, BV], dtype) + dv = T.alloc_fragment([BS, BV], accum_dtype) + dk = T.alloc_fragment([BS, BK], accum_dtype) + dq = T.alloc_fragment([G, BK], accum_dtype) + + dv_shared = T.alloc_shared([BS, BV], dtype) + dk_shared = T.alloc_shared([BS, BK], dtype) + + i_b, i_h = i_bh // H, i_bh % H + + T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared) + T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared) + + # [BS, BK] + T.clear(dk) + # [BS, BV] + T.clear(dv) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + loop_st = i_s * BS + loop_ed = seq_len + for i in T.Pipelined( + start=loop_st, + stop=loop_ed, + num_stages=0, + ): + b_m_slc = BlockMask[i_b, i, i_h, i_s] + if b_m_slc != 0: + # [G, BK] + T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared) + T.clear(qkT) + # [BS, BK] @ [G, BK] -> [BS, G] + T.gemm( + K_shared, + Q_shared, + qkT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + # [G] + T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared) + + for _i, _j in T.Parallel(BS, G): + qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j]) + + for _i, _j in T.Parallel(BS, G): + qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0) + + # [G, BV] + T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do) + T.clear(dsT) + # [BS, BV] @ [G, BV] -> [BS, G] + T.gemm( + V_shared, + do, + dsT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(qkT, qkT_cast) + # [BS, G] @ [G, BV] -> [BS, BV] + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + # [G] + T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta) + for _i, _j in T.Parallel(BS, G): + dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale + + # [BS, G] @ [G, BK] -> [BS, BK] + T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + # [BS, G] * [BS, BK] -> [G, BK] + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for _i, _j in T.Parallel(G, BK): + T.atomic_add(DQ[i_v, i_b, i, i_h * G + _i, _j], dq[_i, _j]) + + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV]) + T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK]) + + return flash_bwd_dqkv + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def tilelang_kernel_preprocess( + batch, + heads, + seq_len, + dim, + dtype=T.float16, + accum_dtype=T.float32, + blk=32, +): + from tilelang import language as T + + shape = [batch, seq_len, heads, dim] + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, by * blk : (by + 1) * blk, bx]) + + return flash_bwd_prep + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def tilelang_kernel_block_mask( + batch, + heads, + seq_len, + selected_blocks, + block_size, + dtype=T.int32, +): + from tilelang import language as T + + block_indices_shape = [batch, seq_len, heads, selected_blocks] + block_counts_shape = [batch, seq_len, heads] + S = selected_blocks + BS = block_size + NS = tilelang.cdiv(seq_len, BS) + + block_mask_shape = [batch, seq_len, heads, NS] + USE_BLOCK_COUNTS = block_counts is not None + + @T.prim_func + def flash_bwd_block_mask( + BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore + BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore + BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore + ): + with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz): + i_t, i_b, i_hs = bx, by, bz + i_h, i_s = i_hs // S, i_hs % S + b_i = BlockIndices[i_b, i_t, i_h, i_s] + if USE_BLOCK_COUNTS: + b_m = b_i * BS <= i_t and i_s < BlockCounts[i_b, i_t, i_h].astype(i_s.dtype) + BlockMask[i_b, i_t, i_h, i_s] = b_m + else: + b_m = b_i * BS <= i_t + BlockMask[i_b, i_t, i_h, i_s] = b_m + + return flash_bwd_block_mask + + +def parallel_nsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o_slc: torch.Tensor, + lse_slc: torch.Tensor, + do_slc: torch.Tensor, + o_swa: torch.Tensor, + lse_swa: torch.Tensor, + do_swa: torch.Tensor, + block_indices: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + window_size: int = 0, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + + assert window_size == 0, "Window size is not supported yet" + delta_slc = tilelang_kernel_preprocess(B, HQ, T, K)(o_slc, do_slc) + + dq = torch.zeros(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool) + + fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv( + batch=B, + heads=HQ, + seq_len=T, + dim=K, + is_causal=True, + block_size=BS, + groups=G, + selected_blocks=S, + scale=scale, + ) + fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, block_mask.to(torch.int32)) + + dq = dq.sum(0) + dk = dk.sum(0) + return dq, dk, dv + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + block_indices, + block_counts, + block_size, + window_size, + scale, + offsets, + ): + ctx.dtype = q.dtype + assert offsets is None, "Offsets are not supported yet" + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + B, SEQLEN, HQ, D = q.shape + H = k.shape[2] + G = HQ // H + S = block_indices.shape[-1] + V = v.shape[-1] + kernel = tilelang_kernel_fwd( + batch=B, + heads=HQ, + seq_len=SEQLEN, + dim=D, + is_causal=True, + scale=scale, + block_size=block_size, + groups=G, + selected_blocks=S, + ) + o_slc = torch.empty(B, SEQLEN, HQ, D, dtype=v.dtype, device=q.device) + lse_slc = torch.empty(B, SEQLEN, HQ, dtype=torch.float, device=q.device) + kernel(q, k, v, block_indices.to(torch.int32), o_slc, lse_slc) + + ctx.save_for_backward(q, k, v, o_slc, lse_slc) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), lse_slc.to(torch.float) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do_slc, do_swa): + q, k, v, o_slc, lse_slc = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=None, + lse_slc=lse_slc, + lse_swa=None, + do_slc=do_slc, + do_swa=do_swa, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + window_size=ctx.window_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices, + ) + return ( + dq.to(q), + dk.to(k), + dv.to(v), + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, SEQLEN, HQ, K]` if `head_first=False` else `[B, HQ, SEQLEN, K]`. + k (torch.Tensor): + keys of shape `[B, SEQLEN, H, K]` if `head_first=False` else `[B, H, SEQLEN, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, SEQLEN, H, V]` if `head_first=False` else `[B, H, SEQLEN, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, SEQLEN, HQ]` if `head_first=False` else `[B, HQ, SEQLEN]`. + g_swa (torch.Tensor): + Gate score for sliding attention of shape `[B, SEQLEN, HQ]` if `head_first=False` else `[B, HQ, SEQLEN]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, SEQLEN, H, S]` if `head_first=False` else `[B, H, SEQLEN, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, SEQLEN, H]` if `head_first=True` else `[B, SEQLEN, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + B, T, H, HQ, D, S, block_size, dtype = 1, 32, 1, 16, 32, 1, 32, torch.float16 + torch.random.manual_seed(0) + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(T): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + ) + ref.backward(do) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_size=block_size, + block_counts=block_counts, + ) + tri.backward(do) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None + + # assert_close(" o", ref, tri, 0.004) + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dg_slc, tri_dg_slc, atol=1e-2, rtol=1e-2) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..b7eea58049b388ceca54c7f1883d1d5a4ab755a6 --- /dev/null +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -0,0 +1,176 @@ +# ruff: noqa +import torch +from reference import naive_nsa_simple_inference +import tilelang +from tilelang import language as T +import tilelang.testing + +tilelang.testing.set_random_seed(42) + + +# TODO(lei): workaround, as threads is not divisible by warp group size, +# auto warp specialization may have some bugs. +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def native_sparse_attention( + batch, + heads, + seq_len, # Length of K/V sequences (context window size) + dim, # Embedding dimension per head + scale=None, + block_size=64, # Tile size for attention computation + groups=1, # Grouped query attention (GQA) groups + selected_blocks=16, # Number of blocks to select per attention head +): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + # Modified shapes for inference (q has seq_len=1)a + q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 + kv_shape = [batch, seq_len, head_kv, dim] + block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1 + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(128, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + num_stages = 0 + threads = 32 + + @T.prim_func + def native_sparse_attention( + Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim] + K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim] + V: T.Tensor(kv_shape, dtype), # Same shape as K + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), # Selected block indices + Output: T.Tensor(q_shape, dtype), # Output attention tensor + ): + with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz): + # Shared memory allocations for tile storage + Q_shared = T.alloc_shared([G, BK], dtype) # Current query block + K_shared = T.alloc_shared([BS, BK], dtype) # Current key block + V_shared = T.alloc_shared([BS, BV], dtype) # Current value block + O_shared = T.alloc_shared([G, BV], dtype) # Output accumulator + + # Attention computation buffers + acc_s = T.alloc_fragment([G, BS], accum_dtype) # QK^T scores + acc_s_cast = T.alloc_fragment([G, BS], dtype) # Casted scores for softmax + acc_o = T.alloc_fragment([G, BV], accum_dtype) # Output accumulator + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_v, i_bh = by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + # Copy Q for the single position + T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared) # Changed i_t to 0 + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # Main attention computation loop over selected blocks + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset + if i_s >= 0: # Skip invalid/padding blocks + # Load current key block to shared memory + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + # Compute QK^T attention scores + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Online softmax with numerical stability + # 1. Compute max for scaling + # 2. Compute exponentials and sum + # 3. Maintain running logsum for normalization + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Accumulate attention-weighted values + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Final normalization and output + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] # Normalize by logsum + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) # Changed i_t to 0 + + return native_sparse_attention + + +def main(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 + groups = HQ // H + SEQ_LEN_Q = 1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + ) + + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + + mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda") + DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN_Q): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda") + + out = kernel(Q, K, V, block_indices.to(torch.int32)) + + ref = naive_nsa_simple_inference( + q=Q, + k=K, + v=V, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + ) + torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..ad36b10402429cdf24b87016c084b2790a0ba0eb --- /dev/null +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -0,0 +1,175 @@ +# ruff: noqa +import torch +from reference import naive_nsa +import tilelang +from tilelang import language as T +import tilelang.testing + +tilelang.testing.set_random_seed(0) + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + else: + scale = scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + block_indices_shape = [batch, seq_len, head_kv, selected_blocks] + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(128, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + num_stages = 2 + threads = 32 + + @T.prim_func + def native_sparse_attention( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_fragment([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_t, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + NS = S + T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[i_b, i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared) + + if is_causal: + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(G, BV): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) + + return native_sparse_attention + + +def main(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + ) + print(kernel.get_kernel_source()) + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda") + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + out = kernel(Q, K, V, block_indices.to(torch.int32)) + + ref = naive_nsa( + q=Q, + k=K, + v=V, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + ) + + print("out", out) + print("ref", ref) + torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..b52ebe42e210823de24107b9990cc111d5c8f1b3 --- /dev/null +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -0,0 +1,380 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import tilelang +from tilelang import language as T +import tilelang.testing + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from reference import naive_nsa +from einops import rearrange + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): + if scale is None: + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [c_seq_len, heads, dim] + kv_shape = [c_seq_len, head_kv, dim] + o_slc_shape = [c_seq_len, heads, dim] + o_swa_shape = [c_seq_len, heads, dim] + lse_slc_shape = [c_seq_len, heads] + lse_swa_shape = [c_seq_len, heads] + block_indices_shape = [c_seq_len, head_kv, selected_blocks] + block_counts_shape = [c_seq_len, head_kv] + offsets_shape = [batch + 1] + token_indices_shape = [c_seq_len, 2] + block_indices_dtype = T.int32 + block_counts_dtype = T.int32 + offsets_dtype = T.int32 + token_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 + block_S = block_size + block_T = min(128, tilelang.math.next_power_of_2(dim)) + + NK = tilelang.cdiv(dim, block_T) + NV = tilelang.cdiv(dim, block_T) + assert NK == 1, "The key dimension can not be larger than 256" + + S = selected_blocks + G = groups + BS = block_S + BK = BV = block_T + num_stages = 0 + threads = 32 + + @T.prim_func + def native_sparse_attention_varlen( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + O_slc: T.Tensor(o_slc_shape, dtype), + BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), + BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype), + Offsets: T.Tensor(offsets_shape, offsets_dtype), + TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype), + ): + with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([G, BK], dtype) + K_shared = T.alloc_shared([BS, BK], dtype) + V_shared = T.alloc_shared([BS, BV], dtype) + O_shared = T.alloc_shared([G, BV], dtype) + + acc_s = T.alloc_fragment([G, BS], accum_dtype) + acc_s_cast = T.alloc_fragment([G, BS], dtype) + acc_o = T.alloc_fragment([G, BV], accum_dtype) + scores_max = T.alloc_fragment([G], accum_dtype) + scores_max_prev = T.alloc_fragment([G], accum_dtype) + scores_scale = T.alloc_fragment([G], accum_dtype) + scores_sum = T.alloc_fragment([G], accum_dtype) + logsum = T.alloc_fragment([G], accum_dtype) + + i_c, i_v, i_bh = bx, by, bz + i_b, i_h = i_bh // head_kv, i_bh % head_kv + + i_n, i_t = TokenIndices[i_c, 0], TokenIndices[i_c, 1] + + bos = Offsets[i_n] + eos = Offsets[i_n + 1] + current_seq_len = eos - bos + + NS = BlockCounts[i_t, i_h] + T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for i in T.Pipelined(NS, num_stages=num_stages): + i_s = BlockIndices[bos + i_t, i_h, i] * BS + if i_s <= i_t and i_s >= 0: + # [BS, BK] + # Lei: may have some padding issues + # we should learn from mha varlen templates to handle this + T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared) + + if is_causal: + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=True) + for i in T.Parallel(G): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(G, BS): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(G): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(G, BV): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(G, BV): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) + + return native_sparse_attention_varlen + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + + batch = len(offsets) - 1 + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + + kernel = native_sparse_attention_varlen( + batch=batch, + heads=HQ, + c_seq_len=C_SEQ_LEN, + dim=K, + is_causal=True, + block_size=block_size, + groups=G, + selected_blocks=S, + ) + + o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) + kernel( + q.view(C_SEQ_LEN, HQ, D), + k.view(C_SEQ_LEN, H, D), + v.view(C_SEQ_LEN, H, D), + o_slc.view(C_SEQ_LEN, HQ, V), + block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), + block_counts.to(torch.int32).view(C_SEQ_LEN, H), + offsets.to(torch.int32), + token_indices.to(torch.int32), + ) + return o_slc + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + return o_slc.to(q.dtype) + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + assert False, "Window size is not supported yet" + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 + torch.manual_seed(42) + # randomly split the sequence into N segments + offsets = ( + torch.cat( + [ + torch.tensor([0], dtype=torch.long), + torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]], + torch.tensor([C_SEQ_LEN], dtype=torch.long), + ], + 0, + ) + .cuda() + .sort()[0] + ) + + # seq-first required for inputs with variable lengths + perm_q = torch.randperm(C_SEQ_LEN, device="cuda") + perm_k = torch.randperm(C_SEQ_LEN, device="cuda") + perm_v = torch.randperm(C_SEQ_LEN, device="cuda") + q = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, HQ, D) + .clone() + .requires_grad_(True) + ) + k = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + v = ( + torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v] + .view(1, C_SEQ_LEN, 1, 1) + .expand(1, C_SEQ_LEN, H, D) + .clone() + .requires_grad_(True) + ) + g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda") + + token_indices = prepare_token_indices(offsets).tolist() + block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda") + for i in range(C_SEQ_LEN): + _, t = token_indices[i] + for h in range(H): + i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S] + block_indices[0, i, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + cu_seqlens=offsets, + ) + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + cu_seqlens=offsets, + ) + + print("tri", tri) + print("ref", ref) + + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) diff --git a/examples/deepseek_nsa/example_triton_nsa_bwd.py b/examples/deepseek_nsa/example_triton_nsa_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..af05bfa701654e3ec2dd53ffb2c0b50c61514801 --- /dev/null +++ b/examples/deepseek_nsa/example_triton_nsa_bwd.py @@ -0,0 +1,1008 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import torch +import triton +import triton.language as tl + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from reference import naive_nsa +from einops import rearrange + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + # if USE_BLOCK_COUNTS: + # NS = tl.load(block_counts + (bos + i_t) * H + i_h) + # else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do_slc, do_swa): + q, k, v, o_slc, lse_slc, o_swa, lse_swa = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + do_slc=do_slc, + do_swa=do_swa, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + window_size=ctx.window_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices, + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None + lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + return o_slc, lse_slc, o_swa, lse_swa + + +@triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None}) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV"], +) +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse_slc, + lse_swa, + delta_slc, + delta_swa, + do_slc, + do_swa, + dk, + dv, + block_mask, + offsets, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BS, BK], dtype=tl.float32) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BS, BV], dtype=tl.float32) + + for i in range(i_s * BS, T): + b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s) + if b_m_slc: + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do_slc = tl.load(p_do_slc, boundary_check=(0, 1)) + # [G] + b_lse_slc = tl.load(p_lse_slc) + b_delta_slc = tl.load(p_delta_slc) + # [BS, G] + b_s_slc = tl.dot(b_k, tl.trans(b_q)) + b_p_slc = tl.exp(b_s_slc - b_lse_slc[None, :]) + b_p_slc = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p_slc, 0) + # [BS, G] @ [G, BV] -> [BS, BV] + b_dv += tl.dot(b_p_slc.to(b_do_slc.dtype), b_do_slc) + # [BS, BV] @ [BV, G] -> [BS, G] + b_dp_slc = tl.dot(b_v, tl.trans(b_do_slc)) + # [BS, G] + b_ds_slc = b_p_slc * (b_dp_slc - b_delta_slc[None, :]) + # [BS, G] @ [G, BK] -> [BS, BK] + b_dk += tl.dot(b_ds_slc.to(b_q.dtype), b_q) + + if WS > 0: + o_s = i_s * BS + tl.arange(0, BS) + if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1): + p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do_swa = tl.load(p_do_swa, boundary_check=(0, 1)) + # [G] + b_lse_swa = tl.load(p_lse_swa) + b_delta_swa = tl.load(p_delta_swa) + # [BS, G] + b_s_swa = tl.dot(b_k, tl.trans(b_q)) + b_p_swa = tl.exp(b_s_swa - b_lse_swa[None, :]) + b_p_swa = tl.where((i >= o_s and (i - WS) < o_s)[:, None], b_p_swa, 0) + # [BS, G] @ [G, BV] -> [BS, BV] + b_dv += tl.dot(b_p_swa.to(b_do_swa.dtype), b_do_swa) + # [BS, BV] @ [BV, G] -> [BS, G] + b_dp_swa = tl.dot(b_v, tl.trans(b_do_swa)) + # [BS, G] + b_ds_swa = b_p_swa * (b_dp_swa - b_delta_swa[None, :]) + # [BS, G] @ [G, BK] -> [BS, BK] + b_dk += tl.dot(b_ds_swa.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor)}) +@triton.jit +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_s = i_hs // S, i_hs % S + + b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s) + if USE_BLOCK_COUNTS: + b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h) + else: + b_m = b_i * BS <= i_t + + if b_i < NS and b_i >= 0: + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV"], +) +@triton.jit(do_not_specialize=["T"]) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse_slc, + delta_slc, + do_slc, + lse_swa, + delta_swa, + do_swa, + dq, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += (bos + i_t) * HQ * K + do_slc += (bos + i_t) * HQ * V + lse_slc += (bos + i_t) * HQ + delta_slc += (bos + i_t) * HQ + if WS > 0: + do_swa += (bos + i_t) * HQ * V + lse_swa += (bos + i_t) * HQ + delta_swa += (bos + i_t) * HQ + dq += (i_v * B * T + bos + i_t) * HQ * K + block_indices += (bos + i_t) * H * S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + + p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do_slc = tl.make_block_ptr(do_slc, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + i_h * G + tl.arange(0, G) + p_delta_slc = delta_slc + i_h * G + tl.arange(0, G) + + # [G, BV] + b_do_slc = tl.load(p_do_slc, boundary_check=(0, 1)) + # [G] + b_lse_slc = tl.load(p_lse_slc) + b_delta_slc = tl.load(p_delta_slc) + + # [G, BK] + b_dq_slc = tl.zeros([G, BK], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (V, T), (1, H * V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BV, BS] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_p_slc = tl.exp(b_s_slc - b_lse_slc[:, None]) + b_p_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p_slc, 0) + + # [G, BV] @ [BV, BS] -> [G, BS] + b_dp_slc = tl.dot(b_do_slc, b_v_slc) + b_ds_slc = b_p_slc * (b_dp_slc.to(tl.float32) - b_delta_slc[:, None]) + # [G, BS] @ [BS, BK] -> [G, BK] + b_dq_slc += tl.dot(b_ds_slc.to(b_k_slc.dtype), tl.trans(b_k_slc)) + b_dq_slc *= scale + + if WS > 0: + p_do_swa = tl.make_block_ptr(do_swa, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_swa = lse_swa + i_h * G + tl.arange(0, G) + p_delta_swa = delta_swa + i_h * G + tl.arange(0, G) + + # [G, BV] + b_do_swa = tl.load(p_do_swa, boundary_check=(0, 1)) + # [G] + b_lse_swa = tl.load(p_lse_swa) + b_delta_swa = tl.load(p_delta_swa) + + # [G, BK] + b_dq_swa = tl.zeros([G, BK], dtype=tl.float32) + for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): + p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_swa = tl.make_block_ptr(v, (V, T), (1, H * V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1)) + # [BV, BS] + b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) + + # [G, BS] + b_s_swa = tl.dot(b_q, b_k_swa) + b_p_swa = tl.exp(b_s_swa - b_lse_swa[:, None]) + b_p_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p_swa, 0) + + # [G, BV] @ [BV, BS] -> [G, BS] + b_dp_swa = tl.dot(b_do_swa, b_v_swa) + b_ds_swa = b_p_swa * (b_dp_swa.to(tl.float32) - b_delta_swa[:, None]) + # [G, BS] @ [BS, BK] -> [G, BK] + b_dq_swa += tl.dot(b_ds_swa.to(b_k_swa.dtype), tl.trans(b_k_swa)) + b_dq_swa *= scale + + if WS == 0: + tl.store(p_dq, b_dq_slc.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + else: + tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + if WS > 0: + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_swa = tl.zeros([G, BV], dtype=tl.float32) + + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_swa = tl.zeros([G], dtype=tl.float32) + for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): + p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1)) + # [BS, BV] + b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) + # [G, BS] + b_s_swa = tl.dot(b_q, b_k_swa) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) + + # [G] + b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa + b_r_swa = tl.exp(b_mp_swa - b_m_swa) + # [G, BS] + b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None]) + # [G] + b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1) + # [G, BV] + b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa) + + b_mp_swa = b_m_swa + b_o_swa = b_o_swa / b_acc_swa[:, None] + b_m_swa += tl.log(b_acc_swa) + tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty)) + + +@triton.jit +def parallel_nsa_bwd_kernel_preprocess(o, do, delta, B: tl.constexpr, V: tl.constexpr): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +def parallel_nsa_block_mask( + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + offsets: torch.LongTensor, + block_size: int, +): + B, T, H, S = block_indices.shape + BS = block_size + if offsets is not None: + NS = triton.cdiv(prepare_lens(offsets).max().item(), BS) + else: + NS = triton.cdiv(T, BS) + block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) + + parallel_nsa_kernel_mask[(T, B, H * S)]( + block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS + ) + return block_mask + + +def parallel_nsa_bwd_preprocess(o: torch.Tensor, do: torch.Tensor): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float32) + parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_nsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o_slc: torch.Tensor, + lse_slc: torch.Tensor, + do_slc: torch.Tensor, + o_swa: torch.Tensor, + lse_swa: torch.Tensor, + do_swa: torch.Tensor, + block_indices: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + window_size: int = 0, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + + delta_slc = parallel_nsa_bwd_preprocess(o_slc, do_slc) + delta_swa = parallel_nsa_bwd_preprocess(o_swa, do_swa) if window_size > 0 else None + + dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + grid = (T, NV, B * H) + parallel_nsa_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse_slc=lse_slc, + delta_slc=delta_slc, + do_slc=do_slc, + lse_swa=lse_swa, + delta_swa=delta_swa, + do_swa=do_swa, + dq=dq, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + dq = dq.sum(0) + + if offsets is not None: + chunk_indices = prepare_chunk_indices(offsets, BS) + NS = len(chunk_indices) + else: + chunk_indices = None + NS = triton.cdiv(T, BS) + + # [B, T, H, M] + block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size) + dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + grid = (NV, NS, B * H) + parallel_nsa_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse_slc=lse_slc, + lse_swa=lse_swa, + delta_slc=delta_slc, + delta_swa=delta_swa, + do_slc=do_slc, + do_swa=do_swa, + dk=dk, + dv=dv, + block_mask=block_mask, + offsets=offsets, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + M=block_mask.shape[-1], + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + dk = dk.sum(0) + return dq, dk, dv + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do_slc, do_swa): + q, k, v, o_slc, lse_slc, o_swa, lse_swa = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + do_slc=do_slc, + do_swa=do_swa, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + window_size=ctx.window_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices, + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 + torch.random.manual_seed(0) + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(T): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + ) + ref.backward(do) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_size=block_size, + block_counts=block_counts, + ) + print("tri", tri) + print("ref", ref) + tri.backward(do) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None + + # assert_close(" o", ref, tri, 0.004) + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(ref_dg_slc, tri_dg_slc, atol=1e-2, rtol=1e-2) diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd.py b/examples/deepseek_nsa/example_triton_nsa_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ab28daaf931ebc7565130343f8e4c15a1570d2 --- /dev/null +++ b/examples/deepseek_nsa/example_triton_nsa_fwd.py @@ -0,0 +1,357 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import torch +import triton +import triton.language as tl + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_fwd, contiguous +from reference import naive_nsa +from einops import rearrange + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + # if USE_BLOCK_COUNTS: + # NS = tl.load(block_counts + (bos + i_t) * H + i_h) + # else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None + lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + return o_slc, lse_slc, o_swa, lse_swa + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16 + torch.random.manual_seed(0) + q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) + k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True) + g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda") + + block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda") + for b in range(B): + for t in range(T): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + ) + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_size=block_size, + block_counts=block_counts, + ) + + print("tri", tri) + print("ref", ref) + + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) diff --git a/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4eb6d7ba6119a0ebf16700d65b55b1fd1a237b --- /dev/null +++ b/examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py @@ -0,0 +1,392 @@ +# ruff: noqa +import torch +from typing import Optional, Union +from packaging.version import parse + +import torch +import triton +import triton.language as tl + +import fla + +if parse(fla.__version__) < parse("0.2.1"): + from fla.ops.common.utils import prepare_token_indices +else: + from fla.ops.utils import prepare_token_indices +from fla.utils import autocast_custom_fwd, contiguous +from reference import naive_nsa +from einops import rearrange + + +@triton.heuristics( + { + "USE_OFFSETS": lambda args: args["offsets"] is not None, + "USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor), + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV"], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o_slc, + o_swa, + lse_slc, + lse_swa, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + WS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H * S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_slc = tl.zeros([G, BV], dtype=tl.float32) + + b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_slc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1)) + # [BS, BV] + b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1)) + # [G, BS] + b_s_slc = tl.dot(b_q, b_k_slc) + b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf")) + + # [G] + b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc + b_r_slc = tl.exp(b_mp_slc - b_m_slc) + # [G, BS] + b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None]) + # [G] + b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1) + # [G, BV] + b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc) + + b_mp_slc = b_m_slc + b_o_slc = b_o_slc / b_acc_slc[:, None] + b_m_slc += tl.log(b_acc_slc) + tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty)) + + if WS > 0: + p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o_swa = tl.zeros([G, BV], dtype=tl.float32) + + b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32) + b_acc_swa = tl.zeros([G], dtype=tl.float32) + for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS): + p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) + p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1)) + # [BS, BV] + b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1)) + # [G, BS] + b_s_swa = tl.dot(b_q, b_k_swa) + b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf")) + + # [G] + b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa + b_r_swa = tl.exp(b_mp_swa - b_m_swa) + # [G, BS] + b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None]) + # [G] + b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1) + # [G, BV] + b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa) + + b_mp_swa = b_m_swa + b_o_swa = b_o_swa / b_acc_swa[:, None] + b_m_swa += tl.log(b_acc_swa) + tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty)) + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + window_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + WS = window_size + if torch.cuda.get_device_capability()[0] >= 9: + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None + lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o_slc=o_slc, + o_swa=o_swa, + lse_slc=lse_slc, + lse_swa=lse_swa, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + WS=WS, + BK=BK, + BV=BV, + ) + return o_slc, lse_slc, o_swa, lse_swa + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + window_size=window_size, + scale=scale, + offsets=offsets, + token_indices=token_indices, + ) + ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.window_size = window_size + ctx.scale = scale + return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None` + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + if isinstance(block_counts, int): + block_indices = block_indices[:, :, :, :block_counts] + block_counts = None + + o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens) + if window_size > 0: + o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1)) + else: + o = o_slc * g_slc.unsqueeze(-1) + if head_first: + o = rearrange(o, "b t h d -> b h t d") + return o + + +if __name__ == "__main__": + N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16 + torch.manual_seed(42) + # randomly split the sequence into N segments + offsets = ( + torch.cat( + [torch.tensor([0], dtype=torch.long), torch.arange(16, T)[torch.randperm(T - 1)[: N - 1]], torch.tensor([T], dtype=torch.long)], + 0, + ) + .cuda() + .sort()[0] + ) + # offsets.shape is [N+1] + # seq-first required for inputs with variable lengths + perm_q = torch.randperm(T, device="cuda") + perm_k = torch.randperm(T, device="cuda") + perm_v = torch.randperm(T, device="cuda") + q = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) + k = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + v = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) + g_slc = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + g_swa = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((1, T, HQ, D), dtype=dtype, device="cuda") + + token_indices = prepare_token_indices(offsets).tolist() + block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device="cuda") + for i in range(T): + _, t = token_indices[i] + for h in range(H): + i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] + block_indices[0, i, h, : len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + block_counts = torch.randint(1, S + 1, (1, T, H), device="cuda") + + ref = naive_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + cu_seqlens=offsets, + ) + + tri = parallel_nsa( + q=q, + k=k, + v=v, + g_slc=g_slc, + g_swa=g_swa, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + cu_seqlens=offsets, + ) + + print("tri", tri) + print("ref", ref) + + torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) diff --git a/examples/deepseek_nsa/reference.py b/examples/deepseek_nsa/reference.py new file mode 100644 index 0000000000000000000000000000000000000000..58083108eb30e871fba15b60a9f36bacee9c3949 --- /dev/null +++ b/examples/deepseek_nsa/reference.py @@ -0,0 +1,305 @@ +# ruff: noqa +from typing import Optional + +import torch +from typing import Union +from einops import rearrange, repeat + + +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Optional[Union[torch.LongTensor, int]] = None, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (Union[torch.LongTensor, int]): + Number of selected blocks for each token. + If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + each token can select the same number of blocks. + If not provided, it will default to `S`, Default: `None`. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices)) + g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa)) + if isinstance(block_counts, torch.Tensor): + block_counts = rearrange(block_counts, "b h t -> b t h") + + dtype = q.dtype + G = q.shape[2] // k.shape[2] + BS = block_size + S = block_indices.shape[-1] + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + if isinstance(block_counts, torch.Tensor): + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + + o_slc = torch.zeros_like(v) + o_swa = torch.zeros_like(v) if window_size > 0 else None + varlen = True + if cu_seqlens is None: + varlen = False + B, T = q.shape[:2] + cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])]) + + for i in range(len(cu_seqlens) - 1): + if not varlen: + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i] + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[i] + else: + s_b = block_counts + else: + T = cu_seqlens[i + 1] - cu_seqlens[i] + q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map( + lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices) + ) + if isinstance(block_counts, torch.Tensor): + s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]] + else: + s_b = block_counts + + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [HQ] + g_slc_i = g_slc_b[i_q] + # [HQ] + g_swa_i = g_swa_b[i_q] + # [S*BS, HQ] + i_i = i_b[i_q] + # [HQ] + if isinstance(block_counts, torch.Tensor): + s_i = s_b[i_q] + else: + s_i = s_b + # [S*BS, HQ, -1] + k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + # [S*BS, HQ] + attn_slc = ( + torch.einsum("h d, n h d -> n h", q_i, k_i_slc) + .masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf")) + .softmax(0) + ) + if not varlen: + o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + else: + o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1) + if window_size > 0: + k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b)) + attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0) + if not varlen: + o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + else: + o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1) + + if head_first: + o_slc = rearrange(o_slc, "b t h d -> b h t d") + o_swa = rearrange(o_swa, "b t h d -> b h t d") + + return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) + + +def naive_nsa_simple( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: torch.LongTensor, + block_size: int = 64, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (torch.LongTensor): + Block counts of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + block_size (int): + Selected block size. Default: 64. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + scale = k.shape[-1] ** -0.5 + + dtype = q.dtype + HQ = q.shape[2] + H = k.shape[2] + D = k.shape[-1] + G = HQ // H + BS = block_size + S = block_indices.shape[-1] + SELECTED_BLOCKS_SIZE = S * BS + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + o = torch.zeros_like(v) + B, T = q.shape[:2] + + for i in range(B): + q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i] + # [T, HQ, S, BS] -> [T, HQ, S*BS] + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, HQ, S*BS] -> [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [S*BS, HQ] -> represents selected blocks for each query token + i_i = i_b[i_q] + # [HQ] -> represents the number of selected blocks for each query token + s_i = s_b[i_q] + + k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype) + v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype) + + for h in range(HQ): + for t in range(SELECTED_BLOCKS_SIZE): + selected_block_index = i_i[t, h] + k_i[t, h] = k_b[selected_block_index, h, :] + v_i[t, h] = v_b[selected_block_index, h, :] + + # [S*BS, HQ] + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float("-inf")) + attn = torch.softmax(attn, dim=0) + o[i, i_q] = torch.einsum("n h, n h v -> h v", attn, v_i) + + return o.to(dtype) + + +def naive_nsa_simple_inference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: torch.LongTensor, + block_size: int = 64, +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, 1, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, 1, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper. + block_counts (torch.LongTensor): + Block counts of shape `[B, 1, H]` if `head_first=False` else `[B, H, T]`. + block_size (int): + Selected block size. Default: 64. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + scale = k.shape[-1] ** -0.5 + + dtype = q.dtype + HQ = q.shape[2] + H = k.shape[2] + D = k.shape[-1] + G = HQ // H + BS = block_size + S = block_indices.shape[-1] + SELECTED_BLOCKS_SIZE = S * BS + k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices)) + block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G) + c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device) + q, k, v = map(lambda x: x.float(), (q, k, v)) + o = torch.zeros_like(q) + B, T = q.shape[:2] + + for i in range(B): + q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i] + # [T, HQ, S, BS] -> [T, HQ, S*BS] + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, HQ, S*BS] -> [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + + # [HQ, D] + q_i = q_b[0] * scale + # [S*BS, HQ] -> represents selected blocks for each query token + i_i = i_b[0] + # [HQ] -> represents the number of selected blocks for each query token + s_i = s_b[0] + + k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype) + v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype) + + for h in range(HQ): + for t in range(SELECTED_BLOCKS_SIZE): + selected_block_index = i_i[t, h] + k_i[t, h] = k_b[selected_block_index, h, :] + v_i[t, h] = v_b[selected_block_index, h, :] + + # [S*BS, HQ] + attn = torch.einsum("h d, n h d -> n h", q_i, k_i) + attn = attn.masked_fill((c >= s_i), float("-inf")) + attn = torch.softmax(attn, dim=0) + o[i, 0] = torch.einsum("n h, n h v -> h v", attn, v_i) + + return o.to(dtype) diff --git a/examples/deepseek_nsa/requirements.txt b/examples/deepseek_nsa/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..777c2ad4c81bbf9c00a4fca8361c7dd9dfb39d0e --- /dev/null +++ b/examples/deepseek_nsa/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e \ No newline at end of file diff --git a/examples/deepseek_nsa/test_example_tilelang_nsa.py b/examples/deepseek_nsa/test_example_tilelang_nsa.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc6f98e962c167bf3f2783910c7dda9bd624373 --- /dev/null +++ b/examples/deepseek_nsa/test_example_tilelang_nsa.py @@ -0,0 +1,17 @@ +# ruff: noqa +import tilelang.testing + +from example_tilelang_nsa_fwd import main as main_fwd +from example_tilelang_nsa_decode import main as main_fwd_decode + + +def test_example_tilelang_nsa_fwd(): + main_fwd() + + +def test_example_tilelang_nsa_fwd_decode(): + main_fwd_decode() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/deepseek_v32/README.md b/examples/deepseek_v32/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8457745b0e9ea4aea5e5f06c5417859cb51a41bc --- /dev/null +++ b/examples/deepseek_v32/README.md @@ -0,0 +1,223 @@ +## Directory Structure + +``` +deepseek_v32/ +โ”œโ”€โ”€ README.md # This file +โ”œโ”€โ”€ figures/ # Figures and diagrams +โ”œโ”€โ”€ inference/ # Inference implementation folder +โ”œโ”€โ”€ fp8_lighting_indexer.py # FP8 lighting indexer +โ”œโ”€โ”€ sparse_mla_bwd.py # Sparse MLA backward implementation +โ”œโ”€โ”€ sparse_mla_fwd.py # Sparse MLA forward implementation +โ”œโ”€โ”€ sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass +โ”œโ”€โ”€ topk_selector.py # Top-k selector implementation +``` + +## File Descriptions + +### Architecture Overview + +![DeepSeek V3.2 Architecture](./figures/v32_arch.png) + +The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations: + +1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision +2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation +3. **Multi-Query Attention** (`sparse_mla_fwd.py`, `sparse_mla_fwd_pipelined.py`, and `sparse_mla_bwd.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward and backward passes + +### Lightning Indexer + +Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector. + +The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices: + +```python +T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, +) +``` + +After the matmul, we apply ReLU and aggregate across heads with learned weights: + +```python +for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = ( + T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i] + ) * index_k_scale_fragment[bn_i] + +T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) +``` + +The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions: + +```python +for bq_i in T.serial(block_Q): + cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) +for bq_i in T.serial(block_Q): + cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) +``` + +The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training. + +### Top-k Selector + +The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process. + +The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence: + +```python +for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + input_idx = s*BLOCK_SIZE+tx + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: + inval_int16 = convert_to_uint16(input[bx, input_idx]) + T.atomic_add(s_histogram[inval_int16], 1) +``` + +The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin: + +```python +if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx +``` + +Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing: + +```python +if l_bin_id32 > l_threshold_bin_id: + pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True) + index[bx, pos] = input_idx +elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + pos = T.atomic_add(s_num_input[0], 1, return_prev=True) + s_input_idx[0, pos] = input_idx +``` + +Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence. + +### Sparse MLA Forward + +The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens. + +Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions: + +```python +# Dense MLA: iterate over full sequence +loop_range = T.ceildiv(seqlen_kv, block_N) +for k in T.Pipelined(loop_range, num_stages=2): + T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) + # ... compute attention over this block +``` + +Sparse MLA only loads KV positions selected by the top-k selector: + +```python +# Sparse MLA: iterate over selected indices only +for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] + # ... compute attention over selected tokens +``` + +This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid: + +```python +for bi_i in T.Parallel(BI): + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i +``` + +Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA. + +### Sparse MLA Forward (Pipelined) + +The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines. + +The key difference is splitting the warp groups into specialized roles: + +```python +if tx < 128: + # Consumer 0: computes left half of output (D//2 dimensions) + # Handles QK matmul, softmax, and PV for left half + +elif tx >= 128 and tx < 256: + # Consumer 1: computes right half of output (D//2 dimensions) + # Only does PV matmul for right half + +elif tx >= 256: + # Producer: loads KV data from global memory + # Uses async copy with barriers to feed consumers +``` + +The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed: + +```python +# Producer alternates between two buffers +for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + # ... load KV into buffer 0 + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + # ... load KV into buffer 1 + T.cp_async_barrier_noinc(bar_k_1_ready[0]) +``` + +Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul. + +### Sparse MLA Backward + +The Sparse MLA backward kernel (`sparse_mla_bwd.py`) computes gradients with respect to queries (dQ) and key-values (dKV) for the sparse attention mechanism. Like the forward pass, it processes only the selected top-k indices, maintaining O(seq_len * topk) complexity. + +The backward pass consists of three main stages: + +**1. Preprocessing**: Computes delta values (row-wise dot products of output and output gradient): + +```python +for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o) + T.copy(dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] +T.reduce_sum(acc, delta, 1) +``` + +**2. Main Backward Computation**: Computes gradients through sparse attention: + +```python +# Sparse MLA backward: iterate over selected indices only +for i_i in T.Pipelined(NI, num_stages=num_stages): + # Load KV data for selected indices + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i] + + # Recompute attention scores for backward + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + # Apply softmax gradient: dP = P * (dP_raw - Delta) + for h_i, bi_i in T.Parallel(padded_H, BI): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale +``` + +The key gradient computations are: +- **dQ = dP @ K** (query gradients) +- **dK = dP^T @ Q** (key gradients) +- **dV = P^T @ dO** (value gradients) + +**3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation: + +```python +# Atomically update dKV at selected indices +for bi_i, d_i in T.Parallel(BI // split_store, D // 4): + T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4]) +``` + +**Performance**: The sparse MLA backward achieves excellent performance: +- **H800 SXM**: ~100 TFlops +- **H200 SXM**: ~115 TFlops + +The implementation efficiently handles the irregular memory access patterns inherent in sparse attention while maintaining high compute utilization through careful memory management and atomic update strategies. Note that this is a relatively naive implementation that requires further optimization. diff --git a/examples/deepseek_v32/figures/v32_arch.png b/examples/deepseek_v32/figures/v32_arch.png new file mode 100644 index 0000000000000000000000000000000000000000..50f3a847b509868c7b04af20e1edb81b54bc6bb6 Binary files /dev/null and b/examples/deepseek_v32/figures/v32_arch.png differ diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..01ad0a73469b7b7feff0a58f7918d3d144ec3c19 --- /dev/null +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -0,0 +1,284 @@ +# ruff: noqa +import itertools +import tilelang +from tilelang import language as T +import torch +from utils import generate_random_cu_seqlens, per_custom_dims_cast_to_fp8 + + +def display_error_message(msg): + print(f"\033[31mWARNING: {msg}\033[0m") + + +def compute_correlation(a, b, label="tensor"): + a, b = a.data.double(), b.data.double() + norm_sum = (a * a + b * b).sum() + if norm_sum == 0: + display_error_message(f"{label} all zero") + return 1 + correlation = 2 * (a * b).sum() / norm_sum + return correlation + + +def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_raise=True): + a_finite = torch.isfinite(a) + b_finite = torch.isfinite(b) + if not torch.all(a_finite == b_finite): + display_error_message(f"{tensor_name} Error: isfinite mask mismatch") + if should_raise: + assert False + if not torch.isclose( + a.masked_fill(a_finite, 0), + b.masked_fill(b_finite, 0), + rtol=0, + atol=0, + equal_nan=True, + ).all(): + display_error_message(f"{tensor_name} Error: nonfinite value mismatch") + if should_raise: + assert False + a = a.masked_fill(~a_finite, 0) + b = b.masked_fill(~b_finite, 0) + correlation = compute_correlation(a, b, tensor_name) + difference = 1.0 - correlation + if not (0 <= difference <= tolerance): + display_error_message(f"{tensor_name} Error: {difference}") + if should_raise: + assert False + return difference + + +def get_configs(): + iter_params = dict( + block_N=[32, 64, 128], + num_stages=[0, 1, 2], + threads=[128, 256], + block_Q=[1, 2, 4], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +class SupplyProg: + def __init__(self): + self.tensors_dict = {} + + def get_key(self, shape, dtype) -> str: + return f"{shape}-{dtype}" + + def supply_prog(self, params): + shapes = [p.shape for p in params] + dtypes = [p.dtype for p in params] + tensor_list = [] + for shape, dtype in zip(shapes, dtypes): + key = self.get_key(shape, dtype) + if key not in self.tensors_dict: + self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda") + tensor_list.append(self.tensors_dict[key]) + else: + tensor_list.append(self.tensors_dict[key]) + return tensor_list + + +supply_prog = SupplyProg() + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def mqa_attn_return_logits( + heads, + index_dim, + block_N=256, + num_stages=3, + threads=512, + block_Q=None, +): + if block_Q is None: + block_Q = 128 // heads + dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 + + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + index_q_shape = [seq_len * heads, index_dim] + index_k_shape = [seq_len_kv, index_dim] + index_k_scale_shape = [seq_len_kv] + logits_shape = [seq_len, seq_len_kv] + + @T.prim_func + def mqa_attn_return_logits_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: + index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], dtype) + index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) + s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) + s_reshaped = T.reshape(s, (block_N, block_Q, heads)) + logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + weights = T.alloc_fragment([block_Q, heads], accum_dtype) + + seq_len_i = bx * block_Q + + cu_k_s_min = T.alloc_local([1], index_dtype) + cu_k_e_max = T.alloc_local([1], index_dtype) + + cu_k_s_min[0] = 2147483647 + cu_k_e_max[0] = -2147483648 + + for bq_i in T.serial(block_Q): + cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) + for bq_i in T.serial(block_Q): + cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) + + T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) + T.copy(Weights[seq_len_i, 0], weights) + + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): + T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) + T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[ + bn_i + ] + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for bq_i, bn_i in T.Parallel(block_Q, block_N): + Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] + + return mqa_attn_return_logits_kernel + + +@tilelang.jit +def clean_logits_( + threads: int = 512, + block_K: int = 4096, +): + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + dtype = T.float + indices_dtype = T.int32 + + @T.prim_func + def clean_logits_kernel( + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + ): + with T.Kernel(seq_len, threads=threads) as bx: + tx = T.thread_binding(0, threads, thread="threadIdx.x") + cu_k_s = T.alloc_local([1], indices_dtype) + cu_k_e = T.alloc_local([1], indices_dtype) + cu_k_s[0] = CuSeqLenKS[bx] + cu_k_e[0] = CuSeqLenKE[bx] + + for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): + for k_i in T.serial(block_K // threads): + idx = n_i * block_K + k_i * threads + tx + if idx < cu_k_s[0] or idx >= cu_k_e[0]: + Logits[bx, idx] = -T.infinity(dtype) + + return clean_logits_kernel + + +def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True): + seq_len, heads, index_dim = q.shape + seq_len_kv = kv.shape[0] + + clean_logits_kernel = clean_logits_() + + mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim) + logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32) + mqa_attn_return_logits_kernel( + q.view(seq_len * heads, index_dim), + kv, + kv_scales, + logits, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + ) + if clean_logits: + clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke) + return logits + + +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): + k = kv + q = q.float() + k = k.float() + + seq_len_kv = kv.shape[0] + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + cost = mask.sum() + return logits, cost + + +def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): + # initial random seed to make the performance reproducible + torch.manual_seed(0) + q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + weights = torch.randn(S, H, device="cuda", dtype=torch.float32) + p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) + + ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + + logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + + logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + + print(f"diff: {diff}") + + from tilelang.profiler import do_bench + + def logits_fn(): + return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + logits_fn() + + print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)) + + logits_ms = do_bench(logits_fn, warmup=100, rep=100) + logits_flops = 2 * cost_ref * H * D + logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12 + print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}") + print(f"cost_ref: {cost_ref}") + + +if __name__ == "__main__": + test_fp8_lighting_indexer() diff --git a/examples/deepseek_v32/inference/README.md b/examples/deepseek_v32/inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fe4cc21bba684273345706317f49012f4bb96d71 --- /dev/null +++ b/examples/deepseek_v32/inference/README.md @@ -0,0 +1,14 @@ +# DeepSeek V3.2 + +First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count: +```bash +cd inference +export EXPERTS=256 +python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP} +``` + +Launch the interactive chat interface and start exploring DeepSeek's capabilities: +```bash +export CONFIG=config_671B_v3.2.json +torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive +``` \ No newline at end of file diff --git a/examples/deepseek_v32/inference/config_671B_v3.2.json b/examples/deepseek_v32/inference/config_671B_v3.2.json new file mode 100644 index 0000000000000000000000000000000000000000..be88f1cca20c7dc78d8459c4c8456c197cba0b5a --- /dev/null +++ b/examples/deepseek_v32/inference/config_671B_v3.2.json @@ -0,0 +1,26 @@ +{ + "vocab_size": 129280, + "dim": 7168, + "inter_dim": 18432, + "moe_inter_dim": 2048, + "n_layers": 61, + "n_dense_layers": 3, + "n_heads": 128, + "n_routed_experts": 256, + "n_shared_experts": 1, + "n_activated_experts": 8, + "n_expert_groups": 8, + "n_limited_groups": 4, + "route_scale": 2.5, + "score_func": "sigmoid", + "q_lora_rank": 1536, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "dtype": "fp8", + "scale_fmt": "ue8m0", + "index_n_heads": 64, + "index_head_dim": 128, + "index_topk": 2048 +} \ No newline at end of file diff --git a/examples/deepseek_v32/inference/convert.py b/examples/deepseek_v32/inference/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..df7943918f80557af7a0485b7d3591d070ffcbab --- /dev/null +++ b/examples/deepseek_v32/inference/convert.py @@ -0,0 +1,100 @@ +import os +import shutil +from argparse import ArgumentParser +from glob import glob +from tqdm import tqdm, trange + +import torch +from safetensors.torch import safe_open, save_file + +mapping = { + "embed_tokens": ("embed", 0), + "input_layernorm": ("attn_norm", None), + "post_attention_layernorm": ("ffn_norm", None), + "q_proj": ("wq", 0), + "q_a_proj": ("wq_a", None), + "q_a_layernorm": ("q_norm", None), + "q_b_proj": ("wq_b", 0), + "kv_a_proj_with_mqa": ("wkv_a", None), + "kv_a_layernorm": ("kv_norm", None), + "kv_b_proj": ("wkv_b", 0), + "o_proj": ("wo", 1), + "gate": ("gate", None), + "gate_proj": ("w1", 0), + "down_proj": ("w2", 1), + "up_proj": ("w3", 0), + "norm": ("norm", None), + "lm_head": ("head", 0), + "scale": ("scale", None), + "wq_b": ("wq_b", None), + "wk": ("wk", None), + "k_norm": ("k_norm", None), + "weights_proj": ("weights_proj", None), +} + + +def main(hf_ckpt_path, save_path, n_experts, mp): + """ + Converts and saves model checkpoint files into a specified format. + + Args: + hf_ckpt_path (str): Path to the directory containing the input checkpoint files. + save_path (str): Path to the directory where the converted checkpoint files will be saved. + n_experts (int): Total number of experts in the model. + mp (int): Model parallelism factor. + + Returns: + None + """ + torch.set_num_threads(8) + n_local_experts = n_experts // mp + state_dicts = [{} for _ in range(mp)] + + for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): + with safe_open(file_path, framework="pt", device="cpu") as f: + for name in f.keys(): + if "model.layers.61" in name: + continue + param: torch.Tensor = f.get_tensor(name) + if name.startswith("model."): + name = name[len("model."):] + name = name.replace("self_attn", "attn") + name = name.replace("mlp", "ffn") + name = name.replace("weight_scale_inv", "scale") + name = name.replace("e_score_correction_bias", "bias") + key = name.split(".")[-2] + assert key in mapping, f"Key {key} not found in mapping" + new_key, dim = mapping[key] + name = name.replace(key, new_key) + for i in range(mp): + new_param = param + if "experts" in name and "shared_experts" not in name: + idx = int(name.split(".")[-3]) + if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: + continue + elif dim is not None: + assert param.size( + dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" + shard_size = param.size(dim) // mp + new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() + state_dicts[i][name] = new_param + + os.makedirs(save_path, exist_ok=True) + + for i in trange(mp): + save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) + + for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): + new_file_path = os.path.join(save_path, os.path.basename(file_path)) + shutil.copyfile(file_path, new_file_path) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--hf-ckpt-path", type=str, required=True) + parser.add_argument("--save-path", type=str, required=True) + parser.add_argument("--n-experts", type=int, required=True) + parser.add_argument("--model-parallel", type=int, required=True) + args = parser.parse_args() + assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" + main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) diff --git a/examples/deepseek_v32/inference/generate.py b/examples/deepseek_v32/inference/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..fda1e80968dc610574f576886e2b33c6ee24e56f --- /dev/null +++ b/examples/deepseek_v32/inference/generate.py @@ -0,0 +1,197 @@ +import os +import json +from argparse import ArgumentParser +from typing import List + +import torch +import torch.distributed as dist +from transformers import AutoTokenizer +from safetensors.torch import load_model + +from model import Transformer, ModelArgs + + +def sample(logits, temperature: float = 1.0): + """ + Samples a token from the logits using temperature scaling. + + Args: + logits (torch.Tensor): The logits tensor for token predictions. + temperature (float, optional): Temperature for scaling logits. Defaults to 1.0. + + Returns: + torch.Tensor: The sampled token. + """ + logits = logits / max(temperature, 1e-5) + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) + + +@torch.inference_mode() +def generate(model: Transformer, + prompt_tokens: List[List[int]], + max_new_tokens: int, + eos_id: int, + temperature: float = 1.0) -> List[List[int]]: + """ + Generates new tokens based on the given prompt tokens using the specified model. + + Args: + model (Transformer): The transformer model used for token generation. + prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence. + max_new_tokens (int): The maximum number of new tokens to generate. + eos_id (int): The end-of-sequence token ID. + temperature (float, optional): The temperature value for sampling. Defaults to 1.0. + + Returns: + List[List[int]]: A list of lists containing the generated tokens for each sequence. + """ + prompt_lens = [len(t) for t in prompt_tokens] + assert max( + prompt_lens + ) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})" + total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) + tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") + for i, t in enumerate(prompt_tokens): + tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + prev_pos = 0 + finished = torch.tensor([False] * len(prompt_tokens), device="cuda") + prompt_mask = tokens != -1 + for cur_pos in range(min(prompt_lens), total_len): + logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + next_token = sample(logits, temperature) + else: + next_token = logits.argmax(dim=-1) + next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) + tokens[:, cur_pos] = next_token + finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) + prev_pos = cur_pos + if finished.all(): + break + completion_tokens = [] + for i, toks in enumerate(tokens.tolist()): + toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens] + if eos_id in toks: + toks = toks[:toks.index(eos_id)] + completion_tokens.append(toks) + return completion_tokens + + +def main( + ckpt_path: str, + config: str, + input_file: str = "", + interactive: bool = True, + max_new_tokens: int = 100, + temperature: float = 1.0, +) -> None: + """ + Main function to load the model and perform interactive or batch text generation. + + Args: + ckpt_path (str): Path to the model checkpoint directory. + config (str): Path to the model configuration file. + input_file (str, optional): Path to a file containing input prompts. Defaults to "". + interactive (bool, optional): Whether to run in interactive mode. Defaults to True. + max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. + temperature (float, optional): Temperature for sampling. Defaults to 1.0. + """ + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + if world_size > 1: + dist.init_process_group("nccl") + global print + if rank != 0: + print = lambda *_, **__: None + torch.cuda.set_device(local_rank) + torch.set_default_dtype(torch.bfloat16) + torch.set_num_threads(8) + torch.manual_seed(33377335) + with open(config) as f: + args = ModelArgs(**json.load(f)) + print(args) + with torch.device("cuda"): + model = Transformer(args) + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) + print("load model") + load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) + print("I'm DeepSeek ๐Ÿ‘‹") + + if interactive: + messages = [] + while True: + if world_size == 1: + prompt = input(">>> ") + elif rank == 0: + prompt = input(">>> ") + objects = [prompt] + dist.broadcast_object_list(objects, 0) + else: + objects = [None] + dist.broadcast_object_list(objects, 0) + prompt = objects[0] + if prompt == "/exit": + break + elif prompt == "/clear": + messages.clear() + continue + messages.append({"role": "user", "content": prompt}) + prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + completion_tokens = generate(model, [prompt_tokens], max_new_tokens, + tokenizer.eos_token_id, temperature) + completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) + print(completion) + messages.append({"role": "assistant", "content": completion}) + else: + with open(input_file) as f: + prompts = f.read().split("\n\n") + assert len( + prompts + ) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})" + prompt_tokens = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True) for prompt in prompts + ] + completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, + temperature) + completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + for prompt, completion in zip(prompts, completions): + print("Prompt:", prompt) + print("Completion:", completion) + print() + + if world_size > 1: + dist.destroy_process_group() + + +if __name__ == "__main__": + """ + Command-line interface for distributed text generation. + + Arguments: + --ckpt-path (str): Path to the model checkpoint directory. + --config (str): Path to the model configuration file. + --input-file (str, optional): File containing prompts for batch processing. + --interactive (bool, optional): Enable interactive mode for generating text. + --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200. + --temperature (float, optional): Temperature for sampling. Defaults to 0.2. + + Raises: + AssertionError: If neither input-file nor interactive mode is specified. + """ + parser = ArgumentParser() + parser.add_argument("--ckpt-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--input-file", type=str, default="") + parser.add_argument("--interactive", action="store_true") + parser.add_argument("--max-new-tokens", type=int, default=200) + parser.add_argument("--temperature", type=float, default=0.6) + args = parser.parse_args() + assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" + main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, + args.temperature) diff --git a/examples/deepseek_v32/inference/kernel.py b/examples/deepseek_v32/inference/kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..25abf15d597caea97d0a890ca09a9cf73c7aa084 --- /dev/null +++ b/examples/deepseek_v32/inference/kernel.py @@ -0,0 +1,268 @@ +import torch +import tilelang +import tilelang.language as T +from typing import Tuple, Optional + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, +} + +FP8 = T.float8_e4m3fn +BF16 = T.bfloat16 +FP32 = T.float32 + + +def fast_log2_ceil(x): + bits_x = T.reinterpret(T.uint32, x) + exp_x = (bits_x >> 23) & 0xFF + man_bits = bits_x & ((1 << 23) - 1) + return T.Cast(T.int32, exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + + +def fast_pow2(x): + bits_x = (x + 127) << 23 + return T.reinterpret(T.float32, bits_x) + + +def fast_round_scale(amax, fp8_max_inv): + return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) + + +@tilelang.jit(pass_configs=pass_configs) +def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False): + M = T.dynamic("M") + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1 / fp8_max + num_stages = 0 if round_scale else 2 + blk_m = 32 + group_size = 128 + + @T.prim_func + def act_quant_kernel_( + X: T.Tensor[(M, N), in_dtype], + Y: T.Tensor[(M, N), out_dtype], + S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], + ): + with T.Kernel( + T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as ( + pid_m, + pid_n, + ): + x_shared = T.alloc_shared((blk_m, group_size), in_dtype) + x_local = T.alloc_fragment((blk_m, group_size), in_dtype) + amax_local = T.alloc_fragment((blk_m,), scale_dtype) + s_local = T.alloc_fragment((blk_m,), scale_dtype) + y_local = T.alloc_fragment((blk_m, group_size), out_dtype) + y_shared = T.alloc_shared((blk_m, group_size), out_dtype) + + for _ in T.Pipelined(1, num_stages=num_stages): + T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) + T.copy(x_shared, x_local) + T.reduce_absmax(x_local, amax_local, dim=1) + for i in T.Parallel(blk_m): + amax_local[i] = T.max(amax_local[i], 1e-4) + if round_scale: + s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv) + else: + s_local[i] = amax_local[i] * fp8_max_inv + for i, j in T.Parallel(blk_m, group_size): + y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max) + for i in T.Parallel(blk_m): + S[pid_m * blk_m + i, pid_n] = s_local[i] + T.copy(y_local, y_shared) + T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) + + return act_quant_kernel_ + + +def act_quant(x: torch.Tensor, + block_size: int = 128, + scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})") + N = x.size(-1) + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) + kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) + return y, s + + +@tilelang.jit(pass_configs=pass_configs) +def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=T.float32): + assert out_dtype in [BF16, T.float32] + + M = T.dynamic("M") + group_size = 128 + block_M = 32 + block_N = 128 + block_K = 128 + + @T.prim_func + def fp8_gemm_kernel_( + A: T.Tensor[(M, K), FP8], + B: T.Tensor[(N, K), FP8], + C: T.Tensor[(M, N), out_dtype], + scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], + scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32], + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), FP8) + B_shared = T.alloc_shared((block_N, block_K), FP8) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + Scale_C_shared = T.alloc_shared((block_M), FP32) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return fp8_gemm_kernel_ + + +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, + b_s: torch.Tensor) -> torch.Tensor: + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), ( + "Scaling factor tensors must be contiguous") + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + kernel = fp8_gemm_kernel(N, K) + kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) + return c + + +@tilelang.jit(out_idx=[4], pass_configs=pass_configs) +def fp8_index_kernel(h: int, d: int): + b = T.dynamic("b") + m = T.dynamic("m") + n = T.dynamic("n") + + blk_n1 = 512 + blk_n2 = 128 + + @T.prim_func + def fp8_index_kernel_( + q: T.Tensor[(b, m, h, d), FP8], + q_s: T.Tensor[(b, m, h), FP32], + k: T.Tensor[(b, n, d), FP8], + k_s: T.Tensor[(b, n), FP32], + o: T.Tensor[(b, m, n), FP32], + ) -> None: + with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): + q_smem = T.alloc_shared((h, d), FP8) + T.copy(q[i_b, i_m, 0, 0], q_smem) + + q_s_frag = T.alloc_fragment(h, FP32) + T.copy(q_s[i_b, i_m, 0], q_s_frag) + + for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): + k_smem = T.alloc_shared((blk_n2, d), FP8) + T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) + + k_s_frag = T.alloc_fragment(blk_n2, FP32) + T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) + + logits = T.alloc_fragment((blk_n2, h), FP32) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + + for i_h, i3_n in T.Parallel(h, blk_n2): + logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] + + logits_sum = T.alloc_fragment(blk_n2, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + + for i3_n in T.Parallel(blk_n2): + logits_sum[i3_n] *= k_s_frag[i3_n] + + T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) + + return fp8_index_kernel_ + + +def fp8_index( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """ + Perform index score using FP8 precision. + + Args: + q (torch.Tensor): The Q tensor, must be contiguous. + q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous. + k (torch.Tensor): The K tensor, must be contiguous. + k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous. + + fp8 q @ fp8 k -> fp32 logits + relu(fp32 logits) * q_s (weights) -> fp32 logits + fp32 logits -> fp32 logits_sum + fp32 logits_sum * k_s (e8m0) -> fp32 index_score + """ + return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) diff --git a/examples/deepseek_v32/inference/model.py b/examples/deepseek_v32/inference/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e7468f0587e03de4d36d637fa2e149624c904b --- /dev/null +++ b/examples/deepseek_v32/inference/model.py @@ -0,0 +1,972 @@ +import math +from dataclasses import dataclass +from typing import Tuple, Optional, Literal + +from einops import rearrange +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist + +from kernel import act_quant, fp8_gemm, fp8_index + +world_size = 1 +rank = 0 +block_size = 128 + + +@dataclass +class ModelArgs: + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + scale_fmt (Optional[str]): Format for quantization scale. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + mscale (float): Scaling factor for extended attention. + index_head_dim (int): Dimension for index head. + index_topk (int): Top-k for index head. + """ + max_batch_size: int = 8 + max_seq_len: int = 4096 * 4 + dtype: Literal["bf16", "fp8"] = "bf16" + scale_fmt: Optional[str] = None + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + # moe + n_routed_experts: int = 64 + n_shared_experts: int = 2 + n_activated_experts: int = 6 + n_expert_groups: int = 1 + n_limited_groups: int = 1 + score_func: Literal["softmax", "sigmoid"] = "softmax" + route_scale: float = 1. + # mla + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + # yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1. + # index + index_n_heads: int = 64 + index_head_dim: int = 128 + index_topk: int = 2048 + + +class ParallelEmbedding(nn.Module): + """ + Embedding layer with parallelism support across distributed processes. + + Args: + vocab_size (int): Vocabulary size. + dim (int): Embedding dimension. + """ + + def __init__(self, vocab_size: int, dim: int): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" + self.part_vocab_size = (vocab_size // world_size) + self.vocab_start_idx = rank * self.part_vocab_size + self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size + self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for parallel embedding layer. + + Args: + x (torch.Tensor): Input tensor containing token indices. + + Returns: + torch.Tensor: Embedded representations. + + Raises: + ValueError: If `world_size` is not defined. + """ + if world_size > 1: + mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) + x = x - self.vocab_start_idx + x[mask] = 0 + y = F.embedding(x, self.weight) + if world_size > 1: + y[mask] = 0 + dist.all_reduce(y) + return y + + +def linear(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + scale_fmt: Optional[str] = None) -> torch.Tensor: + """ + Applies a linear transformation to the incoming data: y = xA^T + b. + This function supports specialized implementations based on quantization + and tensor formats. + + Args: + x (torch.Tensor): The input tensor. + weight (torch.Tensor): The weight tensor. It may be quantized and + requires dequantization for certain cases. + bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. + scale_fmt (Optional[str]): The format of scaling factors. + + Returns: + torch.Tensor: The result of the linear transformation, which may involve + quantization-aware computations depending on the input parameters. + + Notes: + - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version + is used for computation. + - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. + """ + assert bias is None + + if weight.dtype != torch.float8_e4m3fn: + return F.linear(x, weight) + else: + x, scale = act_quant(x, block_size, scale_fmt) + return fp8_gemm(x, scale, weight, weight.scale) + + +class Linear(nn.Module): + """ + Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + dtype = torch.bfloat16 + scale_fmt: Optional[str] = None + + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) + if self.weight.element_size() == 1: + scale_out_features = (out_features + block_size - 1) // block_size + scale_in_features = (in_features + block_size - 1) // block_size + self.weight.scale = self.scale = nn.Parameter( + torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) + else: + self.register_parameter("scale", None) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the custom linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor after linear computation. + """ + return linear(x, self.weight, self.bias, self.scale_fmt) + + +class ColumnParallelLinear(Linear): + """ + Linear layer with column parallelism, splitting output features across distributed processes. + + Args: + in_features (int): Number of input features. + out_features (int): Total number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None): + assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" + self.part_out_features = out_features // world_size + super().__init__(in_features, self.part_out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for column parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with column-parallel computation. + """ + y = linear(x, self.weight, self.bias, self.scale_fmt) + return y + + +class RowParallelLinear(Linear): + """ + Linear layer with row parallelism, splitting input features across distributed processes. + + Args: + in_features (int): Total number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = False, + reduce_output=True, + dtype=None): + assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" + self.part_in_features = in_features // world_size + self.reduce_output = reduce_output + super().__init__(self.part_in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for row parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with row-parallel computation. + """ + y = linear(x, self.weight, None, self.scale_fmt) + if self.reduce_output and world_size > 1: + y = y.float() + dist.all_reduce(y) + if self.bias is not None: + y += self.bias + return y.type_as(x) + + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization (RMSNorm). + + Args: + dim (int): Dimension of the input tensor. + eps (float): Epsilon value for numerical stability. Defaults to 1e-6. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None): + """ + Forward pass for RMSNorm. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor with the same shape as input. + """ + dtype = x.dtype + if residual is None: + x = x.float() + var = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var + self.eps) + return (self.weight * x).to(dtype) + else: + x = residual = x.float() + residual.float() + var = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var + self.eps) + return (self.weight * x).to(dtype), residual.to(dtype) + + +class LayerNorm(nn.Module): + """ + Layer Normalization. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x) + + +def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (ModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.qk_rope_head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + + def find_correction_dim(num_rotations, dim, base, max_seq_len): + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + freqs = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if seqlen > args.original_seq_len: + low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + t = torch.arange(seqlen) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + + +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from fast_hadamard_transform import hadamard_transform + hidden_size = x.size(-1) + return hadamard_transform(x, scale=hidden_size**-0.5) + + +class Indexer(torch.nn.Module): + + def __init__(self, args: ModelArgs): + super().__init__() + self.dim: int = args.dim + self.n_heads: int = args.index_n_heads + self.n_local_heads = args.index_n_heads // world_size + self.head_dim: int = args.index_head_dim + self.rope_head_dim: int = args.qk_rope_head_dim + self.index_topk: int = args.index_topk + self.q_lora_rank: int = args.q_lora_rank + self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim) + self.wk = Linear(self.dim, self.head_dim) + self.k_norm = LayerNorm(self.head_dim) + self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype()) + self.softmax_scale = self.head_dim**-0.5 + self.scale_fmt = args.scale_fmt + + self.register_buffer( + "k_cache", + torch.zeros( + args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), + persistent=False) + self.register_buffer( + "k_scale_cache", + torch.zeros( + args.max_batch_size, + args.max_seq_len, + self.head_dim // block_size, + dtype=torch.float32), + persistent=False) + + def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor]): + bsz, seqlen, _ = x.size() + end_pos = start_pos + seqlen + q = self.wq_b(qr) + q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim) + q_pe, q_nope = torch.split( + q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + q = torch.cat([q_pe, q_nope], dim=-1) + k = self.wk(x) + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2) + k = torch.cat([k_pe, k_nope], dim=-1) + q = rotate_activation(q) + k = rotate_activation(k) + q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt) + self.k_cache[:bsz, start_pos:end_pos] = k_fp8 + self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale + weights = self.weights_proj(x) * self.n_heads**-0.5 + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + index_score = fp8_index(q_fp8.contiguous(), weights, + self.k_cache[:bsz, :end_pos].contiguous(), + self.k_scale_cache[:bsz, :end_pos].contiguous()) + if mask is not None: + index_score += mask + topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1] + topk_indices_ = topk_indices.clone() + dist.broadcast(topk_indices_, src=0) + assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}" + return topk_indices + + +def weight_dequant(weight, scale): + shape = weight.shape + assert weight.dim() == 2 + weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size, + block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size) + weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view( + shape[0] // block_size, shape[1] // block_size, block_size, + block_size).transpose(1, 2).contiguous().view(shape) + return weight + + +class MLA(nn.Module): + """ + Multi-Head Latent Attention (MLA) Layer. + + Attributes: + dim (int): Dimensionality of the input features. + n_heads (int): Number of attention heads. + n_local_heads (int): Number of local attention heads for distributed systems. + q_lora_rank (int): Rank for low-rank query projection. + kv_lora_rank (int): Rank for low-rank key/value projection. + qk_nope_head_dim (int): Dimensionality of non-positional query/key projections. + qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections. + qk_head_dim (int): Total dimensionality of query/key projections. + v_head_dim (int): Dimensionality of value projections. + softmax_scale (float): Scaling factor for softmax in attention computation. + """ + + def __init__(self, args: ModelArgs): + super().__init__() + self.dim = args.dim + self.n_heads = args.n_heads + self.n_local_heads = args.n_heads // world_size + self.q_lora_rank = args.q_lora_rank + self.kv_lora_rank = args.kv_lora_rank + self.qk_nope_head_dim = args.qk_nope_head_dim + self.qk_rope_head_dim = args.qk_rope_head_dim + self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim + self.v_head_dim = args.v_head_dim + + self.wq_a = Linear(self.dim, self.q_lora_rank) + self.q_norm = RMSNorm(self.q_lora_rank) + self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) + self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) + self.kv_norm = RMSNorm(self.kv_lora_rank) + self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, + self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) + self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) + self.softmax_scale = self.qk_head_dim**-0.5 + if args.max_seq_len > args.original_seq_len: + mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.indexer = Indexer(args) + + self.register_buffer( + "kv_cache", + torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), + persistent=False) + self.register_buffer( + "pe_cache", + torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), + persistent=False) + self.dequant_wkv_b = None + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor]): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + start_pos (int): Starting position in the sequence for caching. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + end_pos = start_pos + seqlen + qr = self.q_norm(self.wq_a(x)) + q = self.wq_b(qr) + q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + kv = self.wkv_a(x) + kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv = self.kv_norm(kv) + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) + self.kv_cache[:bsz, start_pos:end_pos] = kv + self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + if mask is not None: # MHA prefill + q = torch.cat([q_nope, q_pe], dim=-1) + kv = self.wkv_b(kv) + kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) + scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale + + # indexer + topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask) + index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), + device=x.device).scatter_(-1, topk_indices, 0) + index_mask += mask + scores += index_mask.unsqueeze(2) + + scores = scores.softmax(dim=-1, dtype=torch.float32) + x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v) + else: # MHA decode + if self.dequant_wkv_b is None and self.wkv_b.scale is not None: + self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale) + wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b + wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) + q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) + scores = (torch.einsum("bshc,btc->bsht", q_nope.float(), + self.kv_cache[:bsz, :end_pos].float()) + + torch.einsum("bshr,btr->bsht", q_pe.float(), + self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale + + # indexer + topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask) + index_mask = torch.full((bsz, 1, end_pos), float("-inf"), + device=x.device).scatter_(-1, topk_indices, 0) + scores += index_mask.unsqueeze(2) + + scores = scores.softmax(dim=-1, dtype=torch.float32) + x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos]) + x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) + x = self.wo(x.flatten(2)) + return x + + +class MLP(nn.Module): + """ + Multi-Layer Perceptron (MLP) used as a feed-forward layer. + + Attributes: + w1 (nn.Module): Linear layer for input-to-hidden transformation. + w2 (nn.Module): Linear layer for hidden-to-output transformation. + w3 (nn.Module): Additional linear layer for feature transformation. + """ + + def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True): + """ + Initializes the MLP layer. + + Args: + dim (int): Input and output dimensionality. + inter_dim (int): Hidden layer dimensionality. + """ + super().__init__() + self.w1 = ColumnParallelLinear(dim, inter_dim) + self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output) + self.w3 = ColumnParallelLinear(dim, inter_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the MLP layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after MLP computation. + """ + return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)) + + +class Gate(nn.Module): + """ + Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. + + Attributes: + dim (int): Dimensionality of input features. + topk (int): Number of top experts activated for each input. + n_groups (int): Number of groups for routing. + topk_groups (int): Number of groups to route inputs to. + score_func (str): Scoring function ('softmax' or 'sigmoid'). + route_scale (float): Scaling factor for routing weights. + weight (torch.nn.Parameter): Learnable weights for the gate. + bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. + """ + + def __init__(self, args: ModelArgs): + """ + Initializes the Gate module. + + Args: + args (ModelArgs): Model arguments containing gating parameters. + """ + super().__init__() + self.dim = args.dim + self.topk = args.n_activated_experts + self.n_groups = args.n_expert_groups + self.topk_groups = args.n_limited_groups + self.score_func = args.score_func + self.route_scale = args.route_scale + self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) + self.bias = nn.Parameter(torch.empty(args.n_routed_experts, + dtype=torch.float32)) if self.dim == 7168 else None + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the gating mechanism. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. + """ + scores = linear(x.float(), self.weight.float()) + if self.score_func == "softmax": + scores = scores.softmax(dim=-1) + else: + scores = scores.sigmoid() + original_scores = scores + if self.bias is not None: + scores = scores + self.bias + if self.n_groups > 1: + scores = scores.view(x.size(0), self.n_groups, -1) + if self.bias is None: + group_scores = scores.amax(dim=-1) + else: + group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) + indices = group_scores.topk(self.topk_groups, dim=-1)[1] + mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) + scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) + indices = scores.topk(self.topk, dim=-1)[1] + weights = original_scores.gather(1, indices) + if self.score_func == "sigmoid": + weights /= weights.sum(dim=-1, keepdim=True) + weights *= self.route_scale + return weights, indices + + +class Expert(nn.Module): + """ + Expert layer for Mixture-of-Experts (MoE) models. + + Attributes: + w1 (nn.Module): Linear layer for input-to-hidden transformation. + w2 (nn.Module): Linear layer for hidden-to-output transformation. + w3 (nn.Module): Additional linear layer for feature transformation. + """ + + def __init__(self, dim: int, inter_dim: int): + """ + Initializes the Expert layer. + + Args: + dim (int): Input and output dimensionality. + inter_dim (int): Hidden layer dimensionality. + """ + super().__init__() + self.w1 = Linear(dim, inter_dim) + self.w2 = Linear(inter_dim, dim) + self.w3 = Linear(dim, inter_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the Expert layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after expert computation. + """ + return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)) + + +class MoE(nn.Module): + """ + Mixture-of-Experts (MoE) module. + + Attributes: + dim (int): Dimensionality of input features. + n_routed_experts (int): Total number of experts in the model. + n_local_experts (int): Number of experts handled locally in distributed systems. + n_activated_experts (int): Number of experts activated for each input. + gate (nn.Module): Gating mechanism to route inputs to experts. + experts (nn.ModuleList): List of expert modules. + shared_experts (nn.Module): Shared experts applied to all inputs. + """ + + def __init__(self, args: ModelArgs): + """ + Initializes the MoE module. + + Args: + args (ModelArgs): Model arguments containing MoE parameters. + """ + super().__init__() + self.dim = args.dim + assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" + self.n_routed_experts = args.n_routed_experts + self.n_local_experts = args.n_routed_experts // world_size + self.n_activated_experts = args.n_activated_experts + self.experts_start_idx = rank * self.n_local_experts + self.experts_end_idx = self.experts_start_idx + self.n_local_experts + self.gate = Gate(args) + self.experts = nn.ModuleList([ + Expert(args.dim, args.moe_inter_dim) + if self.experts_start_idx <= i < self.experts_end_idx else None + for i in range(self.n_routed_experts) + ]) + self.shared_experts = MLP( + args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the MoE module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after expert routing and computation. + """ + shape = x.size() + x = x.view(-1, self.dim) + weights, indices = self.gate(x) + y = torch.zeros_like(x, dtype=torch.float32) + counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() + for i in range(self.experts_start_idx, self.experts_end_idx): + if counts[i] == 0: + continue + expert = self.experts[i] + idx, top = torch.where(indices == i) + y[idx] += expert(x[idx]) * weights[idx, top, None] + y += self.shared_experts(x) + if world_size > 1: + dist.all_reduce(y) + return y.type_as(x).view(shape) + + +class Block(nn.Module): + """ + Transformer block combining attention and feed-forward layers. + + Attributes: + attn (nn.Module): Attention layer (MLA). + ffn (nn.Module): Feed-forward network (MLP or MoE). + attn_norm (nn.Module): Layer normalization for attention. + ffn_norm (nn.Module): Layer normalization for feed-forward network. + """ + + def __init__(self, layer_id: int, args: ModelArgs): + """ + Initializes the Transformer block. + + Args: + layer_id (int): Layer index in the transformer. + args (ModelArgs): Model arguments containing block parameters. + """ + super().__init__() + self.attn = MLA(args) + self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) + self.attn_norm = RMSNorm(args.dim) + self.ffn_norm = RMSNorm(args.dim) + + def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, + freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position in the sequence. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. + + Returns: + torch.Tensor: Output tensor after block computation. + """ + if residual is None: + x, residual = self.attn_norm(x), x + else: + x, residual = self.attn_norm(x, residual) + x = self.attn(x, start_pos, freqs_cis, mask) + x, residual = self.ffn_norm(x, residual) + x = self.ffn(x) + return x, residual + + +class Transformer(nn.Module): + """ + Transformer model with positional embeddings, multiple layers, and output projection. + + Attributes: + max_seq_len (int): Maximum sequence length for the transformer. + embed (nn.Module): Embedding layer for input tokens. + layers (torch.nn.ModuleList): List of transformer blocks. + norm (nn.Module): Layer normalization applied after all blocks. + head (nn.Module): Output projection layer mapping to vocabulary size. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + """ + + def __init__(self, args: ModelArgs): + """ + Initializes the Transformer model. + + Args: + args (ModelArgs): Model arguments containing transformer parameters. + """ + global world_size, rank + world_size = dist.get_world_size() if dist.is_initialized() else 1 + rank = dist.get_rank() if dist.is_initialized() else 0 + Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 + Linear.scale_fmt = args.scale_fmt + super().__init__() + self.max_seq_len = args.max_seq_len + self.embed = ParallelEmbedding(args.vocab_size, args.dim) + self.layers = torch.nn.ModuleList() + for layer_id in range(args.n_layers): + self.layers.append(Block(layer_id, args)) + self.norm = RMSNorm(args.dim) + # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later. + self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32) + self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int = 0): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0. + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + seqlen = tokens.size(1) + freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen] + mask = torch.full( + (seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None + h, residual = self.embed(tokens), None + for layer in self.layers: + h, residual = layer(h, residual, start_pos, freqs_cis, mask) + h, _ = self.norm(h, residual) + logits = self.head(h[:, -1].float()) + if world_size > 1: + all_logits = [torch.empty_like(logits) for _ in range(world_size)] + dist.all_gather(all_logits, logits) + logits = torch.cat(all_logits, dim=-1) + return logits + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device("cuda") + torch.manual_seed(0) + args = ModelArgs() + x = torch.randint(0, args.vocab_size, (2, 128)) + model = Transformer(args) + print(model(x).size()) diff --git a/examples/deepseek_v32/inference/requirements.txt b/examples/deepseek_v32/inference/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..604fed552ca3f44307e1fe3a27bab5ba01c3bc9e --- /dev/null +++ b/examples/deepseek_v32/inference/requirements.txt @@ -0,0 +1,5 @@ +torch +transformers +safetensors +fast_hadamard_transform +tilelang==0.1.6 \ No newline at end of file diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d8035c1ba0e10d4aa7cfadae0d8017ba1333f097 --- /dev/null +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -0,0 +1,341 @@ +# ruff: noqa +import tilelang +from tilelang import language as T +import torch +from utils import assert_tensors_similar + + +@tilelang.jit(out_idx=[-1]) +def preprocess( + B, + S, + H, + D, + block_ND=32, + num_stages=5, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + shape = [B, S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([B, S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + B, + S_kv, + D, + D_tail, + kv_group=1, + block_N=64, + threads=128, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + dkv_shape = [B, S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): + T.copy( + dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, + }, +) +def bwd( + B, + S, + S_kv, + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=256, + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 + + if sm_scale is None: + sm_scale = (D + D_tail) ** (-0.5) + sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) + + H_kv = H // kv_group + q_shape = [B, S, H, D + D_tail] + k_shape = [B, S_kv, kv_group, D + D_tail] + o_shape = [B, S, H, D] + indices_shape = [B, S, kv_group, topk] + delta_shape = [B, S, H] + lse_shape = [B, S, H] + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz): + Q_shared = T.alloc_shared([padded_H, D], dtype) + Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([padded_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dQ_shared = T.alloc_shared([padded_H, D], dtype) + dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + + acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype) + acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype) + + max_kv_i = s_i + + T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + T.annotate_layout( + { + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + } + ) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + mask[bi_i] = Indices[by, s_i, bz, i_i * BS + bi_i] <= max_kv_i + + # Compute attention scores + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i] + + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * padded_H + h_i]) + + T.copy(acc_p, P_shared_cast) + + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale + + T.copy(acc_dp, dP_shared_cast) + T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) + + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + T.clear(acc_dkv_tail) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS // split_store, D // 4): + T.atomic_addx4( + dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): + T.atomic_addx4( + dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None): + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + B, S, H, dim_plus_tail_dim = q.shape + _, S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert kv.shape[0] == B + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (B, S, kv_group, topk) + assert lse.shape == (B, S, H) + + # Get kernels + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, dkv) + dkv = postprocess_kernel(dkv) + + return dq, dkv + + +def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True): + from sparse_mla_fwd import ref_sparse_mla_fwd_interface + + q = q.detach().clone() + kv = kv.detach().clone() + q.requires_grad = True + kv.requires_grad = True + o = ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale, is_casual) + o.backward(do) + return q.grad, kv.grad + + +def test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True): + # Prepare data + q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda") + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + # Forward + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) + + tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) + ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None) + + if check_correctness: + assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq") + assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") + print("assert_tensors_similar passed") + + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) + from tilelang.profiler import do_bench + + def fn(): + return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) + + ms = do_bench(fn, rep=100, warmup=250) + print(f"Average time: {ms:.3f} ms") + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c4d2f0463da163b05832f600ccc44c66ded58a --- /dev/null +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -0,0 +1,296 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from utils import assert_tensors_similar + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=2, + threads=256, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + O_shared = T.alloc_shared([H_per_block, D], dtype) + Lse_shared = T.alloc_shared([H_per_block], accum_dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_i, g_i = by, bz + s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, O_shared) + T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) + T.copy(sumexp, Lse_shared) + T.copy(sumexp, Lse[b_i, s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=64, num_stages=2, threads=256): + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + kernel = sparse_mla_fwd( + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) + out, lse = kernel(q, kv, indices) + return out, lse + + +def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : 1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd( + B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + + if check_correctness: + # otherwise may cause out of memory + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices) + assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out") + print("assert_tensors_similar passed") + + def fn(): + return sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=100, + warmup=250, + ) + print(f"Average time: {ms:.3f} ms") + print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_fwd( + B=1, + S=4096, + SKV=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, + ) diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..54e1a72090152dbb0d1e325784774f4f42efb5a9 --- /dev/null +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -0,0 +1,438 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from tilelang.engine.callback import register_cuda_postproc_callback +import argparse + + +@tilelang.jit( + out_idx=[-2, -1], + compile_flags=[ + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", + ], +) +def sparse_mla_fwd( + batch, + seq_len, + seq_len_kv, + heads, + dim, + tail_dim, + topk, + kv_stride, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=0, + threads=384, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + assert NI % 2 == 0, "NI should be a multiple of 2" + D = dim + D_tail = tail_dim + KV_stride = kv_stride + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): + Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) + Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype) + K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype) + K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype) + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + is_kv_valid = T.alloc_shared([BI], "bool", scope="shared") + + acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared") + alpha_local = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + indices_local = T.alloc_local([1], indices_dtype) + + # TODO: Multi buffer + bar_q = T.alloc_barrier(arrive_count=384) + bar_k_0_ready = T.alloc_barrier(arrive_count=128) + bar_k_1_ready = T.alloc_barrier(arrive_count=128) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) + + b_i, g_i = by, bz + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + q_i = q_start_index_s[0] + s_i + max_kv_i = (q_i + 1 - KV_stride) // KV_stride + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + tx = T.get_thread_binding() + + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + T.barrier_arrive(bar_q) + + if tx < 128: + T.set_max_nreg(240, 1) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + T.fill(acc_o_l, 0) + T.barrier_wait(bar_q, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + if i_i != 0: + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_0_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_0_free[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_1_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_1_free[0]) + + # Rescale + for h_i in T.Parallel(H_per_block): + sum_exp_shared[h_i] = sumexp[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) + + elif tx >= 128 and tx < 256: + T.set_max_nreg(168, 1) + T.fill(acc_o_r, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_0_r, acc_o_r) + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_sScale_and_sS_free) + + # Buffer 1 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_1_r, acc_o_r) + T.barrier_arrive(bar_k_1_free[0]) + if i_i != T.ceildiv(NI, 2) - 1: + T.barrier_arrive(bar_sScale_and_sS_free) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) + elif tx >= 256: + # producer + T.set_max_nreg(80, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + if is_kv_valid[r * 16 + (tx - 256) // 8]: + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + if is_kv_valid[r * 16 + (tx - 256) // 8]: + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v + ] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v + ] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ + b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v + ] + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + return main + + +def sparse_mla_fwd_interface( + q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False +): + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = 512 + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + if q_start_index_s != 0: + assert q_start_index_s > kv_stride, ( + "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + ) + CP0 = q_start_index_s == 0 + + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) + if print_kernel: + print(kernel.get_kernel_source()) + out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + if return_kernel: + return kernel + if q_start_index_s == 0 and kv_stride > 1: + out[:, : kv_stride - 1, :, :] = 0 + return out, lse + + +def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + if q_start_index_s is None: + q_start_index_s = sk * kv_stride - sq + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + num_kv_per_index = 1 + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : kv_stride - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd_pipelined( + B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024, check_correctness=True +): + KV_stride = 1 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + + def fn(): + out, lse = kernel(q, kv, indices, q_start_s_index_t) + if q_start_s_index == 0 and KV_stride > 1: + out[:, : KV_stride - 1, :, :] = 0 + return out, lse + + tl_out, tl_lse = fn() + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) + # print(f"tl_out: {tl_out}") + # print(f"ref_out: {ref_out}") + + torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=10, + warmup=10, + ) + print(f"Average time: {ms:.3f} ms") + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--test_correctness", action="store_true") + args = parser.parse_args() + if args.test_correctness: + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + else: + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness) diff --git a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7e879ba6e4fc1155660ad6be10da917c7a5ad5 --- /dev/null +++ b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -0,0 +1,41 @@ +# ruff: noqa +import tilelang +import tilelang.testing + +import topk_selector +import fp8_lighting_indexer +import sparse_mla_fwd +import sparse_mla_fwd_pipelined +import sparse_mla_bwd + + +def test_example_topk_selector(): + topk_selector.test_topk_selector() + + +def test_example_fp8_lighting_indexer(): + fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_sparse_mla_fwd(): + # small shapes for testing + sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_sparse_mla_fwd_pipelined(): + # small shapes for testing + sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_sparse_mla_bwd(): + sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..244f74c69615430ca32ea847529639a1ccf65cca --- /dev/null +++ b/examples/deepseek_v32/topk_selector.py @@ -0,0 +1,244 @@ +import torch +import tilelang +import tilelang.language as T + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, +} + + +def convert_to_uint16(x): + hval = T.Cast(T.float16, x) + bits_uint = T.reinterpret(T.uint16, hval) + bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000)) + return bits_uint >> 8 + + +def convert_to_uint32(x): + bits_uint = T.reinterpret(T.uint32, x) + bits_uint = T.if_then_else( + x < 0, + ~bits_uint & T.Cast(T.uint32, (0xFFFFFFFF)), + bits_uint | T.Cast(T.uint32, (0x80000000)), + ) + return bits_uint + + +@tilelang.jit(pass_configs=pass_configs) +def tl_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32): + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + RADIX = 1 << 8 + BLOCK_SIZE = 1024 + SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K + + @T.prim_func + def tl_topk_kernel( + input: T.Tensor[(batch, seq_len), in_dtype], + index: T.Tensor[(batch, topk), out_dtype], + starts: T.Tensor[(batch), out_dtype], + ends: T.Tensor[(batch), out_dtype], + ): + with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): + tx = T.get_thread_binding() + + s_threshold_bin_id = T.alloc_shared([1], T.int32) + s_histogram = T.alloc_shared([RADIX + 1], T.int32) + s_num_input = T.alloc_shared([2], T.int32) + s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32) + + l_threshold_bin_id = T.alloc_var(T.int32) + l_new_topk = T.alloc_var(T.int32) + l_num_input = T.alloc_var(T.int32) + l_bin_id32 = T.alloc_var(T.int32) + l_val = T.alloc_var(T.int32) + l_start_pos = T.alloc_var(T.int32) + l_start_idx = T.alloc_var(T.int32) + l_end_idx = T.alloc_var(T.int32) + l_out_pos = T.alloc_var(T.int32) + + l_new_topk = topk + l_start_idx = starts[bx] + l_end_idx = ends[bx] + + # stage 1: use 8bit to do quick topk + T.fill(s_histogram, 0) + T.fill(s_num_input[0], 0) + + T.sync_threads() + for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + input_idx = s * BLOCK_SIZE + tx + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: + inval_int16 = convert_to_uint16(input[bx, input_idx]) + T.atomic_add(s_histogram[inval_int16], 1) + T.sync_threads() + + # cumsum + if tx < RADIX: + for i in T.serial(8): + offset = 1 << i + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + l_val = s_histogram[tx] + s_histogram[tx + offset] + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + s_histogram[tx] = l_val + + # find threshold bin id + T.sync_threads(3, RADIX) + if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx + T.sync_threads() + l_threshold_bin_id = s_threshold_bin_id[0] + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] + T.sync_threads() + + # collect all elements with exponent โ‰ฅ threshold + for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + T.sync_threads() + input_idx = s * BLOCK_SIZE + tx + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: + bin_id = convert_to_uint16(input[bx, input_idx]) + l_bin_id32 = T.Cast(T.int32, bin_id) + if l_bin_id32 > l_threshold_bin_id: + # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + index[bx, pos] = input_idx + + elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + # pos = s_num_input[0] + pos = T.atomic_add(s_num_input[0], 1, return_prev=True) + s_input_idx[0, pos] = input_idx + + # stage 2: tail pass + for round in T.serial(4): + if l_new_topk <= 0: + T.loop_break() + + r_idx = round % 2 + l_start_pos = topk - l_new_topk + + T.sync_threads() + T.fill(s_histogram, 0) + if tx == 0: + s_num_input[r_idx ^ 1] = 0 + T.sync_threads() + + l_num_input = s_num_input[r_idx] + for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): + if s * BLOCK_SIZE + tx < l_num_input: + l_bin_id32 = T.Cast( + T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) + T.atomic_add(s_histogram[l_bin_id32], 1) + T.sync_threads() + # cumsum + if tx < RADIX: + for i in T.serial(8): + offset = 1 << i + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + l_val = s_histogram[tx] + s_histogram[tx + offset] + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + s_histogram[tx] = l_val + + # find threshold bin id + T.sync_threads(3, RADIX) + if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx + T.sync_threads() + + l_threshold_bin_id = s_threshold_bin_id[0] + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] + T.sync_threads() + + for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): + T.sync_threads() + if s * BLOCK_SIZE + tx < l_num_input: + l_bin_id32 = T.Cast( + T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + ) + if l_bin_id32 > l_threshold_bin_id: + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + if round == 3: + l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos + if l_out_pos < topk: + index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + else: + pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) + s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] + + return tl_topk_kernel + + +def tl_topk(input, starts, ends, topk): + batch, seq_len = input.shape + indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device) + kernel = tl_topk_impl(topk) + kernel(input, indexes, starts, ends) + return indexes + + +def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): + batch = 64 + seq_len = 32 * 1024 + topk = 2048 + torch.manual_seed(1) + input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() + starts = torch.zeros(batch, dtype=torch.int32).cuda() + ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len + + indexes = tl_topk(input, starts, ends, topk) + print(indexes) + + indexes_ref = torch.topk(input, topk, dim=-1)[1] + print(indexes_ref) + + # indexes_ref = fast_topk(input, topk) + # print(indexes_ref) + + # Calculate intersection of out_ref and out_trt + for i in range(batch): + ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() + trt_np = indexes[i].cpu().to(torch.int32).numpy() + + set_ref = set(ref_np) + set_trt = set(trt_np) + intersection = set_ref & set_trt + print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) + + # Performance test with CUDA events + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Warmup + for _ in range(5): + _ = tl_topk(input, starts, ends, topk) + torch.cuda.synchronize() + + n_iters = 20 + start_event.record() + for _ in range(n_iters): + _ = tl_topk(input, starts, ends, topk) + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms") + + # Torch topk time + start_event.record() + for _ in range(n_iters): + _ = torch.topk(input, topk, dim=-1)[1] + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") + + +if __name__ == "__main__": + test_topk_selector() diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7252e171108aa13396f6d3e91d84d04de1d3c17 --- /dev/null +++ b/examples/deepseek_v32/utils.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +import contextlib +import functools +import logging +import os +import sys +from enum import Enum +from functools import lru_cache +from typing import Any, Callable, Dict, Literal, Optional, Tuple + +from packaging import version + + +def _is_equal(a, b): + if isinstance(a, torch.Tensor): + return a is b + # Whitelist of types that are safe to compare by value for caching. + if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))): + return a == b + # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. + return False + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: Optional[Tuple] = None + last_kwargs: Optional[Dict] = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if last_args is not None and last_kwargs is not None: + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): + # For Tensors, check for object identity. For other types, check for equality. + # Python caches small integers, so `is` works for them but not for large integers like 4096. + if ( + all(_is_equal(a, b) for a, b in zip(args, last_args)) + and set(kwargs.keys()) == set(last_kwargs.keys()) + and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): + seq_idx = cu_seqlens.new_zeros(seq_len + 1) + seq_idx.scatter_add_(0, cu_seqlens[1:].long(), torch.ones_like(seq_idx)) + seq_idx.cumsum_(0) + return seq_idx[:-1] + + +@tensor_cache +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i + return seq_idx_for_q + + +@tensor_cache +def cal_cu_seqlen_ks_for_q( + cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int +) -> torch.IntTensor: + cu_seqlen_ks_for_each_q = torch.gather( + input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + return cu_seqlen_ks_for_each_q.int() + + +@tensor_cache +def cal_cu_seqlen_ke_for_q( + cu_seqlens_qs: torch.LongTensor, + cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, + cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, + seq_len: int, + kv_stride: int, +) -> torch.IntTensor: + cu_seqlen_ke_for_each_q = torch.gather( + input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = ( + torch.arange( + q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device + ) + + 1 + ) // kv_stride + cu_seqlens_ks[i] + cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) + return cu_seqlen_ke_for_each_q.int() + + +@tensor_cache +def cal_ks_ke_from_cu_seqlen_qk( + cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor = None, + offs_q: torch.LongTensor = None, + *, + seq_len: int, + kv_stride: int = 1, + cp_rank: int = 0, + cp_size: int = 1, + balanced_cp=False, +): + """ + seq_len: seq len per cp rank + balanced cp slice assignment: 0 1 2 3 3 2 1 0 + """ + n_seq = len(cu_seqlens_q) - 1 + assert n_seq > 0 + assert cu_seqlens_q.shape == (n_seq + 1,) + seq_idx = cal_seq_idx_from_cu_seqlens(cu_seqlens_q.long(), seq_len * cp_size) + qs = cu_seqlens_q.gather(0, seq_idx) + pos = torch.arange(len(qs), dtype=qs.dtype, device=qs.device) - qs + if offs_q is not None: + assert offs_q.shape == (n_seq,), offs_q.shape + qoff = offs_q.gather(0, seq_idx) + pos += qoff + if cu_seqlens_k is None or cu_seqlens_k is cu_seqlens_q: + ks = qs + else: + assert cu_seqlens_k.shape == (n_seq + 1,) + ks = cu_seqlens_k.gather(0, seq_idx) + ke = ks + (pos + 1) // kv_stride + + if cp_size == 1: + pass + elif balanced_cp: + assert cp_size % 2 == 0, cp_size + + def f(x: torch.Tensor): + chunks = x.chunk(cp_size * 2) + return torch.cat( + [ + chunks[cp_rank], + chunks[cp_size - cp_rank - 1], + ] + ) + + ks = f(ks) + ke = f(ke) + else: + ks = ks.chunk(cp_size)[cp_rank] + ke = ke.chunk(cp_size)[cp_rank] + + return ks, ke + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512): + total_seqlen = per_cp_seqlen * cp_size + + cu_seqlens = torch.randint(0, average_q_len * 2, (total_seqlen // average_q_len * 2,)).cuda() + last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0] + cu_seqlens = cu_seqlens[:last_seq_id] + + if cu_seqlens.sum() < total_seqlen: + cu_seqlens = torch.cat([cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()]) + + cu_seqlens_cumsum = torch.cumsum(cu_seqlens, dim=0) + cu_seqlens_k_cumsum = torch.cumsum(cu_seqlens // kv_stride, dim=0) + cu_seqlens_qs = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_cumsum[:-1]]) + cu_seqlens_ks = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_k_cumsum[:-1]]) + cu_seqlens_qe = cu_seqlens_cumsum.clone() + cu_seqlens_ke = cu_seqlens_k_cumsum.clone() + + cu_seqlens_ks_for_each_q = cal_cu_seqlen_ks_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + seq_len=total_seqlen, + ) + cu_seqlens_ke_for_each_q = cal_cu_seqlen_ke_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + cu_seqlens_ke=cu_seqlens_ke, + q_start_idxs=torch.zeros_like(cu_seqlens_qs), + seq_len=total_seqlen, + kv_stride=kv_stride, + ) + + assert per_cp_seqlen % 2 == 0 + per_chunk_seqlen = per_cp_seqlen // 2 + slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen) + slice_long = slice( + total_seqlen - (cp_rank + 1) * per_chunk_seqlen, + total_seqlen - cp_rank * per_chunk_seqlen, + ) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) + assert len(ks) == len(ke) == per_cp_seqlen + return ks, ke + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") + if raise_assert: + assert False # noqa: B011 + + +if __name__ == "__main__": + seq_len = 32768 + cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") + last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] + cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) + cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) + cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) + + from tilelang.profiler import do_bench + + fn = lambda: cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len) # noqa: E731 + ms = do_bench(fn, warmup=25, rep=100) diff --git a/examples/dequantize_gemm/README.md b/examples/dequantize_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0c6116775e57b9d02df3c5f49763fb9d9df509fc --- /dev/null +++ b/examples/dequantize_gemm/README.md @@ -0,0 +1,39 @@ + +### Dequantization GEMM + +An example of implementing a dequantization GEMM: + +```python +@T.prim_func +def dequant_matmul( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), +): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + + T.clear(Ct_local) + for k in T.Pipelined( + T.ceildiv(K, block_K), + num_stages=num_stages + ): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_packed_to_unsigned_convert("int", 8)( + num_bits, + B_local[i, j // 2], + j % 2, + dtype=in_dtype, + ) + T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, Ct[bx * block_N, by * block_M]) +``` + +**Notes:** Dequantize GEMM with magic layout transformations to get optimal performance can be found at project [BitBLAS](https://github.com/microsoft/BitBLAS), example kernels can be found at `testing/python/kernel/test_tilelang_dequantize_gemm.py`, detailed explanation and examples is coming soon. diff --git a/examples/dequantize_gemm/dequantize_utils.py b/examples/dequantize_gemm/dequantize_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90a6265ffa4bf22c3d583e74e53066161c80a37a --- /dev/null +++ b/examples/dequantize_gemm/dequantize_utils.py @@ -0,0 +1,148 @@ +import torch + + +def torch_convert_bit_twiddling(tensor): + """ + This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`. + + Parameters: + tensor (torch.Tensor): 2-D input tensor with dtype `torch.uint8`. Shape (N, K). + + Returns: + torch.Tensor: New tensor of dtype `torch.bfloat16` with shape (N, K*2), where each input column pair produces two bf16 output columns. + + Raises: + AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`. + """ + assert tensor.dim() == 2 and tensor.dtype == torch.uint8 + N, K = tensor.shape + assert K % 2 == 0, "Number of columns must be even" + + # Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA + val0 = tensor[:, 0::2].to(torch.int32) + val1 = tensor[:, 1::2].to(torch.int32) + val_concat = (val0 << 8) | val1 # (N, K//2), uint32 + + # Expand to match output shape where each pair generates 4 values + val_concat_expanded = val_concat.repeat_interleave(4, dim=1) # (N, K//2*4) + + # Positional encoding for bit-twiddling logic + pos = torch.arange(K * 2, device=tensor.device) % 4 # (K*2,) + + # Bit masks for decoding (as uint32 for CUDA compatibility) + mask = 0b1000000111000000 + mask1 = 0b1000000000000000 + mask2 = 0b0000000110000000 + mask3 = 0b0000000001000000 + + # Calculate results for all 4 positions in parallel + res0 = val_concat_expanded & mask + res1 = (val_concat_expanded << 3) & mask + res2 = (val_concat_expanded << 6) & mask + res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3) + + # Select the correct result based on position + bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3))) + + # Convert to uint16 for .view(torch.bfloat16) + bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16) + bf16_bf16 = bf16_uint16.view(torch.bfloat16) + + # Avoid integer overflow by using a float32 multiplier for the exponent scaling + bf16_new = bf16_bf16 * (2.0**126) + + return bf16_new + + +def torch_convert(tensor, scale_size=None, Scale=None): + """ + Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding. + + Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input. + + Parameters: + tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values. + scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale. + Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size]. + + Returns: + torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values. + """ + + def _convert(val, pos, scale=None): + assert val.dtype == torch.uint8 + # val = val.view(torch.int8) + mask = (1 << 4) - 1 + f4 = ((val >> (pos * 4)) & mask).to(torch.int16) + s = f4 >> 3 + e_f4 = (f4 & 6) >> 1 + e_f16 = e_f4 + 126 + if scale is not None: + e_f16 = min(e_f16 + scale, (1 << 8) - 1) + m_f4 = f4 & 1 + m_f16 = m_f4 + val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF + lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) + return lower_16_bits.view(torch.bfloat16) + + N = tensor.shape[0] + K = tensor.shape[1] + new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device) + for i in range(new_tensor.shape[0]): + for j in range(new_tensor.shape[1]): + if scale_size is not None: + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size]) + else: + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + return new_tensor + + +def print_bit(name, val): + """ + Print the 32-bit binary representation of a CPU scalar extracted from a PyTorch tensor. + + Converts `val` to CPU, reads its Python scalar with `.item()`, formats it as a 32-bit binary string, and prints it prefixed by `name`. + + Parameters: + name (str): Label printed before the binary representation. + val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown. + """ + val_cpu = val.cpu().item() + binary_repr = f"{val_cpu:032b}" + print(name, binary_repr) + + +def print_red_warning(message): + print(f"\033[31mWARNING: {message}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f"{name} all zero") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): + x_mask = torch.isfinite(x) + y_mask = torch.isfinite(y) + if not torch.all(x_mask == y_mask): + print_red_warning(f"{name} Error: isfinite mask mismatch") + if raise_assert: + raise AssertionError + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") + if raise_assert: + raise AssertionError + x = x.masked_fill(~x_mask, 0) + y = y.masked_fill(~y_mask, 0) + sim = calc_sim(x, y, name) + diff = (1.0 - sim).item() + print(f"{diff=}") + if not (0 <= diff <= eps): + print_red_warning(f"{name} Error: {diff=}") + if raise_assert: + raise AssertionError diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9c945b36083073e020b424725da949614d4df0 --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -0,0 +1,443 @@ +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +from tvm import tir +import torch +from dequantize_utils import torch_convert_bit_twiddling, torch_convert + + +def get_configs(): + """ + Return a list of tuning configuration dictionaries for the autotuned matmul kernel. + + Each dictionary is a single combination (Cartesian product) of the following parameters: + - block_M: tile size for M dimension (one of 64, 128, 256) + - block_N: tile size for N dimension (one of 64, 128, 256) + - block_K: tile size for K dimension + - num_stages: pipeline stages for K-loop (0 or 2) + - threads: number of threads to launch (128, 256, or 512) + - split: K-splitting factor (1 or 2) + + Returns: + list[dict]: List of configuration dicts usable by the autotuner, where each dict maps + the parameter name to its chosen value. + """ + import itertools + + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[128], + num_stages=[0, 2], + threads=[128, 256, 512], + split=[1, 2], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + fast_dequant=True, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T. + + This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts: + - A: dense input of shape (M, K) with dtype `in_dtype`. + - B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`. + - C: output of shape (M, N) with dtype `out_dtype`. + + The generated kernel supports two dequantization paths: + - fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group. + - simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element. + + Important behavior and requirements: + - num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits. + - QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes. + - Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid. + - When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group. + - The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages. + + Parameters that alter kernel layout/behavior (brief): + - block_M, block_N, block_K: tile sizes for M, N, and K dimensions. + - num_stages: number of software pipeline stages for the K-loop. + - threads: number of threads used per kernel block. + - split: extra K-splitting factor; K must be divisible by block_K * split. + - source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics. + + Returns: + A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. + """ + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + + QK = K // num_elems_per_byte + Block_QK = block_K // num_elems_per_byte + A_shape = (M, K) + B_shape = (N, QK) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, Block_QK) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + + from tilelang.quantize import get_mxfp_intrin_group + + # fast_dequant_bf16_fp4_twiddling + # It requires that the 2 consecutive uint8 elements (16bits) contains 4 fp4 elements in a bit-twiddling way. + # The bit-twiddling way is shown here: The pair (x,y) shows that the bit in this position is the y-th bit of the x-th fp4 element. + # (0,0)(3,0)(3,3)(1,0)(3,1)(3,2)(2,0)(0,1)(0,2)(0,3)(1,1)(1,2)(1,3)(2,1)(2,2)(2,3) + mxfp_intrin_info = get_mxfp_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + use_twiddling=True, + ) + + import_source = mxfp_intrin_info["c_source"] + func_name = mxfp_intrin_info["func_name"] + assert import_source is not None, "mxfp_intrin_info is not found" + assert func_name is not None, "mxfp_intrin_info is not found" + import_source = import_source + + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin. + + This function validates the requested input/output datatypes and returns a TileLang `@T.macro` named `fast_dequant_bf16_fp4_twiddling` which: + - Loads compressed FP4 bytes from a shared buffer into per-thread local registers (vectorized loads). + - Invokes an external dequantization routine (via `T.call_extern`) to expand the packed FP4 values into BF16 in registers. + - Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel. + + Notes and preconditions: + - Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`. + - The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel. + - The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly. + - The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + # Some variables for dequantization in each thread + MAX_TRANSACTION_SIZE_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits + local_compress_size = local_size // num_elems_per_byte + + @T.macro + def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared): + # import fast_dequantize plugin + """ + Fast dequantization kernel routine that converts packed FP4 values in shared memory to BF16 and writes the results back into a shared dequantized buffer. + + This function is intended to run inside a tiled GPU kernel: each thread loads a small packed segment from the quantized shared buffer `B_shared` into a per-thread local register buffer, calls an external dequantization routine (provided by the runtime plugin imported from `import_source` and identified by `func_name`) to expand the packed values to BF16 in a per-thread local output buffer, and stores the expanded values into `B_dequantize_shared`. It performs vectorized per-thread loads and stores and is sized according to the surrounding kernel's tiling and threading parameters. + + Parameters: + B_shared: Shared-memory buffer containing packed quantized values (packed FP4 layout). + B_dequantize_shared: Shared-memory buffer to receive dequantized BF16 values (written in-place by this routine). + + Side effects: + - Imports the external dequantization plugin via `import_source` and invokes `func_name`. + - Writes dequantized BF16 results into `B_dequantize_shared`. + + Notes: + - This routine expects the surrounding kernel to define and provide the tiling/threading constants (e.g., thread count, local buffer sizes, block dimensions) and the runtime plugin identifiers (`import_source`, `func_name`). + - No value is returned; results are produced by mutation of `B_dequantize_shared`. + """ + T.import_source(import_source) + + tx = T.get_thread_binding() + + B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) + B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) + for i in T.serial(0, block_N * block_K // threads // local_size): + # First, load data from share memory to register. + # Prepare for dequant. + for v in T.vectorized(0, local_compress_size): + index = i * threads * local_compress_size + tx * local_compress_size + v + B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] + + # Then, dequant. + T.call_extern( + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + dtype=out_dtype, + ) + + # Finally, store the dequantized data to shared memory. + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] + + return fast_dequant_bf16_fp4_twiddling + + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16. + + The returned macro (named `simple_dequant_bf16_fp4`) expects B_shared and B_dequantize_shared buffers (shapes and a few loop/constant names like + `B_shared_shape`, `B_dequantize_shared_shape`, `storage_dtype`, `out_dtype`, `num_bits`, `num_elems_per_byte`, `block_N`, and `block_K`) to be available in the surrounding TIR scope. It: + - Unpacks 4-bit FP values from the packed uint8 representation in B_shared. + - Converts each 4-bit value to a bfloat16 element using an internal helper `_tir_u8_to_f4_to_bf16`. + - Writes the dequantized bfloat16 block into B_dequantize_shared. + + Constraints: + - Supports only in_dtype="fp4" and out_dtype=T.bfloat16. + - The helper assumes nbit == 4 and produces bfloat16 values. + - The macro uses a fixed test-scale of 0 (no per-element scaling) as written. + + Returns: + A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): + """ + Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. + + This helper extracts the 4-bit field located at the bit position `pos` within the + byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an + exponent `scale` offset to align it with bfloat16 exponent bias, clamps the + resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. + + Parameters: + nbit (int): Number of bits in the packed element; must be 4. + val (tir.PrimExpr): A uint8 value containing packed FP4 elements. + pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. + scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. + dtype (str): Target dtype string; must be T.bfloat16. + + Returns: + tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. + + Notes: + - The function asserts `nbit == 4`, `dtype == T.bfloat16`, and that `val.dtype` is T.uint8. + - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 + bit fields and clamps the computed exponent to fit into 8 bits. + """ + assert nbit == 4 + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 + e_bf16 = e_f4 + tir.const(126, T.uint16) + # Scale is the exponential part, within the representation of uint8 + # To handle the overflow, we use the max function to limit the exponential part to 8 bits + e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret( + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) + return val_bf16 + + @T.macro + def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared): + """ + Dequantize a packed FP4 uint8 shared buffer into BF16 and store the result into a shared dequantized buffer. + + This helper: + - Loads B_shared into a local fragment, converts each packed FP4 element to BF16 using `_tir_u8_to_f4_to_bf16`, and writes the dequantized values into B_dequantize_shared. + - Iterates in parallel over the logical block columns (block_N) and block_K, unpacking elements from bytes using `num_elems_per_byte`. + - Uses a fixed scale of 0 in the conversion (placeholder for testing); `num_bits` and `num_elems_per_byte` are expected to be available from the enclosing scope. + + Parameters: + B_shared: shared-memory buffer containing packed FP4 data (uint8-packed). + B_dequantize_shared: shared-memory buffer to receive BF16 dequantized values. + + Side effects: + Writes dequantized BF16 values into B_dequantize_shared. No return value. + """ + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_shared[i, j // num_elems_per_byte], + j % num_elems_per_byte, + 0, # No scale for test + dtype=out_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_shared) + + return simple_dequant_bf16_fp4 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), + ): + """ + Kernel entry for the tiled, pipelined matmul used by the generated prim_func. + + This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: + - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. + - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. + - Pipelines over K in chunks of `block_K` for `num_stages` stages: + - Loads A and packed B tiles into shared memory. + - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. + - Performs a GEMM accumulating into C_local with B transposed. + - Stores the accumulated block from C_local back to the global output C via C_shared. + + Parameters: + - A: input tile of shape (M, K) with dtype `in_dtype`. + - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). + - C: output tensor of shape (M, N) with dtype `out_dtype`. + + Side effects: + - Writes the computed output block into the global tensor `C`. + - Uses and updates shared memory buffers and per-thread accumulators. + + No value is returned. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.annotate_layout( + { + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) + + T.clear(C_local) + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + + if fast_dequant: + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared) + else: + get_simple_dequant_func()(B_shared, B_dequantize_shared) + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + + return main + + +def ref_program_twiddling(A, qB): + """ + Compute reference BF16 matrix multiply using bit-twiddled FP4 quantized B. + + Converts qB (a bit-twiddled, packed FP4 representation of matrix B) back to floating, + performs C = A @ B^T in full precision, and returns the result converted to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K). Treated as floating-point (converted to torch.float for compute). + qB (torch.Tensor): Bit-twiddled, packed FP4 representation of B (quantized). Shape corresponds to B's packed layout. + + Returns: + torch.Tensor: Result matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple(A, qB): + """ + Compute a reference BF16 matrix multiply using a simple (non-twiddled) dequantization of qB. + + Converts the quantized tensor `qB` to full-precision values via `torch_convert`, computes C = A @ B^T in float32, and casts the result to bfloat16 before returning. + + Parameters: + A (torch.Tensor): Left input matrix with shape (M, K). + qB (torch.Tensor): Quantized representation of the right matrix; expected to be compatible with `torch_convert` and represent a matrix whose transpose will be multiplied by A. + + Returns: + torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def main(m=256, n=256, k=256, fast_dequant=True, tune=False): + """ + Run and benchmark the tiled, optionally autotuned FP4->BF16 GEMM kernel and validate results against a PyTorch reference. + + This function builds a matmul kernel (either with autotuning or fixed tiling), obtains a profiler, validates numerical correctness against the appropriate reference implementation (bit-twiddled fast dequantization or simple dequantization), and runs a benchmark that prints measured latency (ms) and effective TFLOPs. + + Parameters: + m (int): Number of rows of A and output C (default 256). + n (int): Number of columns of B and output C (default 256). + k (int): Inner dimension (columns of A, rows of B) (default 256). + fast_dequant (bool): If True use the fast twiddling dequantization path and validate against the twiddling reference; otherwise use the simple dequant path (default True). + tune (bool): If True build the kernel with autotuning configurations; if False use a fixed tiling and threading configuration for reproducible benchmarking (default False). + + Side effects: + - Prints latency and TFLOPs to stdout. + - Raises an assertion via the profiler if the kernel's outputs do not match the chosen reference within the tolerances (rtol=0.01, atol=0.01). + """ + total_flops = 2 * m * n * k + if tune: + kernel = matmul(m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, fast_dequant=fast_dequant) + else: + kernel = matmul( + m, + n, + k, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=4, + fast_dequant=fast_dequant, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + if fast_dequant: + profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + main(256, 256, 256, True) + main(256, 256, 256, False) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0375a1db8d89c732f3053a8a3280ccfb7df940 --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -0,0 +1,547 @@ +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +from tvm import tir +import torch +from dequantize_utils import torch_convert_bit_twiddling, torch_convert + + +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): + """ + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be T.bfloat16). + + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8. + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ + assert nbit == 4 + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 + e_bf16 = e_f4 + tir.const(126, T.uint16) + # Scale is the exponential part, within the representation of uint8 + # To handle the overflow, we may use the min function to limit the exponential part to 8 bits + # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret( + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) + return val_bf16 + + +def get_configs(): + """ + Generate a list of hyperparameter configuration dictionaries for tuning. + + Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', + 'num_stages', 'threads', and 'split'. The function returns the Cartesian + product of the parameter value lists: + - block_M, block_N, block_K: tiling sizes (64, 128, 256) + - num_stages: pipeline stages (0, 2) + - threads: thread counts (128, 256, 512) + - split: K-splitting factor (1, 2) + + Returns: + List[dict]: A list of configuration dictionaries covering all combinations. + """ + import itertools + + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[64, 128, 256], + num_stages=[0, 2], + threads=[128, 256, 512], + split=[1, 2], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. + + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., T.bfloat16). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. + """ + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + QK = K // num_elems_per_byte + Block_QK = block_K // num_elems_per_byte + A_shape = (M, K) + B_shape = (N, QK) + Bias_shape = (M, N) + Scale_shape = (N, K // scale_size) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, Block_QK) + Bias_shared_shape = (block_M, block_N) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + + from tilelang.quantize import get_mxfp_intrin_group + + # fast_dequant_bf16_fp4_twiddling + mxfp_intrin_info = get_mxfp_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + use_twiddling=True, + ) + import_source = mxfp_intrin_info["c_source"] + func_name = mxfp_intrin_info["func_name"] + assert import_source is not None, "mxfp_intrin_info is not found" + assert func_name is not None, "mxfp_intrin_info is not found" + import_source = import_source + + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. + + The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: + - Loads packed FP4 elements from B_shared into per-thread local registers. + - Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values. + - Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two). + - Writes the scaled BF16 results into B_dequantize_shared. + + Notes: + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. + - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + # Some variables for dequantization in each thread + MAX_TRANSACTION_SIZE_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits + local_compress_size = local_size // num_elems_per_byte + + @T.macro + def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): + # import fast_dequantize plugin + """ + Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16 + in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4, + applying per-block scale factors from Scale. + + This routine is a tiled, thread-parallel helper that: + - Imports and calls an external dequantization function (via `import_source`/`func_name`) + to expand compressed uint8-packed FP4 values into BF16 fragments in-thread. + - Loads the corresponding per-block scale entry, interprets it as an exponent bias + (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. + - Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place. + + Parameters: + - B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout). + - B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values. + - Scale: per-block scale tensor; entries are interpreted such that the multiplicative scale + = 2^(Scale - 127). + - k: block index along the K dimension used to select the appropriate Scale entries. + + Side effects: + - Mutates B_dequantize_shared in shared memory. + - Calls an external intrinsic function (must be provided by the environment via `import_source` + and `func_name`) to perform the low-level unpacking/dequantization. + """ + T.import_source(import_source) + + tx = T.get_thread_binding() + bx = T.get_block_binding(0) + + B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) + B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) + Scale_local_thread = T.alloc_local((1,), storage_dtype) + Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) + + for i in T.serial(0, block_N * block_K // threads // local_size): + # First, load data from share memory to register. + # Prepare for dequant. + index_base = i * threads * local_compress_size + tx * local_compress_size + for v in T.vectorized(0, local_compress_size): + index = index_base + v + B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] + index_scale = index_base // (scale_size // num_elems_per_byte) + si = index_scale // (block_K // scale_size) + sj = index_scale % (block_K // scale_size) + Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj] + Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) + + # Then, dequant. + T.call_extern( + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + dtype=out_dtype, + ) + + # Finally, store the dequantized data to shared memory. + for v in T.Parallel(local_size): + B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0] + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] + + return fast_dequant_bf16_fp4_twiddling + + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. + + Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. + + Notes: + - Only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. + - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + @T.macro + def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): + """ + Dequantizes a packed 4-bit (FP4) block from B_shared into BF16 values in B_dequantize_shared using per-element scale exponents. + + Per-element behavior: + - Reads packed 4-bit entries from B_shared (uint8 storage, multiple nibbles per byte). + - Uses Scale to obtain an exponent term (stored as uint8) and reconstructs BF16 values via _tir_u8_to_f4_to_bf16. + - Writes the dequantized BF16 block into B_dequantize_shared. + + Parameters: + - B_shared: shared-memory buffer holding packed 4-bit values (uint8-packed layout). + - B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results. + - Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element. + - k: current block index along the K dimension (used to select the appropriate slice of Scale). + + Side effects: + - Mutates B_dequantize_shared by storing the dequantized BF16 fragment. + """ + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) + + bx = T.get_block_binding(0) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + Scale[ + bx * block_N + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 + dtype=out_dtype, + ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) + T.copy(B_dequantize_local, B_dequantize_shared) + + return simple_dequant_bf16_fp4 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), + ): + """ + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A ยท B^T after dequantization). + - The function writes results in-place into C. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) + + if with_bias: + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) + + if threads == 512: + T.disable_warp_group_reg_alloc() + + if with_bias: + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared) + T.copy(Bias_shared, C_local) + else: + T.clear(C_local) + + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + if fast_dequant: + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale, k) + else: + get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale, k) + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + + return main + + +def ref_program_twiddling(A, qB, Scale, Bias=None): + """ + Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. + + Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A ยท B^T in float, and casts the result to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K), used in floating precision. + qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. + Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. + + Returns: + torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_twiddling_with_bias(A, qB, Scale, Bias): + """ + Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. + + Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A ยท B^T in float, and casts the result to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K), used in floating precision. + qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. + Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. + Bias (torch.Tensor): Bias tensor with shape (M, N). + + Returns: + torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple(A, qB, Scale, Bias=None): + """ + Compute a BF16 matrix product A ยท B^T from a quantized B with simple (non-twiddling) dequantization. + + Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A ยท B^T and returns the result converted to bfloat16. + + Parameters: + - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). + - qB: Quantized representation of B accepted by `torch_convert`. + - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. + + Returns: + - 2D bfloat16 tensor C containing the matrix product A ยท B^T. + + No in-place modification is performed on inputs (a local floating copy of B is scaled). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple_with_bias(A, qB, Scale, Bias): + """ + Compute a BF16 matrix product A ยท B^T from a quantized B with simple (non-twiddling) dequantization. + + Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A ยท B^T and returns the result converted to bfloat16. + + Parameters: + + Returns: + - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). + - qB: Quantized representation of B accepted by `torch_convert`. + - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. + - Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul). + + + Returns: + - 2D bfloat16 tensor C containing the matrix product A ยท B^T. + + No in-place modification is performed on inputs (a local floating copy of B is scaled). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False): + """ + Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS. + + Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS. + + Parameters: + m (int): Number of rows of A / output rows. Default 256. + n (int): Number of columns of B / output columns. Default 256. + k (int): Reduction dimension. Default 256. + scale_size (int): Size of the per-block scale vector used for dequantization. Default 32. + fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True. + tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False. + + Returns: + None + """ + total_flops = 2 * m * n * k + + if tune: + kernel = matmul( + m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) + else: + kernel = matmul( + m, + n, + k, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + + if fast_dequant: + if with_bias: + profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) + else: + if with_bias: + profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + M, N, K = 256, 256, 256 + scale_size = 32 + main(M, N, K, scale_size, fast_dequant=True, with_bias=True) + main(M, N, K, scale_size, fast_dequant=False, with_bias=True) + main(M, N, K, scale_size, fast_dequant=True, with_bias=False) + main(M, N, K, scale_size, fast_dequant=False, with_bias=False) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py new file mode 100644 index 0000000000000000000000000000000000000000..9e90418bc75e73a4345525e7ecc339b254a0ab0c --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py @@ -0,0 +1,563 @@ +import tilelang +import tilelang.language as T +from tilelang import tvm as tvm +from tvm import DataType +from tvm import tir +import torch +from dequantize_utils import torch_convert_bit_twiddling, torch_convert + + +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): + """ + Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by + `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + + Parameters: + nbit (int): Number of bits in the packed field (must be 4). + val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. + pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). + scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). + dtype (str): Destination dtype string (must be T.bfloat16). + + Returns: + tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. + + Notes: + - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8. + - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + """ + assert nbit == 4 + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 + e_bf16 = e_f4 + tir.const(126, T.uint16) + # Scale is the exponential part, within the representation of uint8 + # To handle the overflow, we may use the min function to limit the exponential part to 8 bits + # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret( + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), + ) + return val_bf16 + + +def get_configs(): + """ + Generate a list of hyperparameter configuration dictionaries for tuning. + + Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', + 'num_stages', 'threads', and 'split'. The function returns the Cartesian + product of the parameter value lists: + - block_M, block_N, block_K: tiling sizes (64, 128, 256) + - num_stages: pipeline stages (0, 2) + - threads: thread counts (128, 256, 512) + - split: K-splitting factor (1, 2) + + Returns: + List[dict]: A list of configuration dictionaries covering all combinations. + """ + import itertools + + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[64, 128, 256], + num_stages=[0, 1, 2], + threads=[128, 256, 512], + split=[1, 2], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit( + out_idx=[-1], +) +def matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. + + The generated kernel accepts: + - A: dense matrix with element type `in_dtype`. + - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)). + - Scale: per-block scale/exponent information used to dequantize B. + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). + in_dtype (str): element type of A (e.g., "fp4" in this file). + out_dtype (str): output tensor element type (e.g., T.bfloat16). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the tiled, pipelined GEMM that: + - loads tiled blocks of A and packed B to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - writes the final MxN block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. + """ + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + QK = K // num_elems_per_byte + Block_QK = block_K // num_elems_per_byte + A_shape = (M, K) + B_shape = (N, QK) + Bias_shape = (M, N) + Scale_shape = (N, K // scale_size) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, Block_QK) + Bias_shared_shape = (block_M, block_N) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + + from tilelang.quantize import get_mxfp_intrin_group + + # fast_dequant_bf16_fp4_twiddling + mxfp_intrin_info = get_mxfp_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + use_twiddling=True, + ) + import_source = mxfp_intrin_info["c_source"] + func_name = mxfp_intrin_info["func_name"] + assert import_source is not None, "mxfp_intrin_info is not found" + assert func_name is not None, "mxfp_intrin_info is not found" + import_source = import_source + + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. + + The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: + - Loads packed FP4 elements from B_shared into per-thread local registers. + - Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values. + - Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two). + - Writes the scaled BF16 results into B_dequantize_shared. + + Notes: + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. + - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + # Some variables for dequantization in each thread + MAX_TRANSACTION_SIZE_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits + local_compress_size = local_size // num_elems_per_byte + + @T.macro + def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k): + # import fast_dequantize plugin + """ + Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16 + in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4, + applying per-block scale factors from Scale. + + This routine is a tiled, thread-parallel helper that: + - Imports and calls an external dequantization function (via `import_source`/`func_name`) + to expand compressed uint8-packed FP4 values into BF16 fragments in-thread. + - Loads the corresponding per-block scale entry, interprets it as an exponent bias + (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. + - Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place. + + Parameters: + - B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout). + - B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values. + - Scale: per-block scale tensor; entries are interpreted such that the multiplicative scale + = 2^(Scale - 127). + - k: block index along the K dimension used to select the appropriate Scale entries. + + Side effects: + - Mutates B_dequantize_shared in shared memory. + - Calls an external intrinsic function (must be provided by the environment via `import_source` + and `func_name`) to perform the low-level unpacking/dequantization. + """ + T.import_source(import_source) + + tx = T.get_thread_binding() + bx = T.get_block_binding(0) # noqa: F841 + + B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) + B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) + Scale_local_thread = T.alloc_local((1,), storage_dtype) + Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) + + for i in T.serial(0, block_N * block_K // threads // local_size): + # First, load data from share memory to register. + # Prepare for dequant. + index_base = i * threads * local_compress_size + tx * local_compress_size + for v in T.vectorized(0, local_compress_size): + index = index_base + v + B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] + index_scale = index_base // (scale_size // num_elems_per_byte) + si = index_scale // (block_K // scale_size) + sj = index_scale % (block_K // scale_size) + Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj] + Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) + + # Then, dequant. + T.call_extern( + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + dtype=out_dtype, + ) + + # Finally, store the dequantized data to shared memory. + for v in T.Parallel(local_size): + B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0] + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] + + return fast_dequant_bf16_fp4_twiddling + + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. + + Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. + + Notes: + - Only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. + - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + @T.macro + def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): + """ + Dequantizes a packed 4-bit (FP4) block from B_shared into BF16 values in B_dequantize_shared using per-element scale exponents. + + Per-element behavior: + - Reads packed 4-bit entries from B_shared (uint8 storage, multiple nibbles per byte). + - Uses Scale to obtain an exponent term (stored as uint8) and reconstructs BF16 values via _tir_u8_to_f4_to_bf16. + - Writes the dequantized BF16 block into B_dequantize_shared. + + Parameters: + - B_shared: shared-memory buffer holding packed 4-bit values (uint8-packed layout). + - B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results. + - Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element. + - k: current block index along the K dimension (used to select the appropriate slice of Scale). + + Side effects: + - Mutates B_dequantize_shared by storing the dequantized BF16 fragment. + """ + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) + + bx = T.get_block_binding(0) # noqa: F841 + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + Scale_shared[ + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 + dtype=out_dtype, + ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) + T.copy(B_dequantize_local, B_dequantize_shared) + + return simple_dequant_bf16_fp4 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), + Bias: T.Tensor(Bias_shape, out_dtype), + C: T.Tensor((M, N), out_dtype), + ): + """ + Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. + + This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. + + Parameters are self-descriptive in the signature; notable behaviors: + - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. + - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. + - The GEMM uses transpose_B=True (i.e., multiplies A ยท B^T after dequantization). + - The function writes results in-place into C. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + # To use 1D TMA, the last dim of Scale_shared must have stride=1 + # May use much more shared memory than necessary + Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) + + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) + + if with_bias: + T.annotate_layout( + { + Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), + } + ) + + if threads == 512: + T.disable_warp_group_reg_alloc() + + if with_bias: + # T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], + # Bias_shared) + # T.copy(Bias_shared, C_local) + T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], C_local) + else: + T.clear(C_local) + + # Use 1D TMA to load Scale + T.copy(Scale[bx * block_N : (bx + 1) * block_N, :], Scale_shared) + + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + if fast_dequant: + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) + else: + get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + + return main + + +def ref_program_twiddling(A, qB, Scale, Bias=None): + """ + Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. + + Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A ยท B^T in float, and casts the result to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K), used in floating precision. + qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. + Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. + + Returns: + torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_twiddling_with_bias(A, qB, Scale, Bias): + """ + Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results. + + Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A ยท B^T in float, and casts the result to bfloat16. + + Parameters: + A (torch.Tensor): Left operand with shape (M, K), used in floating precision. + qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling. + Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B. + Bias (torch.Tensor): Bias tensor with shape (M, N). + + Returns: + torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. + """ + dtypeC = T.bfloat16 + B = torch_convert_bit_twiddling(qB) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple(A, qB, Scale, Bias=None): + """ + Compute a BF16 matrix product A ยท B^T from a quantized B with simple (non-twiddling) dequantization. + + Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A ยท B^T and returns the result converted to bfloat16. + + Parameters: + - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). + - qB: Quantized representation of B accepted by `torch_convert`. + - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. + + Returns: + - 2D bfloat16 tensor C containing the matrix product A ยท B^T. + + No in-place modification is performed on inputs (a local floating copy of B is scaled). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def ref_program_simple_with_bias(A, qB, Scale, Bias): + """ + Compute a BF16 matrix product A ยท B^T from a quantized B with simple (non-twiddling) dequantization. + + Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A ยท B^T and returns the result converted to bfloat16. + + Parameters: + + Returns: + - A: 2D tensor representing the left operand (will be cast to float32 for the matmul). + - qB: Quantized representation of B accepted by `torch_convert`. + - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32. + - Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul). + + + Returns: + - 2D bfloat16 tensor C containing the matrix product A ยท B^T. + + No in-place modification is performed on inputs (a local floating copy of B is scaled). + """ + dtypeC = T.bfloat16 + B = torch_convert(qB) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32])) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias + C = C.to(torch.__getattribute__(dtypeC)) + return C + + +def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False): + """ + Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS. + + Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS. + + Parameters: + m (int): Number of rows of A / output rows. Default 256. + n (int): Number of columns of B / output columns. Default 256. + k (int): Reduction dimension. Default 256. + scale_size (int): Size of the per-block scale vector used for dequantization. Default 32. + fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True. + tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False. + + Returns: + None + """ + total_flops = 2 * m * n * k + + if tune: + kernel = matmul( + m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + ) + else: + kernel = matmul( + m, + n, + k, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + + if fast_dequant: + if with_bias: + profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) + else: + if with_bias: + profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01) + else: + profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + M, N, K = 256, 256, 256 + scale_size = 32 + main(M, N, K, scale_size, fast_dequant=True, with_bias=True) + main(M, N, K, scale_size, fast_dequant=False, with_bias=True) + main(M, N, K, scale_size, fast_dequant=True, with_bias=False) + main(M, N, K, scale_size, fast_dequant=False, with_bias=False) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py new file mode 100644 index 0000000000000000000000000000000000000000..37826874bc3dfe2ee980f3b8274d56cca8bed3c6 --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -0,0 +1,434 @@ +import torch +import torch.backends +import tilelang.testing +from tilelang import tvm as tvm +from tvm import DataType +import tilelang.language as T + +tilelang.testing.set_random_seed(0) + + +@tilelang.jit(out_idx=[2]) +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + num_bits=4, +): + from tilelang.quantize import _tir_packed_to_unsigned_convert + + num_elems_per_byte = 8 // num_bits + storage_dtype = T.int8 + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + tx = T.get_thread_binding() + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = i * threads * local_size_compressed + tx * local_size_compressed + v + vi = index // (block_K // num_elems_per_byte) + vj = index % (block_K // num_elems_per_byte) + B_local[v] = B_shared[vi, vj] + for v in T.serial(0, local_size): + B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi = index // block_K + vj = index % block_K + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + kernel = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + + out = profiler.run_once() + assert out is not None + + def ref_program(A, qB): + import torch + + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program) + + +@tvm.testing.requires_package("bitblas") +def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + transform_b, +): + from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout + from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitterWithLadderTransform, + ) + + from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + num_bits = 4 + num_elems_per_byte = 8 // num_bits + storage_dtype = T.int8 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + + warp_rows = 4 + warp_cols = 4 + warp_row_tiles = micro_size_x * warp_rows + warp_col_tiles = micro_size_y * warp_cols + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + reduce_k = 1 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = 32 if in_dtype == T.float16 else 64 + chunk = block_K // reduce_k + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + transform_kind_b=transform_b, + num_elems_per_byte=num_elems_per_byte, + ) + + vec_load_qb = 16 + if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype) + B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + reduced_accum_res = T.alloc_local(0, accum_dtype) + thread_binding = T.get_thread_binding(0) + rk = T.get_thread_binding(1) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + } + ) + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, (block_K // reduce_k)): + vk = rk * (block_K // reduce_k) + k + A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] + + # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load + for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + t = thread_binding + idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v + vkk = idx % (micro_size_k // num_elems_per_byte) + vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % ( + block_N // micro_size_y + ) + B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk] + + for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + rk=rk, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + rk=rk, + ) + + for j in T.serial(warp_cols): + local_size_b = mma_emitter.local_size_b + T.call_extern( + "handle", + "decode_i4u_to_f16", + T.address_of(B_local[j * local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * local_size_b]), + 8, + ) + + mma_emitter.mma(A_local, B_dequantize_local, C_local) + + if reduce_k > 1: + for n in T.serial(warp_rows * warp_cols * local_size): + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float16(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_local[n], + True, + reduced_accum_res[0], + rk, + dtype="handle", + ) + ) + if rk == 0: + C_local[n] = reduced_accum_res[0] + + if rk == 0: + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + for i, j in T.Parallel(block_M, (block_N // reduce_k)): + vj = rk * (block_N // reduce_k) + j + C[by * block_M + i, bx * block_N + vj] = C_shared[ + i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y + ] + + return main + + +def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + transform_b, +): + import bitblas + + matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) + + kernel = tilelang.compile(matmul, out_idx=[2]) + src_code = kernel.get_kernel_source() + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + + # src_code is the generated cuda source + assert src_code is not None + num_bits = 4 + num_elems_per_byte = 8 // num_bits + storage_dtype = T.int8 + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=transform_b, + transpose_matrix=True, + dequantize_bits=num_bits, + storage_dtype=storage_dtype, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype=in_dtype, + dequantize_bits=num_bits, + storage_dtype=storage_dtype, + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + QLB = ladder_permutate(qB.cpu()).cuda() + QLB = lop3_permutate(QLB.cpu()).cuda() + + kernel(A, QLB, C) + + latency = profiler.do_bench(warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + print("Ref C: ", ref_c) + print("C: ", C) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@tilelang.testing.requires_package("bitblas") +def test_run_dequantize_gemm(): + run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128) + run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128) + + +@tilelang.testing.requires_package("bitblas") +def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3) + + +def main(): + test_run_dequantize_gemm() + + +if __name__ == "__main__": + main() diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..79345771d6bd2e565547016c39e6ee2a4611201d --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -0,0 +1,284 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import * +from tvm import tir +import itertools +import torch +import argparse + + +def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == T.float16 + assert val.dtype == T.uint8 + # e_f4 == 0 -> e_f16 = 0 + # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14 + # s1e2m1 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + e_f16 = e_f4 + tir.const(14, T.uint16) + m_f4 = f4 & tir.const(1, T.uint16) + m_f16 = m_f4 + val_f16 = tir.reinterpret( + T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16) + ) + # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16) + return val_f16 + + +def torch_convert(tensor): + def print_bit(name, val): + val_cpu = val.cpu().item() + binary_repr = f"{val_cpu:032b}" + print(name, binary_repr) + + def _convert(val, pos): + assert val.dtype == torch.uint8 + val = val.view(torch.int8) + mask = (1 << 4) - 1 + f4 = ((val >> (pos * 4)) & mask).to(torch.int16) + s = f4 >> 3 + e_f4 = (f4 & 6) >> 1 + e_f16 = e_f4 + 14 + m_f4 = f4 & 1 + m_f16 = m_f4 + val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF + lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) + return lower_16_bits.view(torch.float16) + + N = tensor.shape[0] + K = tensor.shape[1] + new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device) + for i in range(new_tensor.shape[0]): + for j in range(new_tensor.shape[1]): + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + return new_tensor + + +@tilelang.jit(out_idx=[1]) +def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + B_shape = (N, K // num_elems_per_byte) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) + + return main + + +def test_fp4_fp16_convert_close(): + N, K = 256, 256 + block_N, block_K = 64, 64 + kernel = test_convert( + N, + K, + block_N, + block_K, + T.float16, + ) + + B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) + tl_out = kernel(B) + ref_out = torch_convert(B) + assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out) + print("Pass") + + +def get_configs(): + block_M = [64, 128] + block_N = [64, 128] + block_K = [128, 256] + num_stages = [1, 2] + threads = [128, 256] + splits = [1] + _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) + + configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs] + return configs + + +def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): + @tilelang.jit(out_idx=[2]) + def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + KK = K // split + + @T.prim_func + def main_split( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) + + T.clear(Ct_local) + for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): + T.copy(A[by * block_M, KK * bz + k * block_K], A_shared) + T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): + acc = T.alloc_fragment((block_N, block_M), out_dtype) + T.clear(acc) + for k in range(split): + for i, j in T.Parallel(block_N, block_M): + acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j] + T.copy(acc, Ct[bx * block_N, by * block_M]) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) + + T.clear(Ct_local) + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, Ct_shared) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) + + if split == 1: + return main + else: + return main_split + + if tune: + + @autotune(configs=get_configs(), warmup=10, rep=10) + @tilelang.jit(out_idx=[2]) + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None): + return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func + + return kernel() + else: + + def kernel(block_M, block_N, block_K, num_stages, threads, split=1): + return kernel_func(block_M, block_N, block_K, num_stages, threads, split) + + return kernel + + +def ref_program(A, qB): + dtypeC = T.float16 + B = torch_convert(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C.transpose(0, 1) + + +def main(m=256, n=256, k=256, tune=False): + total_flops = 2 * m * n * k + + if not tune: + kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune) + best_latency = best_result.latency + best_config = best_result.config + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=256, help="M") + parser.add_argument("--n", type=int, default=256, help="N") + parser.add_argument("--k", type=int, default=256, help="K") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + M, N, K = args.m, args.n, args.k + main(M, N, K, args.tune) diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py new file mode 100644 index 0000000000000000000000000000000000000000..61baa668e6eb0853c4a6c2d93a5e05dc3f254e46 --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -0,0 +1,198 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import * +from tvm import tir +import itertools +import torch +import argparse + + +def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == T.int8 + assert val.dtype == T.uint8 + + mask = tir.const((1 << nbit) - 1, T.uint8) + + i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask + + i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8)) + i8 = i8_shifted >> tir.const(4, T.int8) + return i8 + + +def get_configs(): + iter_params = dict( + block_M=[64, 128], + block_N=[64, 128], + block_K=[128, 256], + num_stages=[1, 2], + threads=[128, 256, 512], + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@tilelang.jit(out_idx=[1]) +def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + B_shape = (N, K // num_elems_per_byte) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1): + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) + + return main + + +def torch_convert(tensor): + def _convert(val, pos): + assert val.dtype == torch.uint8 + val = val.view(torch.int8) + mask = (1 << 4) - 1 + i4_shifted = (val >> (pos * 4)) & mask + i4 = (i4_shifted << 4) >> 4 + + return i4.view(torch.int8) + + N = tensor.shape[0] + K = tensor.shape[1] + new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device) + for i in range(new_tensor.shape[0]): + for j in range(new_tensor.shape[1]): + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + return new_tensor + + +def ref_program(A, qB): + dtypeC = T.int32 + B = torch_convert(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C.transpose(0, 1) + + +def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): + @tilelang.jit(out_idx=[2]) + def kernel_func(block_M, block_N, block_K, num_stages, threads): + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_local_shape = (block_N, block_K) + + assert K % (block_K) == 0 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + } + ) + + T.clear(Ct_local) + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, Ct_shared) + T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) + + return main + + if tune: + + @autotune(configs=get_configs(), warmup=10, rep=10) + @tilelang.jit(out_idx=[2]) + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): + return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func + + return kernel() + + else: + + def kernel(block_M, block_N, block_K, num_stages, threads): + return kernel_func(block_M, block_N, block_K, num_stages, threads) + + return kernel + + +def main(m=128, n=256, k=256, tune=False): + total_flops = 2 * m * n * k + if not tune: + kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 + ) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) + print("All checks pass.") + + latency = profiler.do_bench(warmup=50) + print(f"Tilelang: {latency} ms") + + else: + best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune) + best_latency = best_result.latency + best_config = best_result.config + print(f"Bset latency: {best_latency}") + print(f"Best config: {best_config}") + print(f"Best tflops: {total_flops / best_latency * 1e-9}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=512, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=512, help="Matrix dimension K") + parser.add_argument("--tune", action="store_true", help="Enable tuning") + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + main(M, N, K, args.tune) + # main(M, N, K, True) diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py new file mode 100644 index 0000000000000000000000000000000000000000..dea2e5ddd8a1e763cac46c1e07199b5c6077830c --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -0,0 +1,221 @@ +import tilelang +from tilelang import language as T +from typing import Optional, Callable, Any +import torch +from tilelang import DataType +from tilelang.quantize import ( + _tir_packed_int_to_int_convert, +) + + +@tilelang.jit +def dequantize_gemv( + M: int, + N: int, + K: int, + in_dtype: str, + out_dtype: str, + accum_dtype: str, + num_bits: int = 4, + storage_dtype: T.dtype = T.int8, + source_format: str = "uint", + n_partition: int = 4, + reduce_thread: int = 32, + fast_decoding: bool = False, + trans_A: bool = False, + trans_B: bool = True, + group_size: int = -1, + with_scaling: bool = False, +) -> Callable[..., Any]: + assert n_partition is not None, "n_partition must be provided" + assert reduce_thread is not None, ( + "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented" + ) + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + storage_type = "".join(c for c in storage_dtype if not c.isdigit()) + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + micro_size_k_compressed = micro_size_k // num_elems_per_byte + block_K = reduce_thread * micro_size_k + + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = (N, K // storage_nbit * num_bits) + C_shape = (M, N) + + dp4a_size = 4 + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + # Lazy import to decrease the startup time + # as intrin registry may take a while to load + from tilelang.quantize import get_lop3_intrin_group + + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=with_scaling, + with_zeros=False, + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = import_source + + @T.prim_func + def main( + A: T.Tensor[A_shape, in_dtype], + B: T.Tensor[B_shape, storage_dtype], + C: T.Tensor[C_shape, out_dtype], + ): + with T.Kernel( + T.ceildiv(N, n_partition), + M, + threads=(reduce_thread, n_partition), + ) as ( + bx, + by, + ): + A_local = T.alloc_local((micro_size_k,), in_dtype) + B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([micro_size_k], in_dtype) + accum_res = T.alloc_local((1,), accum_dtype) + reduced_accum_res = T.alloc_local((1,), accum_dtype) + + kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x") + ni = T.thread_binding(0, n_partition, thread="threadIdx.y") + + T.import_source(import_source) + + T.clear(accum_res) + for ko in T.serial(T.ceildiv(K, block_K)): + for v in T.vectorized(micro_size_k): + A_local[v] = A[by, ko * block_K + kr * micro_size_k + v] + + for v in T.vectorized(micro_size_k_compressed): + B_quant_local[v] = B[ + bx * n_partition + ni, + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v, + ] + + if fast_decoding: + T.call_extern( + func_name, + T.address_of(B_quant_local[0]), + T.address_of(B_dequantize_local[0]), + dtype=in_dtype, + ) + else: + for ki in T.serial(micro_size_k): + B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype + ) + + if use_dp4a: + for ki in T.serial(micro_size_k // dp4a_size): + T.dp4a( + A_local[ki * dp4a_size], + B_dequantize_local[ki * dp4a_size], + accum_res[0], + ) + else: + for ki in T.serial(micro_size_k): + accum_res[0] += A_local[ki] * B_dequantize_local[ki] + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + accum_res[0], + True, + reduced_accum_res[0], + kr, + dtype="handle", + ) + ) + if kr == 0: + C[by, bx * n_partition + ni] = reduced_accum_res[0] + + return main + + +def main() -> None: + M = 1 + N = 1024 + K = 1024 + in_dtype = T.float16 + out_dtype = T.float16 + accum_dtype = T.float16 + num_bits = 4 + storage_dtype = T.int8 + source_format = "uint" + n_partition = 4 + reduce_thread = 32 + fast_decoding = True + trans_A = False + trans_B = True + group_size = -1 + with_scaling = False + + kernel = dequantize_gemv( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + num_bits, + storage_dtype, + source_format, + n_partition, + reduce_thread, + fast_decoding, + trans_A, + trans_B, + group_size, + with_scaling, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() + qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() + + if fast_decoding: + from tilelang.quantize.utils import interleave_weight + + qB = interleave_weight(qB, num_bits, in_dtype) + kernel(A, qB, C) + + # int4 reference + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) + for j in range(B.shape[1]): + B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + print("C: ", C) + print("Ref C: ", ref_c) + # doesn't apply scaling, the absolute error is large + torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1) + + +if __name__ == "__main__": + main() diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..9921c6bfe2dcc9265fe4875d21835d93baef6b90 --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -0,0 +1,522 @@ +import tilelang +import tilelang.language as T +from tilelang.quantize import _tir_u8_to_f4_to_bf16 +from tilelang import tvm as tvm +from tvm import DataType +import torch +from dequantize_utils import torch_convert_bit_twiddling, assert_similar +from tilelang.autotuner import set_autotune_inputs +import argparse + + +def get_configs(): + """ + Generate a list of hyperparameter configuration dictionaries for tuning. + + Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', + 'num_stages', 'threads', and 'split'. The function returns the Cartesian + product of the parameter value lists: + - block_M, block_N, block_K: tiling sizes + - num_stages: pipeline stages + - threads: thread counts + - split: K-splitting factor + + Returns: + List[dict]: A list of configuration dictionaries covering all combinations. + """ + import itertools + + iter_params = dict( + block_M=[128], + block_N=[64, 128, 256], + block_K=[128], + num_stages=[0, 1, 2], + threads=[128, 256, 512], + split=[1], + ) + return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[-1]) +def matmul( + M, + N, + K, + topk, + E, + padding_M, + in_dtype, + out_dtype, + accum_dtype, + source_format=T.uint32, + num_bits=4, + scale_size=32, + fast_dequant=True, + with_bias=False, + block_M=128, + block_N=256, + block_K=128, + num_stages=2, + threads=256, + split=1, +): + """ + Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype. + + The generated kernel accepts: + - A: dense matrix with element type `in_dtype` and shape (M, K). + - B: packed quantized matrix for all experts, stored as uint8 with `num_bits` bits per element, shape (E, N, QK), where QK = K / (8/num_bits). + - Scale: per-expert, per-block scale/exponent information for dequantizing B, shape (E, N, K // scale_size). + - Bias: per-expert, per-output bias, shape (E, N). + - topk_weights: router weights for the top-k experts for each token, shape (M, topk). + - sorted_token_ids: flattened and padded tensor of token indices, shape (padding_M,). + - expert_ids: expert id for each token in the padded batch, shape (padding_M // block_M,). + - C: output tensor, shape (M, topk, N). + + The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths: + - fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization. + - fast_dequant (False): uses a simple elementwise dequantization helper. + + Parameters: + M, N, K (int): matrix dimensions (A is MxK, result is (M, topk, N)). K must be divisible by (block_K * split). + topk (int): number of experts selected per token. + E (int): number of experts. + padding_M (int): padded number of tokens after grouping and block alignment. + in_dtype (str): element type of A (e.g., T.bfloat16). + out_dtype (str): output tensor element type (e.g., T.bfloat16). + accum_dtype (str): accumulation type used for the inner GEMM. + source_format (str, optional): format string passed to intrinsic selector (default "uint"). + num_bits (int, optional): number of bits per quantized element in B (default 4). + scale_size (int, optional): number of elements grouped per scale entry (default 32). + fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + num_stages (int, optional): pipelining stages for K loop (default 2). + threads (int, optional): threads per block used by the kernel (default 256). + split (int, optional): split factor along K used by the scheduler (default 1). + with_bias (bool, optional): whether to add Bias to the output (default False). + + Returns: + A T.prim_func implementing the grouped, pipelined GEMM that: + - loads tiled blocks of A and packed B for each expert to shared memory, + - dequantizes B via the chosen path into a shared dequantized tile, + - performs a tiled GEMM accumulating into local fragments, + - applies per-token topk weights and bias, + - writes the final (M, topk, N) block to the global output tensor. + + Notes: + - The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. + - The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. + - An assertion enforces that K % (block_K * split) == 0. + """ + + num_elems_per_byte = 8 // num_bits + storage_dtype = T.uint8 + QK = K // num_elems_per_byte + Block_QK = block_K // num_elems_per_byte + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, Block_QK) + Bias_shared_shape = block_N + B_dequantize_shared_shape = (block_N, block_K) + assert K % (block_K * split) == 0 + + from tilelang.quantize import get_mxfp_intrin_group + + # fast_dequant_bf16_fp4_twiddling + mxfp_intrin_info = get_mxfp_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + use_twiddling=True, + ) + import_source = mxfp_intrin_info["c_source"] + func_name = mxfp_intrin_info["func_name"] + assert import_source is not None, "mxfp_intrin_info is not found" + assert func_name is not None, "mxfp_intrin_info is not found" + import_source = import_source + + # the dequant part is the same as in dequant_gemm + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): + """ + Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. + The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: + - Loads packed FP4 elements from B_shared into per-thread local registers. + - Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values. + - Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two). + - Writes the scaled BF16 results into B_dequantize_shared. + + Notes: + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. + - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. + - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. + """ + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + # Some variables for dequantization in each thread + MAX_TRANSACTION_SIZE_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits + local_compress_size = local_size // num_elems_per_byte + + @T.macro + def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k): + # import fast_dequantize plugin + """ + Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16 + in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4, + applying per-block scale factors from Scale. + + This routine is a tiled, thread-parallel helper that: + - Imports and calls an external dequantization function (via `import_source`/`func_name`) + to expand compressed uint8-packed FP4 values into BF16 fragments in-thread. + - Loads the corresponding per-block scale entry, interprets it as an exponent bias + (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. + - Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place. + + Parameters: + - B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout). + - B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values. + - Scale_shared: per-block scale tensor; entries are interpreted such that the multiplicative scale + = 2^(Scale - 127). + - k: block index along the K dimension used to select the appropriate Scale entries. + + Side effects: + - Mutates B_dequantize_shared in shared memory. + - Calls an external intrinsic function (must be provided by the environment via `import_source` + and `func_name`) to perform the low-level unpacking/dequantization. + """ + T.import_source(import_source) + + tx = T.get_thread_binding() + + B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) + B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) + Scale_local_thread = T.alloc_local((1,), storage_dtype) + Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) + + for i in T.serial(0, block_N * block_K // threads // local_size): + # First, load data from share memory to register. + # Prepare for dequant. + index_base = i * threads * local_compress_size + tx * local_compress_size + for v in T.vectorized(0, local_compress_size): + index = index_base + v + B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] + index_scale = index_base // (scale_size // num_elems_per_byte) + si = index_scale // (block_K // scale_size) + sj = index_scale % (block_K // scale_size) + Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj] + Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) + + # Then, dequant. + T.call_extern( + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + dtype=out_dtype, + ) + + # Finally, store the dequantized data to shared memory. + for v in T.Parallel(local_size): + B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0] + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v] + + return fast_dequant_bf16_fp4_twiddling + + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): + assert in_dtype in ["fp4"] + assert out_dtype in [T.bfloat16] + + @T.macro + def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) + + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + Scale_shared[ + i, k * block_K // scale_size + j // scale_size + ], # Scale is the exponential part, within the representation of uint8 + dtype=out_dtype, + ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) + T.copy(B_dequantize_local, B_dequantize_shared) + + return simple_dequant_bf16_fp4 + + @T.prim_func + def main( + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((E, N, QK), storage_dtype), + Scale: T.Tensor((E, N, K // scale_size), storage_dtype), + Bias: T.Tensor((E, N), out_dtype), + # Add fusedmoe tensors + topk_weights: T.Tensor((M * topk), out_dtype), + sorted_token_ids: T.Tensor((padding_M), T.int32), + expert_ids: T.Tensor((padding_M // block_M), T.int32), + C: T.Tensor((M, topk, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + topk_weights_shared = T.alloc_shared((block_M), out_dtype) + sorted_token_ids_shared = T.alloc_shared((block_M), T.int32) + expert_id = T.alloc_local((1), T.int32) # the expert id for the current block + # To use 1D TMA, the last dim of Scale_shared must have stride=1 + # May use much more shared memory than necessary + Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) + + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + C_shared: tilelang.layout.make_swizzled_layout(C_shared), + } + ) + T.use_swizzle(10) + + if threads == 512: + T.disable_warp_group_reg_alloc() + + T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared) + expert_id[0] = expert_ids[by] + + # Get the topk weights of each token in the current block + for i in T.Parallel(block_M): + if sorted_token_ids_shared[i] != -1: + topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]] + + # Get bias and scale based on the expert id + if with_bias: + T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared) + else: + T.clear(Bias_shared) + + T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared) + + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = Bias_shared[j] + + tx = T.get_thread_binding() + + for k in T.Pipelined(K // block_K, num_stages=num_stages): + # Each thread copies 4 bytes, local size is 16 + for copy_i in T.serial(block_M * block_K // threads // 16): + base = copy_i * threads * 16 + tx * 16 + if sorted_token_ids_shared[base // block_K] != -1: + for copy_j in T.vectorized(16): + A_shared[base // block_K, base % block_K + copy_j] = A[ + sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j + ] + + T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared) + if fast_dequant: + get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k) + else: + get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = C_local[i, j] * topk_weights_shared[i] + + T.copy(C_local, C_shared) + for copy_i in T.serial(block_M * block_N // threads // 16): + base = copy_i * threads * 16 + tx * 16 + if sorted_token_ids_shared[base // block_N] != -1: + for copy_j in T.vectorized(16): + C[ + sorted_token_ids_shared[base // block_N] // topk, + sorted_token_ids_shared[base // block_N] % topk, + bx * block_N + base % block_N + copy_j, + ] = C_shared[base // block_N, base % block_N + copy_j] + + return main + + +def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256): + dtypeC = T.bfloat16 + M, K = A.shape + E, N, QK = qB.shape + topk = topk_weights.shape[0] // M + scale_size = K // Scale.shape[2] + assert scale_size == 32 # MXFP4 + + # Initialize output tensor + C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda") + + # Iterate over sorted_token_ids + for idx in range(len(sorted_token_ids)): # padding_M + token_id = sorted_token_ids[idx] + if token_id == -1: + continue + expert_id = expert_ids[idx // block_M] + topk_idx = token_id % topk + + # Get the token embedding + token_embedding = A[token_id // topk] + + # Dequantize the expert weights + B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) + B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) + + # Compute the output for this token-expert pair + # token_embedding @ B.T + bias + output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id] + output = output.to(torch.__getattribute__(dtypeC)) + + # Apply the topk weight + weight = topk_weights[token_id] + output = output * weight + + # Store the result + C[token_id // topk, topk_idx] = output + + return C + + +def get_data(m, n, k, qk, scale_size, topk, E, block_M): + A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts. + Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda") + Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + + weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1) + # topk_weights: Router weights for the top-k experts for each token. + # Shape: (m, topk) + # tokens_experts: A flattened tensor of expert assignments for each token. + # For each of m tokens, topk unique experts are chosen. Shape: (m * topk,) + topk_weights, tokens_experts = torch.topk(weights, topk, dim=-1) + tokens_experts = tokens_experts.reshape(m * topk) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.reshape(m * topk) + + sorted_expert_vals, sorted_indices = torch.sort(tokens_experts, stable=True) + sorted_token_ids = sorted_indices + unique_expert_ids, counts = torch.unique_consecutive(sorted_expert_vals, return_counts=True) + expert_ids = [] + padded_token_ids = [] + start = 0 + for eid, cnt in zip(unique_expert_ids.tolist(), counts.tolist()): + end = start + cnt + group_token_ids = sorted_token_ids[start:end] + pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt + if pad_len > 0: + # -1 for padding (`M` instead in vLLM moe_align_block_size()) + group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")]) + padded_token_ids.append(group_token_ids) + expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M)) + start = end + + # sorted_token_ids: The final flattened and padded tensor of token indices. + sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) + # expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # ๏ผˆpadding_M,๏ผ‰ + padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding + + return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M + + +def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): + # Tunable parameters + block_M, block_N, block_K = 128, 256, 128 # noqa: F841 + num_stages = 1 # noqa: F841 + threads = 512 # noqa: F841 + split = 1 # noqa: F841 + + total_flops = 2 * m * n * k * topk + num_bits = 4 + num_elems_per_byte = 8 // num_bits + qk = k // num_elems_per_byte + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M) + + if tune: + with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): + # Autotune with inputs manually composed + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + else: + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + T.bfloat16, + T.bfloat16, + T.float32, + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, + ) + + output = kernel( + A, + qB, + Scale, + Bias, + topk_weights, + sorted_token_ids, + expert_ids, + ) + + print("Tilelang kernel run finished.") + + ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... + + latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + diff = (output - ref_output).abs() + max_val = diff.max() + max_idx = diff.argmax() + print(f"max abs diff: {max_val} at index: {max_idx}") + assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference + print("All checks pass. โœ…") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm + parser.add_argument("--N", type=int, default=5760, help="N") + parser.add_argument("--K", type=int, default=2944, help="K") + parser.add_argument("--scale_size", type=int, default=32, help="scale size") + parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token + parser.add_argument("--E", type=int, default=32, help="E") # number of experts + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + + main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune) diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..01bc40e6c944b9c4b1f9fff36da132b002055a2b --- /dev/null +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -0,0 +1,47 @@ +import tilelang.testing + +import example_dequant_gemv_fp16xint4 +import example_dequant_gemm_fp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper_tma +import example_dequant_groupedgemm_bf16_mxfp4_hopper +import example_dequant_gemm_w4a8 + + +@tilelang.testing.requires_cuda +def test_example_dequant_gemv_fp16xint4(): + example_dequant_gemv_fp16xint4.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_fp4_hopper(): + example_dequant_gemm_fp4_hopper.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_bf16_mxfp4_hopper(): + example_dequant_gemm_bf16_mxfp4_hopper.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_bf16_mxfp4_hopper_tma(): + example_dequant_gemm_bf16_mxfp4_hopper_tma.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): + example_dequant_groupedgemm_bf16_mxfp4_hopper.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_gemm_w4a8(): + example_dequant_gemm_w4a8.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/dsa_sparse_finetune/dsa.py b/examples/dsa_sparse_finetune/dsa.py new file mode 100644 index 0000000000000000000000000000000000000000..9fae8e5e3d698c9d7763b707fa2b2fd7506257c2 --- /dev/null +++ b/examples/dsa_sparse_finetune/dsa.py @@ -0,0 +1,223 @@ +from typing import Optional +import torch +import torch.nn.functional as F +from indexer_topk_reducesum import indexer_topk_reducesum_interface +from indexer_bwd import indexer_bwd_interface +from sparse_mla_fwd import sparse_mla_fwd_interface +from sparse_mla_bwd import sparse_mla_bwd +from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface +from einops import einsum, repeat +from utils import get_abs_err, get_err_ratio + + +class RegsiterLossFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + ctx.save_for_backward(loss) + return x + + @staticmethod + def backward(ctx, grad): + loss = ctx.saved_tensors + return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device) + + +register_loss = RegsiterLossFunction.apply + + +def ref_deepseek_sparse_attention_innner( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + dtype = q.dtype + q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights)) + + index_sm_scale = index_q.shape[-1] ** -0.5 + b, s = index_q.shape[:2] + + # tl_topk_indices = tl_topk_indices.to(torch.int64) + # tl_topk_indices[tl_topk_indices == -1] = s + + casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2") + index_logits = F.relu(index_logits) + index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale + index_logits = torch.where(casual_mask, index_logits, float("-inf")) + topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices + topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices) + topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) + index_topk_score = topk_score + + if sm_scale is None: + sm_scale = kv.shape[-1] ** -0.5 + + h = q.shape[-2] + index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_( + dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool) + )[:, :, :-1] + mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h) + k, v = kv, kv[..., :dim_v] + logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d") + + attn_score = attn_score.sum(dim=-2) # [b, s1, s2] + attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) + attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) + + loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum") + o = register_loss(o, loss) + + return o.to(dtype), topk_indices + + +def ref_deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + index_sm_scale: Optional[float] = None, +): + all_o, all_topk_indices = [], [] + for i in range(offsets.shape[0] - 1): + o, topk_indices = ref_deepseek_sparse_attention_innner( + q[None, offsets[i] : offsets[i + 1]], + kv[None, offsets[i] : offsets[i + 1]], + index_q[None, offsets[i] : offsets[i + 1]], + index_k[None, offsets[i] : offsets[i + 1]], + weights[None, offsets[i] : offsets[i + 1]], + topk, + dim_v, + sm_scale, + index_sm_scale, + ) + all_o.append(o.squeeze(0)) + all_topk_indices.append(topk_indices.squeeze(0)) + o = torch.cat(all_o, dim=0) + topk_indices = torch.cat(all_topk_indices, dim=0) + return o, topk_indices + + +class DSAFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, + ): + # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) + topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets) + o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v) + ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets) + ctx.topk = topk + ctx.dim_v = dim_v + ctx.sm_scale = sm_scale + return o, topk_indices + + @staticmethod + def backward( + ctx, + do: torch.Tensor, + _1: torch.Tensor, + ): + q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors + attn_score = sparse_mla_topk_reducesum_interface( + q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v + ).squeeze(-2) + dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale) + dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets) + return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None + + +def deepseek_sparse_attention( + q: torch.Tensor, + kv: torch.Tensor, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + offsets: torch.Tensor, + topk: int, + dim_v: int, + sm_scale: Optional[float] = None, +): + return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale) + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + index_D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_() + index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_() + weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_() + index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_() + do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_() + offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda() + + o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + o.backward(do) + q_grad, q.grad = q.grad, None + kv_grad, kv.grad = kv.grad, None + index_q_grad, index_q.grad = index_q.grad, None + index_k_grad, index_k.grad = index_k.grad, None + weights_grad, weights.grad = weights.grad, None + + ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D) + ref_o.backward(do) + ref_q_grad, q.grad = q.grad, None + ref_kv_grad, kv.grad = kv.grad, None + ref_index_q_grad, index_q.grad = index_q.grad, None + ref_index_k_grad, index_k.grad = index_k.grad, None + ref_weights_grad, weights.grad = weights.grad, None + + print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") + print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}") + print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}") + print( + f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" + ) + print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}") + print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}") + + intersections = [] + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + mask = trt_np != -1 + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + intersections.append(len(intersection) / len(set_ref)) + print("average intersections: {:.4f}".format(sum(intersections) / len(intersections))) + + +test_kernel() diff --git a/examples/dsa_sparse_finetune/index.py b/examples/dsa_sparse_finetune/index.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4800411004e5890faba0578cf83f09e27f2dc9 --- /dev/null +++ b/examples/dsa_sparse_finetune/index.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py +import torch +import torch.nn.functional as F +import functools +from typing import Callable, Any + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + (last_args is not None and last_kwargs is not None) + and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) + and all(a is b for a, b in zip(args, last_args, strict=False)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_lens( + lens: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0)) + + +@tensor_cache +def prepare_lens_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()]) + + +@tensor_cache +def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(cu_seqlens) + return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..68508ad4e45104b3b5717c95ef30ebfe1caaccd4 --- /dev/null +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -0,0 +1,254 @@ +import torch +import torch.nn.functional as F +from einops import einsum, repeat + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_bwd_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_I: int = 32, + num_stages: int = 0, + num_threads: int = 128, +): + assert num_stages == 0 + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_I == 0 + assert heads <= 64 and heads % 8 == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + dtype: str = BF16 + accum_dtype: str = FP32 + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + shape_p = [seq_len, topk] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.prim_func + def tl_indexer_bwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, dtype), + dIndexK: T.Tensor(index_k_shape, dtype), + AttnScore: T.Tensor(shape_p, FP32), + IndexScore: T.Tensor(shape_p, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos = Offsets[i_b] + num_blocks = T.ceildiv(topk, block_I) + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + weights_shared = T.alloc_shared([heads], dtype=dtype) + + d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype) + d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype) + + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.copy(Weights[bos + i_t, :], weights_shared) + T.fill(d_index_q_frag, 0) + T.fill(d_weights_frag, 0) + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + + for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): + i_st = bi_i * block_I + i_ed = (bi_i + 1) * block_I + + indices_shared = T.alloc_shared([block_I], dtype=INT32) + T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared) + + index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype) + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0) + + attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) + for i in T.Parallel(block_I): + attn_score_shared[i] = AttnScore[bos + i_t, i_st + i] + index_score_shared[i] = IndexScore[bos + i_t, i_st + i] + + logits = T.alloc_fragment((block_I, heads), accum_dtype) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + for i, j in T.Parallel(block_I, heads): + logits[i, j] = T.max(logits[i, j], 0) + + # dw + d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype) + for i, j in T.Parallel(block_I, heads): + d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j] + T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) + + d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype) + d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype) + d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype) + + for i, j in T.Parallel(block_I, heads): + d_relu = T.alloc_var(accum_dtype) + if logits[i, j] > 0: + d_relu = 1.0 + else: + d_relu = 0.0 + d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j] + + # dq + T.copy(d_logits_qk, d_logits_qk_cast1) + T.gemm( + d_logits_qk_cast1, # [BS, HQ] + index_k_shared, # [BS, K] + d_index_q_frag, # [HQ, K] + transpose_A=True, + transpose_B=False, + clear_accum=False, + ) + + # dk + T.copy(d_logits_qk, d_logits_qk_cast2) + d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype) + T.gemm( + d_logits_qk_cast2, # [BS, HQ] + index_q_shared, # [HQ, K] + d_index_k_frag, # [BS, K] + transpose_A=False, + transpose_B=False, + clear_accum=True, + ) + + for i, j in T.Parallel(block_I, dim): + pos = indices_shared[i] + if (pos > -1) & (pos <= i_t): + T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j]) + + for i, j in T.Parallel(heads, dim): + d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale + + T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :]) + T.copy(d_weights_frag, dWeights[bos + i_t, :]) + + return tl_indexer_bwd_kernel + + +def indexer_bwd_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + attn_score: torch.Tensor, + index_score: torch.Tensor, + topk_indices: torch.Tensor, + offsets: torch.Tensor, +): + _, heads, dim, topk = *q.shape, topk_indices.shape[-1] + token_indices = prepare_token_indices(offsets) + dq = torch.zeros_like(q) + dweights = torch.zeros_like(weights) + dk = torch.zeros_like(k) + kernel = tl_indexer_bwd_impl(heads, dim, topk) + kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices) + return dq, dweights, dk + + +def ref_indexer_bwd( + Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor +) -> torch.Tensor: + Q.requires_grad_(True) + Weights.requires_grad_(True) + K.requires_grad_(True) + softmax_scale = Q.shape[-1] ** -0.5 + all_loss = [] + all_log_topk_prob = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1] + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + attn_score = AttnScore[offsets[i] : offsets[i + 1]] + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale + logits = F.relu(logits) + score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) + score = torch.where(mask, score, float("-inf")) + topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64)) + log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32) + loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum") + all_loss.append(loss) + all_log_topk_prob.append(log_topk_prob) + loss = torch.stack(all_loss).sum() + loss.backward() + log_topk_prob = torch.cat(all_log_topk_prob, dim=0) + return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad + + +def test_kernel( + B=1, + S=2048, + H=16, + D=128, + topk=64, +): + torch.manual_seed(42) + q = torch.randn((S, H, D)).cuda().bfloat16() + w = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + all_attn_score = [] + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device) + logits = torch.ones(seq_len, topk).cuda() + logits = torch.where(mask, logits, float("-inf")) + attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) + all_attn_score.append(attn_score) + attn_score = torch.cat(all_attn_score, dim=0) + + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets) + + dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets) + + print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}") + print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}") + print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py new file mode 100644 index 0000000000000000000000000000000000000000..d76eb027247b9ce8fdf4cd20f422d7a79304eb3b --- /dev/null +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -0,0 +1,273 @@ +import math +import torch +import torch.nn.functional as F +from einops import einsum + +import tilelang as tl +import tilelang.language as T +from typing import Optional +from index import prepare_token_indices + +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_topk_reducesum_impl( + heads: int, + dim: int, + topk: int, + sm_scale: Optional[float] = None, + block_K: int = 32, + dtype: str = FP32, + num_stages: int = 0, + num_threads: int = 128, +): + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_K == 0 + assert heads <= 64 and heads % 8 == 0 + assert num_stages == 0 + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + index_q_shape = [seq_len, heads, dim] + weights_shape = [seq_len, heads] + index_k_shape = [seq_len, dim] + topk_indices_shape = [seq_len, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + N = 2 * topk + num_iters = int(round(math.log2(N))) + if sm_scale is None: + sm_scale = dim**-0.5 + + @T.macro + def bitonic_sort( + topk_index_shared: T.SharedBuffer([N], dtype=INT32), + topk_value_shared: T.SharedBuffer([N], dtype=FP32), + ): + T.sync_threads() + for i1 in T.serial(num_iters): + for i2 in T.serial(i1 + 1): + for i in T.Parallel(N): + ascending = (i & (1 << (i1 + 1))) != 0 + j = i ^ (1 << (i1 - i2)) + if i < j and ( + (ascending and topk_value_shared[i] > topk_value_shared[j]) + or (not ascending and topk_value_shared[i] < topk_value_shared[j]) + ): + val = topk_value_shared[i] + topk_value_shared[i] = topk_value_shared[j] + topk_value_shared[j] = val + idx = topk_index_shared[i] + topk_index_shared[i] = topk_index_shared[j] + topk_index_shared[j] = idx + T.sync_threads() + + @T.prim_func + def tl_indexer_topk_reducesum_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + Weights: T.Tensor(weights_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + ReduceSum: T.Tensor(topk_indices_shape, FP32), + Offsets: T.Tensor(offsets_shape, INT32), + TokenIndices: T.Tensor(token_indices_shape, INT32), + ): + with T.Kernel(seq_len, threads=num_threads) as (bx): + i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] + bos, eos = Offsets[i_b], Offsets[i_b + 1] + num_blocks = T.ceildiv(i_t + 1, block_K) + + topk_index_shared = T.alloc_shared([N], dtype=INT32) + topk_value_shared = T.alloc_shared([N], dtype=FP32) + + T.fill(topk_index_shared, -1) + T.fill(topk_value_shared, float("-inf")) + T.sync_threads() + + index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) + T.copy(IndexQ[bos + i_t, :, :], index_q_shared) + T.sync_threads() + + weights_frag = T.alloc_shared([heads], dtype=dtype) + T.copy(Weights[bos + i_t, :], weights_frag) + T.sync_threads() + + for i, j in T.Parallel(heads, dim): + index_q_shared[i, j] = index_q_shared[i, j] * sm_scale + T.sync_threads() + + for bk_i in T.Pipelined(num_blocks, num_stages=num_stages): + k_st = bk_i * block_K + k_ed = T.min((bk_i + 1) * block_K, eos - bos) + + index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype) + for i, j in T.Parallel(block_K, dim): + index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0) + T.sync_threads() + + logits = T.alloc_fragment((block_K, heads), FP32) + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + T.sync_threads() + + for i, j in T.Parallel(block_K, heads): + logits[i, j] = T.max(logits[i, j], 0) * weights_frag[j] + T.sync_threads() + + logits_sum = T.alloc_fragment(block_K, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + T.sync_threads() + + offset = T.alloc_var(INT32) + if k_st >= topk: + offset = topk + (k_st % topk) + else: + offset = k_st + T.sync_threads() + for i in T.Parallel(block_K): + if k_st + i > i_t: + logits_sum[i] = float("-inf") + j = offset + i + topk_index_shared[j] = k_st + i + topk_value_shared[j] = logits_sum[i] + T.sync_threads() + + if k_ed > topk and k_ed % topk == 0: + bitonic_sort(topk_index_shared, topk_value_shared) + + bitonic_sort(topk_index_shared, topk_value_shared) + + logits_max_frag = T.alloc_fragment([1], dtype=FP32) + logits_frag = T.alloc_fragment([topk], dtype=FP32) + reducesum_shared = T.alloc_shared([topk], dtype=FP32) + + T.copy(topk_value_shared[:topk], logits_frag) + T.sync_threads() + + T.reduce_max(logits_frag, logits_max_frag, dim=-1) + T.sync_threads() + + for i in T.Parallel(topk): + logits_frag[i] = T.exp(logits_frag[i] - logits_max_frag[0]) + T.sync_threads() + + lse_frag = T.alloc_fragment([1], dtype=FP32) + T.reduce_sum(logits_frag, lse_frag) + T.sync_threads() + + for i in T.Parallel(topk): + reducesum_shared[i] = logits_frag[i] / lse_frag[0] + T.sync_threads() + + # for i in T.Parallel(topk): + # reducesum_shared[i] = logits_frag[i] + # T.sync_threads() + + for i in T.Parallel(topk): + if topk_index_shared[i] > i_t: + topk_index_shared[i] = -1 + T.sync_threads() + + T.copy(topk_index_shared[:topk], TopkIndices[bos + i_t, :]) + T.copy(reducesum_shared[:topk], ReduceSum[bos + i_t, :]) + + return tl_indexer_topk_reducesum_kernel + + +def indexer_topk_reducesum_interface( + q: torch.Tensor, + weights: torch.Tensor, + k: torch.Tensor, + topk: int, + offsets: torch.Tensor, + dtype: str = BF16, +): + seq_len, heads, dim = q.shape + kernel = tl_indexer_topk_reducesum_impl(heads=heads, dim=dim, topk=topk, dtype=dtype) + token_indices = prepare_token_indices(offsets) + topk_indices = torch.zeros((seq_len, topk), device=q.device, dtype=torch.int32) + topk_score = torch.zeros((seq_len, topk), device=q.device, dtype=torch.float32) + kernel(q, weights, k, topk_indices, topk_score, offsets, token_indices) + return topk_indices, topk_score + + +def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor: + all_topk_indices = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + assert (offsets[i + 1] - offsets[i]).item() >= topk + q = Q[offsets[i] : offsets[i + 1]] + weights = Weights[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + softmax_scale = q.shape[-1] ** -0.5 + s = q.shape[0] + mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) + logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") + logits = F.relu(logits) + logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale + logits = torch.where(mask, logits, float("-inf")) + topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) + topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32) + all_topk_indices.append(topk_indices) + all_topk_score.append(topk_score) + topk_indices = torch.cat(all_topk_indices, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return topk_indices, topk_score + + +def test_kernel( + B=1, + S=2048, + H=64, + D=128, + topk=64, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D)).cuda().bfloat16() + weights = torch.randn((S, H)).cuda().bfloat16() + k = torch.randn((S, D)).cuda().bfloat16() + offsets = torch.tensor([0, S], dtype=torch.int32).cuda() + + ref_topk_indices, ref_topk_score = ref_index_score(q, weights, k, topk, offsets) + + topk_indices, topk_score = indexer_topk_reducesum_interface(q, weights, k, topk, offsets) + + for j in range(S): + ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() + trt_np = topk_indices[j].cpu().to(torch.int32).numpy() + + ref_np_val = ref_topk_score[j] + trt_np_val = topk_score[j] + + mask = (ref_np_val > 0).cpu().numpy() + + set_ref = set(ref_np[mask]) + set_trt = set(trt_np[mask]) + intersection = set_ref & set_trt + + print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) + + print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..8b76dbca1c5fa483f57399e701a17bff870edd80 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -0,0 +1,354 @@ +# ruff: noqa +import tilelang +from tilelang import language as T +import torch +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit(out_idx=[-1]) +def preprocess( + H, + D, + block_ND=32, + num_stages=5, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + S = T.symbolic("S") + + shape = [S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o) + T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + D, + D_tail, + kv_group=1, + block_N=64, + threads=128, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + S_kv = T.symbolic("S_kv") + + dkv_shape = [S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by): + T.copy( + dKV[bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bx * block_N : (bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def bwd( + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=128, + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 + + if sm_scale is None: + sm_scale = (D + D_tail) ** (-0.5) + + B_plus_one = T.symbolic("B_plus_one") + S = T.symbolic("S") + + H_kv = H // kv_group + q_shape = [S, H, D + D_tail] + k_shape = [S, kv_group, D + D_tail] + o_shape = [S, H, D] + indices_shape = [S, kv_group, topk] + delta_shape = [S, H] + lse_shape = [S, H] + offsets_shape = [B_plus_one] + token_indices_shape = [S, 2] + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + Offsets: T.Tensor(offsets_shape, indices_dtype), + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz): + Q_shared = T.alloc_shared([padded_H, D], dtype) + Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([padded_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dP_shared_cast = T.alloc_shared([padded_H, BS], dtype) + dQ_shared = T.alloc_shared([padded_H, D], dtype) + dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype) + + acc_p = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([padded_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) + acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + + max_kv_i = s_i + + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared) + T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared) + T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + T.annotate_layout( + { + dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), + dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), + } + ) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1) + + # Compute attention scores + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i] + + T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i] + T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i]) + + T.copy(acc_p, P_shared_cast) + + T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + + for h_i, bi_i in T.Parallel(padded_H, BS): + acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale + + T.copy(acc_dp, dP_shared_cast) + T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) + + T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) + T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + T.clear(acc_dkv_tail) + T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS // split_store, D // 4): + T.atomic_addx4( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], + acc_dkv_shared[bi_i, d_i * 4], + ) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): + T.atomic_addx4( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D]) + T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None): + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + S, H, dim_plus_tail_dim = q.shape + S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert S == S_kv + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (S, kv_group, topk) + assert lse.shape == (S, H) + + token_indices = prepare_token_indices(offsets) + + # Get kernels + preprocess_kernel = preprocess(H, D) + bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv) + dkv = postprocess_kernel(dkv) + + return dq, dkv + + +def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True): + from sparse_mla_fwd import ref_sparse_mla_fwd_interface + + q = q.detach().clone() + kv = kv.detach().clone() + q.requires_grad = True + kv.requires_grad = True + o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual) + o.backward(do) + return q.grad, kv.grad + + +def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True): + # Prepare data + q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True) + do = torch.randn((S, H, DV), dtype=dtype, device="cuda") + offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, : len(i_i)] = i_i + + # Forward + from sparse_mla_fwd import sparse_mla_fwd_interface + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets) + + tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets) + + if check_correctness: + assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq") + assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") + print("assert_tensors_similar passed") + + per_token_flop = 2 * sum( + [ + H * DV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DQKV * topk, + H * DV * topk, + ] + ) + from tilelang.profiler import do_bench + + def fn(): + return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) + + ms = do_bench(fn, rep=100, warmup=250) + print(f"Average time: {ms:.3f} ms") + print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) + print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True) diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d87523695240ce3029c29e84c10c50cbfc4a39c8 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -0,0 +1,310 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from index import prepare_token_indices + +from utils import assert_tensors_similar + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 + else: + sm_scale = sm_scale + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + + head_kv = heads // kv_group + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len, kv_group, dim + tail_dim] + o_shape = [seq_len, heads, dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, Output[bos + s_i, H0:H1, :]) + T.copy(sumexp, Lse[bos + s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface( + q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128 +): + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + seq_len, heads, dim_plus_tail_dim = q.shape + seq_len_kv, kv_group, _ = kv.shape + assert seq_len == seq_len_kv + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + _, _, topk = indices.shape + assert indices.shape == (seq_len, kv_group, topk) + + token_indices = prepare_token_indices(offsets) + + kernel = sparse_mla_fwd( + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads + ) + out, lse = kernel(q, kv, indices, offsets, token_indices) + return out, lse + + +def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True): + Q = Q.float() + KV = KV.float() + all_o = [] + for i in range(offsets.shape[0] - 1): + q = Q[None, offsets[i] : offsets[i + 1]] + kv = KV[None, offsets[i] : offsets[i + 1]] + indices = Indices[None, offsets[i] : offsets[i + 1]].clone() + + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda" + ).view(1, -1) + + indices[indices > sk] = sk + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : 1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + all_o.append(o.squeeze(0)) + o = torch.cat(all_o, dim=0) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, +): + torch.random.manual_seed(0) + q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda") + + indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda") + for i in range(offsets.shape[0] - 1): + seq_len = (offsets[i + 1] - offsets[i]).item() + assert seq_len >= topk + for t in range(seq_len): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[offsets[i] + t, h, : len(i_i)] = i_i + + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + if check_correctness: + # otherwise may cause out of memory + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets) + assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out") + print("assert_tensors_similar passed") + + def fn(): + return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=100, + warmup=250, + ) + print(f"Average time: {ms:.3f} ms") + print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_mla_fwd( + B=1, + S=4096, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=1024, + dtype=torch.bfloat16, + check_correctness=True, + block_I=64, + num_stages=2, + threads=256, + ) diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py new file mode 100644 index 0000000000000000000000000000000000000000..a03bc74f51e254b8cd9232eebc91bc9c6f0fa4c9 --- /dev/null +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -0,0 +1,226 @@ +# ruff: noqa +import torch +import torch.nn as nn +import torch.nn.functional as F +import tilelang +from tilelang import language as T +from einops import repeat, rearrange, einsum +from index import prepare_token_indices +from utils import get_abs_err, get_err_ratio + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tilelang.jit(pass_configs=pass_configs) +def tl_sparse_mla_topk_reducesum_impl( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + block_I=32, + num_stages=2, + threads=128, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 + + batch_plus_one = T.symbolic("batch_plus_one") + seq_len = T.symbolic("seq_len") + seq_len_kv = T.symbolic("seq_len_kv") + + head_kv = heads // kv_group + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + q_shape = [seq_len, heads, dim + tail_dim] + kv_shape = [seq_len_kv, kv_group, dim + tail_dim] + indices_shape = [seq_len, kv_group, topk] + lse_shape = [seq_len, heads] + reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk] + offsets_shape = [batch_plus_one] + token_indices_shape = [seq_len, 2] + + @T.prim_func + def tl_sparse_mla_topk_reducesum_kernel( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore + TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore + ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as ( + bx, + by, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + reducesum = T.alloc_fragment([BI], accum_dtype) + lse = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(lse, 0) + + b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] + bos, eos = Offsets[b_i], Offsets[b_i + 1] + r_i = bx % REPLICATE_H + g_i = by + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[bos + s_i, H0:H1, :D], Q_shared) + T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) + T.copy(Lse[bos + s_i, H0:H1], lse) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1) + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i]) + T.reduce_sum(acc_s, reducesum, dim=0) + T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI]) + + return tl_sparse_mla_topk_reducesum_kernel + + +def sparse_mla_topk_reducesum_interface( + q: torch.Tensor, + kv: torch.Tensor, + topk_indices: torch.Tensor, + lse: torch.Tensor, + offsets: torch.Tensor, + dim_v: int, +): + assert kv.shape[-2] == 1 + seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1] + REPLICATE_H = max(heads // 64, 1) + tail_dim = dim_plus_tail_dim - dim_v + token_indices = prepare_token_indices(offsets) + + reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device) + kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk) + kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum) + reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk] + attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True) + + return attn_score + + +def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor): + # q: [batch, seq_len, heads, dim] + # k: [batch, seq_len, dim] + sm_scale = Q.shape[-1] ** -0.5 + all_lse = [] + all_topk_score = [] + for i in range(offsets.shape[0] - 1): + q = Q[offsets[i] : offsets[i + 1]] + k = K[offsets[i] : offsets[i + 1]] + topk_indices = TopkIndices[offsets[i] : offsets[i + 1]] + seq_len = q.shape[0] + mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() + logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale + logits = torch.where(mask, logits, float("-inf")) + score = F.softmax(logits, dim=-1, dtype=torch.float32) + score_sum = score.sum(dim=-2) + topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64)) + topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True) + max_logits = logits.amax(dim=-1).to(torch.float32) + lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits + all_lse.append(lse) + all_topk_score.append(topk_score) + lse = torch.cat(all_lse, dim=0) + topk_score = torch.cat(all_topk_score, dim=0) + return lse, topk_score + + +def test_kernel( + B=1, + S=2048, + H=16, + D=512, + tail_D=64, + topk=128, +): + torch.manual_seed(42) + + q = torch.randn((S, H, D + tail_D)).cuda().bfloat16() + kv = torch.randn((S, D + tail_D)).cuda().bfloat16() + offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() + + topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous() + + lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets) + + kv = kv.unsqueeze(-2) + topk_indices = topk_indices.unsqueeze(-2) + + attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) + print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}") + + +if __name__ == "__main__": + test_kernel() diff --git a/examples/dsa_sparse_finetune/utils.py b/examples/dsa_sparse_finetune/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96afd064dc0f83f0e813fa4093f10d2fd309dfce --- /dev/null +++ b/examples/dsa_sparse_finetune/utils.py @@ -0,0 +1,73 @@ +import torch + + +def get_abs_err(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + return (x - y).flatten().abs().max().item() + + +def get_err_ratio(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + err = (x - y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") + if raise_assert: + assert False # noqa: B011 diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py new file mode 100644 index 0000000000000000000000000000000000000000..f075c64fd669ff4ce15ae518b3e17998aad8edae --- /dev/null +++ b/examples/elementwise/example_elementwise_add.py @@ -0,0 +1,62 @@ +import argparse +import itertools +import torch +import tilelang +import tilelang.language as T + + +def ref_program(x, y): + return x + y + + +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[-1]) +def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): + @T.prim_func + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), in_dtype) + B_shared = T.alloc_shared((block_M, block_N), in_dtype) + C_local = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(B[by * block_M, bx * block_N], B_shared) + for local_y, local_x in T.Parallel(block_M, block_N): + C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return elem_add + + +def main(M=1024, N=1024, use_autotune=False): + a = torch.randn(M, N, dtype=torch.float32, device="cuda") + b = torch.randn(M, N, dtype=torch.float32, device="cuda") + + if use_autotune: + kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32) + else: + # Default config + config = {"block_M": 32, "block_N": 32, "threads": 128} + kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32) + + out = kernel(a, b) + torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=1024) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + main(args.m, args.n, args.use_autotune) diff --git a/examples/elementwise/test_example_elementwise.py b/examples/elementwise/test_example_elementwise.py new file mode 100644 index 0000000000000000000000000000000000000000..24f675cd6a3778280ce1a52c1b6e6ca54aa8393c --- /dev/null +++ b/examples/elementwise/test_example_elementwise.py @@ -0,0 +1,14 @@ +import tilelang.testing +import example_elementwise_add + + +def test_example_elementwise_add(): + example_elementwise_add.main() + + +def test_example_elementwise_add_autotune(): + example_elementwise_add.main(use_autotune=True) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/flash_attention/README.md b/examples/flash_attention/README.md new file mode 100644 index 0000000000000000000000000000000000000000..633727ec4e9270b66176db82d3e13f430895c33a --- /dev/null +++ b/examples/flash_attention/README.md @@ -0,0 +1,111 @@ +# FlashAttention + +Using tile-lang, we can define buffers at different memory layers. For instance, `Q_shared`, `K_shared`, and `V_shared` can be defined in shared memory, while `acc_s` and `acc_o` can be placed in registers. This flexibility allows us to represent a complex fusion pattern like FlashAttention in a simple way. + +```python +@T.prim_func +def flash_attention( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), +): + # Launch a specialized T.Kernel with 3D mapping: (bx, by, bz) + # bx: block index in sequence dimension + # by: block index in "heads" dimension + # bz: block index in "batch" dimension + # threads=thread_num means how many threads per block + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): + # Allocate shared memory for Q, K, V to reduce global memory accesses + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + # Allocate buffers on register + # acc_s: buffer to hold intermediate attention scores + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + # acc_s_cast: buffer for storing casted/adjusted scores + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + # acc_o: partial accumulation of output + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + # Buffers to track per-row maximum score and related stats + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + # Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + + # Copy a block of Q from global memory to Q_shared + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + + # Initialize accumulators + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = ( + T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + ) + + # Pipeline the loop to overlap copies/gemm stages + for k in T.Pipelined(loop_range, num_stages=num_stages): + # Copy K block into shared memory + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) + ) + else: + T.clear(acc_s) + + # Perform the Q*K^T multiplication, Here, transpose_B=True indicates that K_shared is transposed, + # policy=T.GemmWarpPolicy.FullRow means each warp is responsible for computing an entire row + # of acc_s, and the resulting acc_s is retained in registers. + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Copy V block into shared memory + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + for i, j in T.Parallel(block_M, dim): + acc_s[i, j] *= scale + + # Save old scores_max, then reset scores_max + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # Compute the maximum value per row on dimension 1 (block_N) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # Compute the factor by which we need to rescale previous partial sums + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + + # Rescale the partial output accumulation to keep exponents consistent + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + # Exponentiate (scores - max) for the new block + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + + # Make a cast of acc_s to fp16 for the next GEMM + T.copy(acc_s, acc_s_cast) + + # Multiply the attention acc_s_cast by V and add to partial output (acc_o) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + # Update the "logsum" tracker with the newly accumulated sum + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + # Final step: divide each partial output by logsum (completing the softmax) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + + # Write back the final output block from acc_o to the Output buffer + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) +``` \ No newline at end of file diff --git a/examples/flash_attention/bert_padding.py b/examples/flash_attention/bert_padding.py new file mode 100644 index 0000000000000000000000000000000000000000..15c4097ce77a21ebcd2060b53c629e7a89972b88 --- /dev/null +++ b/examples/flash_attention/bert_padding.py @@ -0,0 +1,205 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py +# ruff: noqa +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): + """ + Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). + The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). + + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: + ``` + [ + [2, 3, 0, 0, 0, 0], + [3, 2, 0, 0, 0, 0], + [6, 0, 0, 0, 0, 0] + ] + ``` + , which refers to the 3D-attention mask: + ``` + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1] + ] + ] + ```. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + length = attention_mask_in_length.sum(dim=-1) + seqlen = attention_mask_in_length.size(-1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) + real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() + seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] + indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz) + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..89c1166693c672a8fd0021419837c950a0651df9 --- /dev/null +++ b/examples/flash_attention/example_gqa_bwd.py @@ -0,0 +1,514 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim_v] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim_v, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim_qk] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) + + return flash_bwd + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel + dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + + T.copy(dv, dv_shared) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + T.copy(dk, dk_shared) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + + return flash_bwd + + +@torch.compile +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): + BATCH, N_CTX, H, D_HEAD_QK = q.shape + D_HEAD_V = v.shape[-1] + block_M = 128 + block_N = 64 + mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o, lse = mod(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + ctx.use_atomic = use_atomic + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD_QK = q.shape + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] + groups = H // HEAD_KV + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 128 + block_N = 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK) + delta = mod_prep(o, do) + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] + shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + else: + kernel = flashattn_bwd_split( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel + shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) + dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk, dv = dk.sum(0), dv.sum(0) + + return dq, dk, dv, None, None, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D_QK] + # K: [B, T, HK, D_QK] + # V: [B, T, HV, D_V] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim_qk = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): + flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK + flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V + total_flops = 3 * flops_per_qk + 2 * flops_per_v + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + O = attention(Q, K, V, causal, groups, use_atomic) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal, groups) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.โœ…") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") + args = parser.parse_args() + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..07586f99fdd7d55d52a59359337c423dcca96a6f --- /dev/null +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -0,0 +1,535 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.contrib import nvcc +import argparse + +tilelang.disable_cache() + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops + # We should set it to negative large number instead + T.fill(scores_max, T.Cast(accum_dtype, -1e30)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim_v] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim_v, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # bshd -> bhld to use tma reduction instruction + return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d]) + + +@tilelang.jit( + out_idx=[3, 4, 5], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :]) + with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz): + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :]) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.copy(dq, dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True) + + return flash_bwd + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel + dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + + T.copy(dv, dv_shared) + T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + T.copy(dk, dk_shared) + T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + + return flash_bwd + + +@torch.compile +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): + BATCH, N_CTX, H, D_HEAD_QK = q.shape + D_HEAD_V = v.shape[-1] + block_M = 128 + block_N = 64 + mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o, lse = mod(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + ctx.use_atomic = use_atomic + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD_QK = q.shape + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] + groups = H // HEAD_KV + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 128 + block_N = 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V) + delta = mod_prep(o, do) + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] + shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq, dk, dv = mod_post(dq, dk, dv) + else: + kernel = flashattn_bwd_split_novarlen( + BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups + ) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel + shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) + dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) + dk, dv = dk.sum(0), dv.sum(0) + + return dq, dk, dv, None, None, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D_QK] + # K: [B, T, HK, D_QK] + # V: [B, T, HV, D_V] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim_qk = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): + flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK + flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V + total_flops = 3 * flops_per_qk + 2 * flops_per_v + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + O = attention(Q, K, V, causal, groups, use_atomic) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal, groups) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.โœ…") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + arch = nvcc.get_target_compute_version() + print(f"Detected GPU compute capability: {arch}") + assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") + args = parser.parse_args() + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..cc88b64da7a44ffc9b95f09bb8a1cd45eb681136 --- /dev/null +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -0,0 +1,730 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.contrib import nvcc +import argparse +from einops import rearrange, repeat +from bert_padding import pad_input, unpad_input + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +@tilelang.jit( + out_idx=[5, 6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [total_q, heads, dim_qk] + k_shape = [total_kv, head_kv, dim_qk] + v_shape = [total_kv, head_kv, dim_v] + o_shape = [total_q, heads, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + k_start_idx = cu_seqlens_k[bz] + q_end_idx = cu_seqlens_q[bz + 1] + k_end_idx = cu_seqlens_k[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + + for i, d in T.Parallel(block_M, dim_qk): + if bx * block_M + i < q_current_seqlen: + Q_shared[i, d] = Q[q_start_idx + bx * block_M + i, by, d] + else: + Q_shared[i, d] = 0.0 + + T.fill(acc_o, 0.0) + T.fill(logsum, 0.0) + # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops + # We should set it to negative large number instead + T.fill(scores_max, T.Cast(accum_dtype, -1e30)) + loop_range = T.ceildiv(k_current_seqlen, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + for i, d in T.Parallel(block_N, dim_qk): + if k * block_N + i < k_current_seqlen: + K_shared[i, d] = K[k_start_idx + k * block_N + i, by // groups, d] + else: + K_shared[i, d] = 0.0 + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen), + 0, + T.Cast(accum_dtype, -1e30), + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30) + ) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, d in T.Parallel(block_N, dim_v): + if k * block_N + i < k_current_seqlen: + V_shared[i, d] = V[k_start_idx + k * block_N + i, by // groups, d] + else: + V_shared[i, d] = 0.0 + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] /= logsum[i] + + for i, d in T.Parallel(block_M, dim_v): + if bx * block_M + i < q_current_seqlen: + Output[q_start_idx + bx * block_M + i, by, d] = acc_o[i, d] + + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + if bx * block_M + i < q_current_seqlen: + lse[bz, by, bx * block_M + i] = logsum[i] + + return flash_fwd + + +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + shape = [total_q, heads, dim_v] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + q_end_idx = cu_seqlens_q[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + + T.clear(acc) + for k in range(T.ceildiv(dim_v, blk)): + for i, j in T.Parallel(blk, blk): + if by * blk + i < q_current_seqlen and k * blk + j < dim_v: + o[i, j] = O[q_start_idx + by * blk + i, bx, k * blk + j] + do[i, j] = dO[q_start_idx + by * blk + i, bx, k * blk + j] + else: + o[i, j] = 0.0 + do[i, j] = 0.0 + + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + + for i in T.Parallel(blk): + if by * blk + i < q_current_seqlen: + Delta[bz, bx, by * blk + i] = delta[i] + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # bshd -> bhsd to use tma reduction instruction + return T.Layout(dQ.shape, lambda l, h, d: [h, l, d]) + + +@tilelang.jit( + out_idx=[3, 4, 5], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + q_shape = [total_q, heads, dim_qk] + k_shape = [total_kv, head_kv, dim_qk] + v_shape = [total_kv, head_kv, dim_v] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(q_shape, dtype), # type: ignore + dK_out: T.Tensor(k_shape, dtype), # type: ignore + dV_out: T.Tensor(v_shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy(dQ[bx * blk : (bx + 1) * blk, by, :], dQ_out[bx * blk : (bx + 1) * blk, by, :]) + with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by): + T.annotate_layout( + { + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + } + ) + T.copy(dK[bx * blk : (bx + 1) * blk, by, :], dK_out[bx * blk : (bx + 1) * blk, by, :]) + T.copy(dV[bx * blk : (bx + 1) * blk, by, :], dV_out[bx * blk : (bx + 1) * blk, by, :]) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_atomic_add( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [total_q, heads, dim_qk] + k_shape = [total_kv, head_kv, dim_qk] + v_shape = [total_kv, head_kv, dim_v] + do_shape = [total_q, heads, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + k_start_idx = cu_seqlens_k[bz] + q_end_idx = cu_seqlens_q[bz + 1] + k_end_idx = cu_seqlens_k[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + dK: make_dq_layout(dK), + dV: make_dq_layout(dV), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + } + ) + + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) + + T.clear(dv) + T.clear(dk) + + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 + loop_ed = T.ceildiv(q_current_seqlen, block_N) + + for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) + + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) + + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) + T.clear(dsT) + # dsT: (block_kv, block_q) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.copy(dq, dq_shared) + T.atomic_add( + dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :], + dq_shared, + memory_order="relaxed", + use_tma=True, + ) + + T.copy(dv, dv_shared) + T.atomic_add( + dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], + dv_shared, + memory_order="relaxed", + use_tma=True, + ) + T.copy(dk, dk_shared) + T.atomic_add( + dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :], + dk_shared, + memory_order="relaxed", + use_tma=True, + ) + + return flash_bwd + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd_split( + batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1 +): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [total_q, heads, dim_qk] + k_shape = [total_kv, head_kv, dim_qk] + v_shape = [total_kv, head_kv, dim_v] + do_shape = [total_q, heads, dim_v] + dk_shape = [groups, total_kv, head_kv, dim_qk] # sum after kernel + dv_shape = [groups, total_kv, head_kv, dim_v] # sum after kernel + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor(do_shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(dk_shape, dtype), # type: ignore + dV: T.Tensor(dv_shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], dtype) + + q_start_idx = cu_seqlens_q[bz] + k_start_idx = cu_seqlens_k[bz] + q_end_idx = cu_seqlens_q[bz + 1] + k_end_idx = cu_seqlens_k[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + + T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) + + T.clear(dv) + T.clear(dk) + loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0 + loop_ed = T.ceildiv(q_current_seqlen, block_N) + + for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + # Note: The padding zero of varlen should be considered in T.copy + T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q) + + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do) + + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k_base * block_N + j) + and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0 + ) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + if k_base * block_N + i < q_current_seqlen: + T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j], memory_order="relaxed") + + T.copy(dv, dv_shared) + T.copy(dv_shared, dV[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) + T.copy(dk, dk_shared) + T.copy(dk_shared, dK[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :]) + + return flash_bwd + + +@torch.compile +class _attention(torch.autograd.Function): + @staticmethod + def forward( + ctx, q, k, v, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups=1, use_atomic=True + ): + BATCH, N_CTX, H, D_HEAD_QK = q.shape + D_HEAD_V = v.shape[-1] + block_M = 128 + block_N = 64 + q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) + + total_q = q_unpad.shape[0] + total_kv = k_unpad.shape[0] + + mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) + o = pad_input(o_unpad, indices_q, BATCH, N_CTX) + ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) + ctx.batch = BATCH + ctx.causal = causal + ctx.use_atomic = use_atomic + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.indices_q = indices_q + ctx.indices_k = indices_k + return o + + @staticmethod + def backward(ctx, do): + N_CTX = do.shape[1] + q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + # lse_clone = lse.clone() + do_unpad, _, _, _ = unpad_input(do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) + total_q, H, D_HEAD_QK = q.shape + total_kv, HEAD_KV, D_HEAD_V = v.shape + groups = H // HEAD_KV + BATCH = len(cu_seqlens_q) - 1 + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)] + block_M = 128 + block_N = 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V) + mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) + delta = mod_prep(o, do, cu_seqlens_q) + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add( + BATCH, + total_q, + total_kv, + N_CTX, + H, + ctx.max_seqlen_q, + D_HEAD_QK, + D_HEAD_V, + ctx.causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=groups, + ) + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.zeros_like(k, dtype=torch.float32) + dv = torch.zeros_like(v, dtype=torch.float32) + kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) + dq, dk, dv = mod_post(dq, dk, dv) + else: + kernel = flashattn_bwd_split( + BATCH, + total_q, + total_kv, + N_CTX, + H, + ctx.max_seqlen_q, + D_HEAD_QK, + D_HEAD_V, + ctx.causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=groups, + ) + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) + dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) + dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32)) + dk, dv = dk.sum(0), dv.sum(0) + + dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX) + dk = pad_input(dk, ctx.indices_k, BATCH, N_CTX) + dv = pad_input(dv, ctx.indices_k, BATCH, N_CTX) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, padding_mask, is_causal, groups=1): + # Q: [B, T, HQ, D_QK] + # K: [B, T, HK, D_QK] + # V: [B, T, HV, D_V] + # HQ = HKV * groups + # To handle precision issue + Q, K, V = Q.float(), K.float(), V.float() + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim_qk = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) + if padding_mask is not None: + scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf")) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + if padding_mask is not None: + output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) + return output + + +def main( + BATCH: int = 1, + H: int = 32, + N_CTX: int = 256, + D_HEAD_QK: int = 192, + D_HEAD_V: int = 128, + groups: int = 16, + causal: bool = False, + use_atomic: bool = True, +): + flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK + flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V + total_flops = 3 * flops_per_qk + 2 * flops_per_v + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random") + seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32) + cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) + max_seqlen_q = seqlens_q.max().item() + + # In training backward pass, seqlens_k should be the same as seqlens_q + seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q + + O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, padding_mask, causal, groups) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.โœ…") + print( + "Note: this varlen kernel performance is as good as the non-varlen kernel shown in Nsight-Compute. As you may observe that the TFLOPS is a bit lower, that's because the unpad operation is included in the above benchmark." + ) + + +if __name__ == "__main__": + arch = nvcc.get_target_compute_version() + print(f"Detected GPU compute capability: {arch}") + assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV") + parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV") + args = parser.parse_args() + # Can be set to True/False for testing + args.causal = True + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e2de27752cee99c77fab091b599a4de5e65928 --- /dev/null +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -0,0 +1,353 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim_qk], dtype) + K_shared = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim_v): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim_v] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim_v, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk) ** 0.5 + scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.wait_wgmma(1) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.wait_wgmma(0) + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) + T.wait_wgmma(0) + T.copy(dq, dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) + + return flash_bwd + + +@torch.compile +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): + BATCH, N_CTX, H, D_HEAD_QK = q.shape + D_HEAD_V = v.shape[-1] + block_M = 128 + block_N = 64 + mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o, lse = mod(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + ctx.use_atomic = use_atomic + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD_QK = q.shape + ( + HEAD_KV, + D_HEAD_V, + ) = v.shape[-2], v.shape[-1] + groups = H // HEAD_KV + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 128 + block_N = 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + delta = mod_prep(o, do) + + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] + shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = dq.to(torch.float16) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + + return dq, dk, dv, None, None, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D_QK] + # K: [B, T, HK, D_QK] + # V: [B, T, HV, D_V] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim_qk = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False): + flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK + flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V + total_flops = 3 * flops_per_qk + 2 * flops_per_v + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + + head_kv = H // groups + K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_() + V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_() + O = attention(Q, K, V, causal, groups) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal, groups) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.โœ…") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K") + parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V") + parser.add_argument("--causal", action="store_true", help="Causal flag") + parser.add_argument("--groups", type=int, default=16, help="groups") + args = parser.parse_args() + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py new file mode 100644 index 0000000000000000000000000000000000000000..5005435eaf7cb2349f09d8900a4718e9f84f52e6 --- /dev/null +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -0,0 +1,256 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +class FlashAttentionTuneSpace: + def __init__( + self, + block_sizes=(64, 128, 256), + thread_options=(128, 256, 512), + num_stages_range=(2, 3), + max_shared_mem=100 * 1024, + warp_alignment=16, + dim=128, + dtype_bytes=2, + ): + self.block_sizes = block_sizes + self.thread_options = thread_options + self.num_stages_range = num_stages_range + self.max_shared_mem = max_shared_mem + self.warp_alignment = warp_alignment + self.dim = dim + self.dtype_bytes = dtype_bytes + + +def get_configs(user_config=None): + config = user_config or FlashAttentionTuneSpace() + valid_configs = [] + + for block_M, block_N in itertools.product(config.block_sizes, repeat=2): + for threads in config.thread_options: + assert threads % 32 == 0 + warp_count = threads // 32 + warp_M = block_M // warp_count + warp_N = block_N // warp_count + + if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0: + continue + + shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N) + if shared_mem > config.max_shared_mem: + continue + + for num_stages in config.num_stages_range: + valid_configs.append( + { + "block_M": block_M, + "block_N": block_N, + "num_stages": num_stages, + "threads": threads, + } + ) + return valid_configs + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D] + # K: [B, T, HK, D] + # V: [B, T, HV, D] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128) + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7a71b1780ba0ea370d01f5f72379e59491f76e --- /dev/null +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -0,0 +1,243 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict( + block_M=[128], + block_N=[128], + num_stages=[2], + threads=[256], + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune( + configs=get_configs(), + warmup=10, + rep=10, +) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups=1, + block_M=64, + block_N=64, + num_stages=0, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D] + # K: [B, T, HK, D] + # V: [B, T, HV, D] + # HQ = HKV * groups + assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 64, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + groups: int = 16, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + parser.add_argument("--groups", type=int, default=16, help="groups") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..b02345d93084843074d0924a4e945424bf104ca7 --- /dev/null +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -0,0 +1,253 @@ +# ruff: noqa +import argparse +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +from einops import rearrange, repeat +from tilelang.profiler import do_bench +from varlen_utils import generate_random_padding_mask, generate_qkv + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), + upcast=True, +): + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + b, T, Hq, D = q.shape + S = k.shape[1] + scale = (1.0 / D) ** 0.5 + k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) + scores = torch.einsum("bthd,bshd->bhts", q, k) + left, right = window_size + left = S if left is None or left < 0 else int(left) + right = S if right is None or right < 0 else int(right) + t_idx = torch.arange(T, device=scores.device)[:, None] + s_idx = torch.arange(S, device=scores.device)[None, :] + visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right)) + visible_mask = visible_ts.unsqueeze(0).unsqueeze(0) + if key_padding_mask is not None: + k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s") + visible_mask = visible_mask & k_keep + neg_inf = torch.finfo(scores.dtype).min + scores = scores * scale + scores = scores.masked_fill(~visible_mask, neg_inf) + attention = torch.softmax(scores, dim=-1).to(v.dtype) + if query_padding_mask is not None: + q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1") + attention = attention.masked_fill(~q_keep, 0.0) + output = torch.einsum("bhts,bshd->bthd", attention, v) + if query_padding_mask is not None: + output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + o_shape = [UQ, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) + if is_causal + else T.ceildiv(kv_current_seqlen, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i < k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), + -1e9, + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0 + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + + return main + + +def main( + batch: int = 1, heads: int = 64, q_seqlen: int = 2048, k_seqlen: int = 2048, dim: int = 128, groups: int = 16, is_causal: bool = False +): + assert heads % groups == 0, "heads must be divisible by groups" + + flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim + total_flops = 2 * flops_per_matmul + + tilelang.testing.set_random_seed(0) + + if is_causal: + total_flops *= 0.5 + + tilelang.testing.set_random_seed(0) + + dtype = torch.float16 + device = torch.device("cuda") + + head_kv = heads // groups + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device) + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + + query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + UQ = q_unpad.shape[0] + UKV = k_unpad.shape[0] + + kernel = flashattn(batch, groups, UQ, UKV, heads, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + + out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + out = output_pad_fn(out_unpad) + + out_ref, _ = attention_ref( + q, + k, + v, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + causal=is_causal, + ) + torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.โœ…") + latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") + args = parser.parse_args() + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal) diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..835a315965db00752f4096a60a2de4a4db10bf68 --- /dev/null +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -0,0 +1,363 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + # Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + # T.copy(Q_shared, Q_local) + # for i, j in T.Parallel(block_M, dim): + # Q_local[i, j] *= scale + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, heads, seq_len, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, by, bx * blk : (bx + 1) * blk, :], + dQ_out[bz, by, bx * blk : (bx + 1) * blk, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + # should not store K to local if dim is large + # K_local = T.alloc_fragment([block_M, dim], dtype) + # K_local_T = T.alloc_fragment([block_M, dim], dtype) + # V_local = T.alloc_fragment([block_M, dim], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + } + ) + T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared) + T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :]) + T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :]) + + return flash_bwd + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal): + BATCH, H, N_CTX, D_HEAD = q.shape + block_M = 64 + block_N = 64 if D_HEAD <= 128 else 32 + o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, H, N_CTX, D_HEAD = q.shape + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) + delta = kernel_prep(o, do) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) + shape = [BATCH, H, N_CTX, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.empty(shape, dtype=torch.float16, device=q.device) + dv = torch.empty(shape, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + return dq, dk, dv, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(2) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) + return output + + +def main( + BATCH: int = 8, + H: int = 32, + N_CTX: int = 1024, + D_HEAD: int = 64, + causal: bool = False, +): + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 5 * flops_per_matmul + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() + K = torch.empty_like(Q).normal_().requires_grad_() + V = torch.empty_like(Q).normal_().requires_grad_() + dO = torch.randn_like(Q) + O = attention(Q, K, V, causal) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + + print("All checks passed.โœ…") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd_bshd.py b/examples/flash_attention/example_mha_bwd_bshd.py new file mode 100644 index 0000000000000000000000000000000000000000..c0620bde0e95480d907fe94d9909ee3c30348860 --- /dev/null +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -0,0 +1,354 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + # Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment + return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dQ_out: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + # should not store K to local if dim is large + # K_local = T.alloc_fragment([block_M, dim], dtype) + # K_local_T = T.alloc_fragment([block_M, dim], dtype) + # V_local = T.alloc_fragment([block_M, dim], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) + + return flash_bwd + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal): + BATCH, N_CTX, H, D_HEAD = q.shape + block_M = 64 + block_N = 64 if D_HEAD <= 128 else 32 + o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD = q.shape + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) + delta = kernel_prep(o, do) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) + shape = [BATCH, N_CTX, H, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.empty(shape, dtype=torch.float16, device=q.device) + dv = torch.empty(shape, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + return dq, dk, dv, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + BATCH: int = 8, + H: int = 32, + N_CTX: int = 1024, + D_HEAD: int = 64, + causal: bool = False, +): + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 5 * flops_per_matmul + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() + K = torch.empty_like(Q).normal_().requires_grad_() + V = torch.empty_like(Q).normal_().requires_grad_() + dO = torch.randn_like(Q) + O = attention(Q, K, V, causal) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + from tilelang.profiler import do_bench + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..34a8d69ce475f7126e2635ae91f9a71159aa0d91 --- /dev/null +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -0,0 +1,331 @@ +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +import argparse + + +@tilelang.jit( + out_idx=[3, 4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_fwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + Output: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return flash_fwd + + +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.float16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return flash_bwd_prep + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, dtype), # type: ignore + dV: T.Tensor(shape, dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + # should not store K to local if dim is large + # K_local = T.alloc_fragment([block_M, dim], dtype) + # K_local_T = T.alloc_fragment([block_M, dim], dtype) + # V_local = T.alloc_fragment([block_M, dim], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + dq_shared = T.alloc_shared([block_N, dim], accum_dtype) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + } + ) + + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.wait_wgmma(1) + + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. + T.wait_wgmma(0) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) + T.wait_wgmma(0) + T.copy(dq, dq_shared) + T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) + + return flash_bwd + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal): + BATCH, N_CTX, H, D_HEAD = q.shape + block_M = 64 + block_N = 64 if D_HEAD <= 128 else 32 + mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + o, lse = mod(q, k, v) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + BATCH, N_CTX, H, D_HEAD = q.shape + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] + block_M = 128 + block_N = 128 if D_HEAD <= 64 else 32 + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + delta = mod_prep(o, do) + mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) + shape = [BATCH, N_CTX, H, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.empty(shape, dtype=torch.float16, device=q.device) + dv = torch.empty(shape, dtype=torch.float16, device=q.device) + mod(q, k, v, do, lse, delta, dq, dk, dv) + dq = dq.to(torch.float16) + return dq, dk, dv, None + + +attention = _attention.apply + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + BATCH: int = 8, + H: int = 32, + N_CTX: int = 1024, + D_HEAD: int = 64, + causal: bool = False, +): + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 5 * flops_per_matmul + if causal: + total_flops *= 0.5 + Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() + K = torch.empty_like(Q).normal_().requires_grad_() + V = torch.empty_like(Q).normal_().requires_grad_() + dO = torch.randn_like(Q) + O = attention(Q, K, V, causal) + O.backward(dO, retain_graph=True) + dQ, Q.grad = Q.grad.clone(), None + dK, K.grad = K.grad.clone(), None + dV, V.grad = V.grad.clone(), None + + O_ref = ref_program(Q, K, V, causal) + O_ref.backward(dO, retain_graph=True) + dQ_ref, Q.grad = Q.grad.clone(), None + dK_ref, K.grad = K.grad.clone(), None + dV_ref, V.grad = V.grad.clone(), None + + assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) + assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print("All checks passed.โœ…") + + def run(): + O_ref.backward(dO, retain_graph=True) + + def run1(): + O.backward(dO, retain_graph=True) + + latency = do_bench(run, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(run1, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="Batch size") + parser.add_argument("--h", type=int, default=32, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") + parser.add_argument("--d_head", type=int, default=64, help="Head dimension") + parser.add_argument("--causal", type=bool, default=False, help="Causal flag") + args = parser.parse_args() + main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..e70d17bf8c9adc9b12d90111e3fe8906cccc5ba0 --- /dev/null +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -0,0 +1,220 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_q = Q.size(2) + seq_kv = K.size(2) + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) + ref_program_processed = partial(ref_program, is_causal=is_causal) + + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=1, help="heads") + parser.add_argument("--seq_q", type=int, default=256, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal", default=False) + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c4d81ece8607c4499798d58421c067dc60b518 --- /dev/null +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -0,0 +1,224 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_q = Q.size(2) + seq_kv = K.size(2) + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + ref_program_processed = partial(ref_program, is_causal=is_causal) + + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_q", type=int, default=4096, help="query sequence length") + parser.add_argument("--seq_kv", type=int, default=4096, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py new file mode 100644 index 0000000000000000000000000000000000000000..248073f797d7b41d6b223f6de25c02f5b486de5b --- /dev/null +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -0,0 +1,205 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 8, + heads: int = 32, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) + ref_program_processed = partial(ref_program, is_causal=is_causal) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = best_result.latency + best_config = best_result.config + ref_latency = best_result.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py new file mode 100644 index 0000000000000000000000000000000000000000..ab2aab44f496f7ae422d17f5d2a1baf0c057116e --- /dev/null +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -0,0 +1,211 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]], + ): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 8, + heads: int = 32, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + ref_program_processed = partial(ref_program, is_causal=is_causal) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--is_causal", action="store_true", help="causal") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba2e8ab472d371cedcf1fe2d7203e25c85ec099 --- /dev/null +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -0,0 +1,288 @@ +# ruff: noqa +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +import argparse + +import torch +from einops import rearrange, repeat +from varlen_utils import generate_random_padding_mask, generate_qkv + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + dim = q.shape[-1] + scale = (1.0 / dim) ** 0.5 # log2(e) + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + scores = torch.einsum("bthd,bshd->bhts", q, k) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) + scores = scores * scale + attention = torch.softmax(scores, dim=-1).to(v.dtype) + + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + output = torch.einsum("bhts,bshd->bthd", attention, v) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +@tilelang.jit( + out_idx=[6], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=32): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [UQ, heads, dim] + k_shape = [UKV, heads, dim] + v_shape = [UKV, heads, dim] + o_shape = [UQ, heads, dim] + + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(k_shape, dtype), + V_unpad: T.Tensor(v_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype, "shared") + K_shared = T.alloc_shared([block_N, dim], dtype, "shared") + V_shared = T.alloc_shared([block_N, dim], dtype, "shared") + O_shared = T.alloc_shared([block_M, dim], dtype, "shared") + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + batch_idx = bz + head_idx = by + + q_start_idx = cu_seqlens_q[batch_idx] + k_start_idx = cu_seqlens_k[batch_idx] + v_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + v_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + k_current_seqlen = k_end_idx - k_start_idx + v_current_seqlen = v_end_idx - v_start_idx + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Q_shared[i, d] = Q_unpad[q_start_idx + bx * block_M + i, head_idx, d] + else: + Q_shared[i, d] = 0 + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(k_current_seqlen, block_N) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + # Q * K + for i, d in T.Parallel(block_N, dim): + if k * block_N + i < k_current_seqlen: + K_shared[i, d] = K_unpad[k_start_idx + k * block_N + i, head_idx, d] + else: + K_shared[i, d] = 0 + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= k * block_N + j) + and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0 + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + # Rescale + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + # V * softmax(Q * K) + for i, d in T.grid(block_N, dim): + if k * block_N + i < v_current_seqlen: + V_shared[i, d] = V_unpad[v_start_idx + k * block_N + i, head_idx, d] + else: + V_shared[i, d] = 0 + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + + return main + + +def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + + tilelang.testing.set_random_seed(0) + + causal = False + if causal: + total_flops *= 0.5 + + dtype = torch.float16 + device = torch.device("cuda") + window_size = (-1, -1) + + q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + + query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + UQ = q_unpad.shape[0] # unpadded query length + UK = k_unpad.shape[0] # unpadded key length + UKV = k_unpad.shape[0] # unpadded query key length + + kernel = flashattn(batch, UQ, UKV, heads, dim, causal) + + out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + out = output_pad_fn(out_unpad) + + out_ref, _ = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + causal=causal, + ) + torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) + + import flash_attn + + fla_out_unpad = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + causal=causal, + ) + fla_out = output_pad_fn(fla_out_unpad) + torch.testing.assert_close(out, fla_out, rtol=1e-2, atol=1e-2) + + print("All checks passed.โœ…") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="heads") + parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + + args = parser.parse_args() + main(args.batch, args.heads, args.seq_len, args.dim) diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..da172bb62a4dee0ada293fc1249812905ded8de1 --- /dev/null +++ b/examples/flash_attention/test_example_flash_attention.py @@ -0,0 +1,101 @@ +import tilelang.testing + +import example_gqa_bwd +import example_gqa_bwd_wgmma_pipelined +import example_mha_bwd_bshd +import example_mha_bwd_bhsd +import example_mha_fwd_bhsd_wgmma_pipelined +import example_gqa_fwd_bshd +import example_mha_fwd_bshd +import example_gqa_fwd_bshd_wgmma_pipelined +import example_mha_fwd_bshd_wgmma_pipelined +import example_mha_fwd_varlen +import example_mha_bwd_bshd_wgmma_pipelined +import example_mha_fwd_bhsd +import example_gqa_bwd_tma_reduce_varlen + + +@tilelang.testing.requires_cuda +def test_example_gqa_bwd_tma_reduce_varlen(): + example_gqa_bwd_tma_reduce_varlen.main() + + +@tilelang.testing.requires_cuda +def test_example_gqa_bwd(): + example_gqa_bwd.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_bwd_wgmma_pipelined(): + example_gqa_bwd_wgmma_pipelined.main() + + +@tilelang.testing.requires_cuda +def test_example_mha_bwd(): + example_mha_bwd_bshd.main( + BATCH=1, + H=16, + N_CTX=512, + D_HEAD=64, + causal=False, + ) + + +@tilelang.testing.requires_cuda +def test_example_mha_bwd_bhsd(): + example_mha_bwd_bhsd.main( + BATCH=1, + H=16, + N_CTX=512, + D_HEAD=64, + causal=False, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_bwd_wgmma_pipelined(): + example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_fwd_bshd_wgmma_pipelined(): + example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + + +@tilelang.testing.requires_cuda +def test_example_gqa_fwd_bshd(): + example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_fwd_bhsd_wgmma_pipelined(): + example_mha_fwd_bhsd_wgmma_pipelined.main() + + +@tilelang.testing.requires_cuda +def test_example_mha_fwd_bhsd(): + example_mha_fwd_bhsd.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_mha_fwd_bshd_wgmma_pipelined(): + example_mha_fwd_bshd_wgmma_pipelined.main(batch=1, heads=32, seq_len=256) + + +@tilelang.testing.requires_cuda +def test_example_mha_fwd_bshd(): + example_mha_fwd_bshd.main(batch=1, seq_len=256) + + +@tilelang.testing.requires_cuda +def test_example_mha_fwd_varlen(): + example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/flash_attention/varlen_utils.py b/examples/flash_attention/varlen_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..43e21cc3b80ce72eaa582407024ec2c42015731e --- /dev/null +++ b/examples/flash_attention/varlen_utils.py @@ -0,0 +1,108 @@ +# ruff: noqa +import torch +from einops import rearrange, repeat +from bert_padding import pad_input, unpad_input + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device) + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) diff --git a/examples/flash_decoding/README.md b/examples/flash_decoding/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a1b416125dd31d705d50327940dd62ba3ee2a2a4 --- /dev/null +++ b/examples/flash_decoding/README.md @@ -0,0 +1 @@ +# Flash Decoding diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..ee42df2080f3f3ed468b522e9df0db25377144de --- /dev/null +++ b/examples/flash_decoding/example_gqa_decode.py @@ -0,0 +1,495 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse +import itertools +from functools import lru_cache +from typing import Tuple, Dict + +torch.random.manual_seed(0) + + +def get_configs(): + block_N = [64, 128] + block_H = [64] + num_split = [1, 2, 4, 8] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] + return configs + + +@lru_cache(maxsize=1) +def get_heuristic_config() -> Tuple[Dict, int]: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version == 89: + cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128) + else: + cfg = dict(block_N=128, block_H=64, num_split=8, num_stages=2, threads=128) + return cfg, sm_version + + +# TODO(lei): fix warp specialized and tma lower pass +def get_pass_configs(): + return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) +def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [batch, seqlen_kv, groups, dim] + shape_v = [batch, seqlen_kv, groups, dim] + shape_o = [batch, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // groups + + part_shape = [batch, heads, num_split, dim] + valid_block_H = min(block_H, kv_group_num) + valid_block_N = min(block_N, seqlen_kv // num_split) + + @T.macro + def flash_attn( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + mask_local = T.alloc_fragment([block_N], "uint8") + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + mask_local = T.alloc_fragment([block_N], "uint8") + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + K[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + K_shared, + ) + T.copy( + mask[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + ], + mask_local, + ) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy( + V[ + bid, + (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, + :, + ], + V_shared, + ) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local = T.alloc_fragment([num_split, 128], dtype) + lse_logsum_local = T.alloc_fragment([128], accum_dtype) + lse_max_local = T.alloc_fragment([128], accum_dtype) + scale_local = T.alloc_fragment([128], accum_dtype) + + T.annotate_layout( + { + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), + # lse_local: (local_id, thread_id) + lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + for k, j in T.Parallel(num_split, 128): + lse_local[k, j] = glse[bz, by, k] + T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) + for k in T.serial(num_split): + for j in T.Parallel(128): + lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j]) + for j in T.Parallel(128): + lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + for j in T.Parallel(128): + scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j]) + # Note: Pay attention to dim and the number of threads in Parallel + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[i] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def flashattn_gqa_decode_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn_split(Q, K, V, mask, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn(Q, K, V, mask, Output) + + if num_split > 1: + return flashattn_gqa_decode_split + else: + return flashattn_gqa_decode_no_split + + +def ref_program(query, key, value, mask, glse, Output_partial): + # """ + # Inputs: + # - query (Tensor): [batch, heads, dim] + # - key (Tensor): [batch, seqlen_kv, groups, dim] + # - value (Tensor): [batch, seqlen_kv, groups, dim] + # - mask (Tensor): [batch, seqlen_kv, groups] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = query.shape[-1] + num_head_groups = query.shape[1] // key.shape[2] + scale = dim**0.5 + key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + if mask is not None: + mask = rearrange(mask, "b s h -> b h s") + mask = mask.unsqueeze(1) + scores = scores.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def flash_split_ref(Q, K, V, mask): + num_split = 16 + batch = Q.size(0) + nheads = Q.size(1) + groups = K.size(2) + dim = Q.size(-1) + block_N = 32 + seqlen_kv = K.size(1) + num_head_groups = nheads // groups + + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) + gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float) + glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) + + Q_ = Q * scale + Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups) + + for ks in range(num_split): + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) + for i in range(int((seqlen_kv // num_split) / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum( + "bghd,bkhd->bghk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, nheads, block_N] + if mask is not None: + mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :] + mask_local = rearrange(mask_local, "b s h -> b h s") + mask_local = mask_local.unsqueeze(1) + acc_s = acc_s.masked_fill(mask_local == 0, float("-inf")) + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] + scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] + acc_o *= scores_scale[:, :, :, None] + acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) + acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] + acc_o += torch.einsum( + "bghk,bkhd->bghd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o_out = rearrange(acc_o, "b g h d->b (h g) d") + logsum_out = rearrange(logsum, "b g h->b (h g)") + acc_o_out /= logsum_out[:, :, None] + logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)") + gacc_o[ks, :, :, :] = acc_o_out + glogsum[ks, :, :] = logsum_out + + return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3) + + +def reduce_ref(Q, K, V, mask, glse, Output_partial): + num_split = 16 + o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0) + lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) # [batch, heads] + lse_max = glse.max(dim=2, keepdim=False).values + for ks in range(num_split): + lse = glse[:, :, ks] + lse_logsum += torch.exp2(lse - lse_max) + lse_logsum = torch.log2(lse_logsum) + lse_max + for ks in range(num_split): + lse = glse[:, :, ks] + scale = torch.exp2(lse - lse_logsum) # [batch, heads] + o += Output_partial[:, :, ks, :] * scale[:, :, None] + return o.to(torch.float16) + + +def ref_split_program(Q, K, V, mask, glse=None, Output_partial=None): + glse_, Output_partial_ = flash_split_ref(Q, K, V, mask) + return reduce_ref(Q, K, V, mask, glse_, Output_partial_) + + +def print_red_warning(msg): + print(f"\033[91m{msg}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f"{name} all zero") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True): + sim = calc_sim(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print_red_warning(f"{name} Error: {diff}") + if assert_: + raise AssertionError(f"{name} Error: {diff}") + else: + if print_: + print(f"passed: {name} diff={diff}") + + +def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False): + batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim + qk_flops = 2 * batch * heads * kv_seqlen * dim + pv_flops = 2 * batch * heads * kv_seqlen * dim + total_flops = qk_flops + pv_flops + + if not tune: + config, sm_version = get_heuristic_config() + kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + + q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16) + k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) + v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) + mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8) + split = config["num_split"] + glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16) + Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16) + o = kernel(q, k, v, mask, glse, Output_partial) + o_ref = ref_program(q, k, v, mask, glse, Output_partial) + o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial) + + print(o) + print(o_ref) + + assert_similar(o, o_ref, name="o_ref") + assert_similar(o, o_ref_split, name="o_ref_split") + + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = flashattn(batch, heads, groups, kv_seqlen, dim) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=32, help="heads") + parser.add_argument("--groups", type=int, default=8, help="groups") + parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length") + parser.add_argument("--dim", type=int, default=128, help="dim") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3d8baed6fce23f474597d29aeb365e611e629e --- /dev/null +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -0,0 +1,909 @@ +import torch +import triton +import triton.language as tl +import math +import argparse +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +torch.manual_seed(0) +tilelang.disable_cache() + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +@triton.jit +def _fwd_inner( + q, + k_ptrs, + v_ptrs, + s_ptrs, + m_i, + l_i, + acc, + offs_h, + mask_h, + offs_n, + seqlen, + softmax_scale, + lo, + hi, + stride_kt, + stride_vt, + stride_sh, + stride_sn, + BLOCK_N: tl.constexpr, +): + """Inner loop computation for attention""" + + for blk_idx in tl.range(lo, hi): + start_n = blk_idx * BLOCK_N + k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < seqlen) + v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < seqlen) + + qk = tl.dot(q, k) + qk *= softmax_scale + qk += tl.where(offs_n[None, :] + start_n < seqlen, 0, -1.0e9) + + row_max = tl.max(qk, 1) + tl.store(s_ptrs + offs_h * stride_sh + blk_idx * stride_sn, row_max, mask=mask_h) + + m_ij = tl.maximum(m_i, row_max) + qk -= m_ij[:, None] + p = tl.math.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + m_i = m_ij + acc *= alpha[:, None] + p = p.to(v.type.element_ty) + acc += tl.dot(p, v) + + return m_i, l_i, acc + + +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]], + key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"], +) +@triton.jit +def _fwd_kernel_varlen( + Q, # [token_q = b, h_q, dim] + K, # [token_k, h_kv, dim] + V, + O, + S, + s_aux, + softmax_scale, + cu_seqlens_k, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_sb, + stride_sh, + stride_sn, # bmask shape [b, q_h, seq/BLOCK_N] + gqa_group_size: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + off_z = tl.program_id(0) + off_h_for_kv = tl.program_id(1) + off_h_q = off_h_for_kv * gqa_group_size + + cu_k_start = tl.load(cu_seqlens_k + off_z) + cu_k_end = tl.load(cu_seqlens_k + off_z + 1) + + seqlen_k = cu_k_end - cu_k_start + + offs_h = tl.arange(0, BLOCK_H) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + Q_ptrs = Q + off_z * stride_qt + off_h_q * stride_qh + K_ptrs = K + (cu_k_start) * stride_kt + off_h_for_kv * stride_kh + V_ptrs = V + (cu_k_start) * stride_vt + off_h_for_kv * stride_vh + O_ptrs = O + off_z * stride_ot + off_h_q * stride_oh + S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh + + mask_h = offs_h < gqa_group_size + q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) + + if s_aux is not None: + sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) + l_i = tl.zeros([BLOCK_H], dtype=tl.float32) + m_i = tl.zeros([BLOCK_H], dtype=tl.float32) + sink + else: + l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) + m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) + + acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + + k_ptrs = K_ptrs + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V_ptrs + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + lo, hi = 0, tl.cdiv(seqlen_k, BLOCK_N) + m_i, l_i, acc = _fwd_inner( + q, + k_ptrs, + v_ptrs, + S_ptrs, + m_i, + l_i, + acc, + offs_h, + mask_h, + offs_n, + seqlen_k, + softmax_scale, + lo, + hi, + stride_kt, + stride_vt, + stride_sh, + stride_sn, + BLOCK_N, + ) + + if s_aux is not None: + sink = tl.math.exp(sink - m_i) + l_i = l_i + sink + acc = acc / l_i[:, None] + + else: + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + + for blk_idx in tl.range(lo, hi): + s = tl.load(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, mask=mask_h) + s = tl.exp(s - m_i) / l_i + tl.store(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, s, mask=mask_h) + + acc = acc.to(O.dtype.element_ty) + + tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None]) + + +def get_configs(): + import itertools + + block_N = [64, 128] + block_H = [64] + num_split = [1] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") +def flashattn( + batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128 +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [total_seqlen_k, k_heads, dim] + shape_v = [total_seqlen_k, k_heads, dim] + shape_o = [batch, heads, dim] + shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // k_heads + + valid_block_H = min(block_H, kv_group_num) + # TODO: check if max_seqlen_kv is correct for varlen case + + @T.macro + def flash_attn( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) + # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) + s_aux_shared = T.alloc_shared([block_H], T.float32) + + T.annotate_layout( + { + # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + # K_shared: tilelang.layout.make_swizzled_layout(K_shared), + # V_shared: tilelang.layout.make_swizzled_layout(V_shared), + # O_shared: tilelang.layout.make_swizzled_layout(O_shared), + # S_shared: tilelang.layout.make_swizzled_layout(S_shared), + } + ) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + cur_start_k = cu_seqlens_k[bid] + cur_end_k = cu_seqlens_k[bid + 1] + cur_seqlen_k = cur_end_k - cur_start_k + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], + # -T.infinity(accum_dtype)) + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # scores_max_prev is m_i + # scores_max is row_max->m_ij in triton + T.copy(scores_max, S_shared[:, k]) + # scores_scale is alpha in triton + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # scores_sum is l_ij in triton + # logsum is l_i in triton + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + if has_sink: + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) + for i in T.Parallel(block_H): + logsum[i] += s_aux_shared[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] + # T.copy(S_shared, S_fragment) + # for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + # S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + # T.copy(S_fragment, S_shared) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), + ): + flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S) + + # TODO: split version + return flashattn_gqa_decode_no_split + + +def flash_attn_with_attn_pool_decode_tilelang( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, + tl_kernel=None, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + assert K.is_contiguous() + assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + O_tl = torch.zeros_like(Q) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) + O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) + + if use_per_kv_head_sparse_index: + S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O_tl, S_tl + + +def flash_attn_with_attn_pool_decode( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + assert K.is_contiguous() + assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + BLOCK_D = head_size + BLOCK_N = block_size + BLOCK_H = 64 + + O = torch.zeros_like(Q) + S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device) + + def grid(META): + return (batch, k_h) + + with torch.cuda.device(Q.device.index): + _fwd_kernel_varlen[grid]( + Q, + K, + V, + O, + S, + s_aux, + softmax_scale, + cu_seqlens_k, + *Q.stride(), + *K.stride(), + *V.stride(), + *O.stride(), + *S.stride(), + gqa_group_size, + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + ) + + if use_per_kv_head_sparse_index: + S = torch.max_pool2d(S, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S = torch.max_pool2d(S, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O, S + + +def test_equal_seqlen_decode_main(args): + """Test decode kernel with equal sequence lengths""" + print("Testing decode kernel with equal sequence lengths") + + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + # For decode, query is just 1 token per batch + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + softmax_scale = 1.0 / math.sqrt(head_size) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Convert to varlen format for K, V + k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) + v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) + + # Generate cumulative sequence lengths + cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) + max_seqlen_k = k_seqlen + + print(f"q shape: {q.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Compute torch reference + q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] + k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + + if sink is None: + # Standard scaled dot-product attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + attn_weights = torch.softmax(logits, dim=-1) + O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] + + # Compute attention score pooling + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, k_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(torch.float16) + + print("S_tilelang", S_tilelang) + print("attn_score_pooled", attn_score_pooled) + + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch)) + max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled)) + + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + print("โœ… All tests passed!") + + +def test_varlen_decode_main(args): + """Test decode kernel with variable sequence lengths""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen # Use as max sequence length + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Generate variable length k sequences + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + print(f"k_seqlens: {k_seqlens}") + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + print(f"cu_seqlens_k: {cu_seqlens_k}") + + # Generate tensors - Q is [batch_size, q_heads, head_size] for decode + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + print(f"Actual max_seqlen_k: {max_seqlen_k}") + print(f"q_decode shape: {q_decode.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Create torch reference - pad tensors for comparison + k_padded_list = [] + v_padded_list = [] + + for i in range(batch_size): + actual_k_len = k_seqlens[i] + + # Extract and pad k, v for this batch + k_start = cu_seqlens_k[i] + k_end = cu_seqlens_k[i + 1] + + # Pad to max_seqlen_k + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + + k_padded[:actual_k_len] = k_varlen[k_start:k_end] + v_padded[:actual_k_len] = v_varlen[k_start:k_end] + + k_padded_list.append(k_padded) + v_padded_list.append(v_padded) + + # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + + # Expand q to match kv heads: [b, q_heads, 1, head_size] + q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] + + print(f"q_expanded shape: {q_expanded.shape}") + print(f"k_padded_batched shape: {k_padded_batched.shape}") + print(f"v_padded_batched shape: {v_padded_batched.shape}") + + # Compute torch reference + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + + if sink is None: + # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_score[i, :, :, actual_k_len:] = float("-inf") + + attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + logits[i, :, :, actual_k_len:] = float("-inf") + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] + + O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] + + # Compute attention score pooling for S + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, max_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + + print(f"O_triton shape: {O_triton.shape}") + print(f"O_tilelang shape: {O_tilelang.shape}") + print(f"O_torch shape: {O_torch.shape}") + print(f"S_triton shape: {S_triton.shape}") + print(f"S_tilelang shape: {S_tilelang.shape}") + print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + + # Compare results + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") + + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") + + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) + + print("โœ… All tests passed!") + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def speed_benchmark_decode_comparison(args): + """Speed benchmark for decode kernel""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print("\n=== Decode Speed Benchmark Comparison ===") + print("Configuration:") + print(f" Batch size: {batch_size}") + print(f" Q heads: {q_heads}, KV heads: {kv_heads}") + print(f" Max K sequence length: {max_k_seqlen}") + print(f" Head size: {head_size}") + print(f" Block size: {block_size}") + print(f" Data type: {dtype}") + print(f" Variable lengths: {args.test_varlen}") + print(f" s_aux attention: {args.test_sink}") + print() + + # Generate input data + if args.test_varlen: + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + else: + k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + # Generate tensors + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(" Using sink attention with sink values") + + print("Setup complete:") + print(f" Total K tokens: {total_k_tokens}") + print(f" Actual max K seq len: {max_seqlen_k}") + if args.test_varlen: + print(f" K sequence lengths: {k_seqlens.tolist()}") + + # Warmup + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + + # Benchmark + print("โšก Benchmarking Tilelang kernel (100 iterations)...") + tilelang_time = do_bench( + flash_attn_with_attn_pool_decode_tilelang, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + False, + tl_kernel, + ) + print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") + + # Benchmark + print("โšก Benchmarking Triton kernel (100 iterations)...") + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) + print(f"Average decode kernel time Triton: {triton_time:.3f} ms") + + print(f"Speedup: {(triton_time / tilelang_time):.3f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=64, help="Block size for computation") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") + args = parser.parse_args() + args.test_sink = True + args.test_varlen = False + args.dtype = T.float16 + args.num_split = 1 + + if args.benchmark: + speed_benchmark_decode_comparison(args) + elif args.test_varlen: + test_varlen_decode_main(args) + else: + test_equal_seqlen_decode_main(args) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py new file mode 100644 index 0000000000000000000000000000000000000000..0984e707531fc49e3f7c2130b1299b9826c4ea53 --- /dev/null +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py @@ -0,0 +1,679 @@ +import torch +import math +import argparse +import tilelang +import tilelang.language as T +from example_gqa_decode_varlen_logits import flash_attn_with_attn_pool_decode, repeat_kv, do_bench + +torch.manual_seed(0) + + +def get_configs(): + import itertools + + block_N = [64, 128] + block_H = [64] + num_split = [1] + num_stages = [1, 2, 3] + threads = [128] + _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) + + configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs] + return configs + + +# @autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") +def flashattn( + batch, + heads, + k_heads, + max_seqlen_kv, + total_seqlen_k, + dim, + has_sink, + page_block_size, + block_N=128, + block_H=64, + num_split=1, + num_stages=1, + threads=128, +): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, dim] + shape_k = [total_seqlen_k, k_heads, dim] + shape_v = [total_seqlen_k, k_heads, dim] + shape_o = [batch, heads, dim] + shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // k_heads + assert page_block_size >= block_N and page_block_size % block_N == 0, ( + "page_block_size must be larger than block_N and a multiple of block_N" + ) + + valid_block_H = min(block_H, kv_group_num) + # TODO: check if max_seqlen_kv is correct for varlen case + + @T.macro + def flash_attn( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], T.int32), + Output: T.Tensor([batch, heads, dim], dtype), + S: T.Tensor(shape_s, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) + s_aux_shared = T.alloc_shared([block_H], T.float32) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + cur_start_k = cu_seqlens_k[bid] + cur_end_k = cu_seqlens_k[bid + 1] + cur_seqlen_k = cur_end_k - cur_start_k + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # scores_max_prev is m_i + # scores_max is row_max->m_ij in triton + T.copy(scores_max, S_shared[:, k]) + # scores_scale is alpha in triton + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + # scores_sum is l_ij in triton + # logsum is l_i in triton + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size + T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + if has_sink: + T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared) + for i in T.Parallel(block_H): + logsum[i] += s_aux_shared[i] + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): + S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], T.int32), + Output: T.Tensor(shape_o, dtype), + S: T.Tensor(shape_s, dtype), + ): + flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S) + + # TODO: split version + return flashattn_gqa_decode_no_split + + +def flash_attn_with_attn_pool_decode_tilelang( + Q: torch.Tensor, ## [tq = b, q_h, q_dim] + K: torch.Tensor, ## [tk, k_h, k_dim] + V: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_k: int, + real_max_k_seqlen: int, + num_split: int, + softmax_scale: float, + s_aux: torch.Tensor = None, + block_size: int = 64, + use_per_kv_head_sparse_index: bool = False, + tl_kernel=None, + block_table: torch.Tensor = None, +): + num_tokens, q_h, head_size = Q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = K.size(1) + + assert Q.dim() == K.dim() == 3 + assert Q.size(2) == K.size(2) + assert cu_seqlens_k.dim() == 1 + assert head_size in {64, 128, 256} + assert Q.is_contiguous() + assert K.is_contiguous() + assert V.is_contiguous() + + gqa_group_size = q_h // k_h + + O_tl = torch.zeros_like(Q) + S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) + O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table) + + if use_per_kv_head_sparse_index: + S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + else: + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) + + return O_tl, S_tl + + +def test_equal_seqlen_decode_main(args): + """Test decode kernel with equal sequence lengths""" + print("Testing decode kernel with equal sequence lengths") + + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + # For decode, query is just 1 token per batch + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype) + softmax_scale = 1.0 / math.sqrt(head_size) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Convert to varlen format for K, V + k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() + v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() + + # Generate cumulative sequence lengths + cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32) + max_seqlen_k = k_seqlen + + print(f"q shape: {q.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) + + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + block_table=block_table, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Compute torch reference + q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] + k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size] + + if sink is None: + # Standard scaled dot-product attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + attn_weights = torch.softmax(logits, dim=-1) + O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k] + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size] + + # Compute attention score pooling + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, k_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(torch.float16) + + print("S_tilelang", S_tilelang) + print("attn_score_pooled", attn_score_pooled) + + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch)) + max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled)) + + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}" + assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}" + print("โœ… All tests passed!") + + +def test_varlen_decode_main(args): + """Test decode kernel with variable sequence lengths""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen # Use as max sequence length + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(f"Using sink attention with sink values: {sink}") + + # Generate variable length k sequences + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + print(f"k_seqlens: {k_seqlens}") + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + print(f"cu_seqlens_k: {cu_seqlens_k}") + + # Generate tensors - Q is [batch_size, q_heads, head_size] for decode + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + print(f"Actual max_seqlen_k: {max_seqlen_k}") + print(f"q_decode shape: {q_decode.shape}") + print(f"k_varlen shape: {k_varlen.shape}") + print(f"v_varlen shape: {v_varlen.shape}") + + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) + + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Test our decode kernel + O_triton, S_triton = flash_attn_with_attn_pool_decode( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + ) + O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + real_max_k_seqlen, + args.num_split, + softmax_scale, + s_aux=sink, + block_size=block_size, + tl_kernel=tl_kernel, + block_table=block_table, + ) + for i in range(batch_size): + S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 + + # Create torch reference - pad tensors for comparison + k_padded_list = [] + v_padded_list = [] + + for i in range(batch_size): + actual_k_len = k_seqlens[i] + + # Extract and pad k, v for this batch + k_start = cu_seqlens_k[i] + k_end = cu_seqlens_k[i + 1] + + # Pad to max_seqlen_k + k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) + + k_padded[:actual_k_len] = k_varlen[k_start:k_end] + v_padded[:actual_k_len] = v_varlen[k_start:k_end] + + k_padded_list.append(k_padded) + v_padded_list.append(v_padded) + + # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] + k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] + + # Expand q to match kv heads: [b, q_heads, 1, head_size] + q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] + + print(f"q_expanded shape: {q_expanded.shape}") + print(f"k_padded_batched shape: {k_padded_batched.shape}") + print(f"v_padded_batched shape: {v_padded_batched.shape}") + + # Compute torch reference + k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] + + if sink is None: + # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] + attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_score[i, :, :, actual_k_len:] = float("-inf") + + attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] + else: + # s_aux attention + logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + + # Apply sequence length masking + for i in range(batch_size): + actual_k_len = k_seqlens[i] + logits[i, :, :, actual_k_len:] = float("-inf") + + sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) + sinks = torch.exp(sink_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + attn_weights = unnormalized_scores / normalizer + + # Mask out invalid positions + for i in range(batch_size): + actual_k_len = k_seqlens[i] + attn_weights[i, :, :, actual_k_len:] = 0.0 + + # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] + O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] + + O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] + + # Compute attention score pooling for S + attn_score_pooled = torch.max_pool2d( + attn_weights.squeeze(2), # [b, q_heads, max_seqlen] + kernel_size=(q_heads, block_size), + stride=(q_heads, block_size), + ceil_mode=True, + ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] + + print(f"O_triton shape: {O_triton.shape}") + print(f"O_tilelang shape: {O_tilelang.shape}") + print(f"O_torch shape: {O_torch.shape}") + print(f"S_triton shape: {S_triton.shape}") + print(f"S_tilelang shape: {S_tilelang.shape}") + print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + + # Compare results + max_diff_o = torch.max(torch.abs(O_triton - O_torch)) + max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) + print(f"Max difference in O: {max_diff_o.item()}") + print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") + + max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) + max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled)) + print(f"Max difference in S: {max_diff_s.item()}") + print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") + + assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" + assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" + assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" + assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), ( + f"Score mismatch: {max_diff_s_tl.item()}" + ) + + print("โœ… All tests passed!") + + +def speed_benchmark_decode_comparison(args): + """Speed benchmark for decode kernel""" + batch_size = args.batch_size + q_heads = args.q_heads + kv_heads = args.kv_heads + max_k_seqlen = args.k_seqlen + real_max_k_seqlen = args.k_seqlen + head_size = args.head_size + block_size = args.block_size + page_block_size = args.page_block_size + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 + + print("\n=== Decode Speed Benchmark Comparison ===") + print("Configuration:") + print(f" Batch size: {batch_size}") + print(f" Q heads: {q_heads}, KV heads: {kv_heads}") + print(f" Max K sequence length: {max_k_seqlen}") + print(f" Head size: {head_size}") + print(f" Block size: {block_size}") + print(f" Data type: {dtype}") + print(f" Variable lengths: {args.test_varlen}") + print(f" s_aux attention: {args.test_sink}") + print() + + # Generate input data + if args.test_varlen: + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + else: + k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) + + # Generate cumulative sequence lengths for k + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + total_k_tokens = 0 + for i in range(batch_size): + cu_seqlens_k[i] = total_k_tokens + total_k_tokens += k_seqlens[i] + cu_seqlens_k[batch_size] = total_k_tokens + + # Generate tensors + q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + + softmax_scale = 1.0 / math.sqrt(head_size) + max_seqlen_k = int(k_seqlens.max()) + + # Generate sink values if needed + sink = None + if args.test_sink: + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values + print(" Using sink attention with sink values") + + print("Setup complete:") + print(f" Total K tokens: {total_k_tokens}") + print(f" Actual max K seq len: {max_seqlen_k}") + if args.test_varlen: + print(f" K sequence lengths: {k_seqlens.tolist()}") + + # Warmup + num_tokens, q_h, head_size = q_decode.shape + batch = cu_seqlens_k.size(0) - 1 + k_h = k_varlen.size(1) + tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size) + + block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32) + block_cnt = 0 + for i in range(batch): + cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() + for j in range(math.ceil(cur_seqlen / page_block_size)): + block_table[i, j] = block_cnt + block_cnt += 1 + block_cnt = 0 + + # Benchmark + print("โšก Benchmarking Tilelang kernel (100 iterations)...") + tilelang_time = do_bench( + flash_attn_with_attn_pool_decode_tilelang, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + False, + tl_kernel, + block_table, + ) + print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") + + # Benchmark + print("โšก Benchmarking Triton kernel (100 iterations)...") + triton_time = do_bench( + flash_attn_with_attn_pool_decode, + q_decode, + k_varlen, + v_varlen, + cu_seqlens_k, + max_seqlen_k, + args.k_seqlen, + 1, + softmax_scale, + sink, + block_size, + ) + print(f"Average decode kernel time Triton: {triton_time:.3f} ms") + print(f"Speedup: {(triton_time / tilelang_time):.3f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads") + parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads") + parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") + parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") + parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") + parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") + parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") + parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") + parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits") + parser.add_argument("--page_block_size", type=int, default=128, help="Page block size") + args = parser.parse_args() + args.test_sink = True + args.test_varlen = True + args.dtype = T.float16 + args.num_split = 1 + + if args.benchmark: + speed_benchmark_decode_comparison(args) + elif args.test_varlen: + test_varlen_decode_main(args) + else: + test_equal_seqlen_decode_main(args) diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5b243d695eefd8bb2d0fa0f18e30986d96ca7135 --- /dev/null +++ b/examples/flash_decoding/example_mha_inference.py @@ -0,0 +1,322 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from functools import partial + +num_split = 4 + + +@tilelang.jit(out_idx=[5]) +def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape_q = [batch, seqlen_q, heads, dim] + shape_kv = [batch, seqlen_kv, heads, dim] + part_shape = [batch, seqlen_q, heads, num_split, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.macro + def MMA0( + K: T.Tensor(shape_kv, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + mid: T.int32, + hid: T.int32, + bid: T.int32, + sid: T.int32, + ): + T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared) + # TODO: Handle causal split case + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape_kv, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + hid: T.int32, + bid: T.int32, + sid: T.int32, + ): + T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_kv, dtype), + V: T.Tensor(shape_kv, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), + ): + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + mid = bx + hid = by % heads + bid = by // heads + sid = bz + + # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently + # disable relevant tma copy and use SIMT as fallback for now + T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + # TODO: Handle causal split case + loop_range = ( + T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N)) + if is_causal + else T.ceildiv((seqlen_kv // num_split), block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=2): + MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_q, dtype), + ): + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): + po_local = T.alloc_fragment([block_M, dim], dtype) + po_shared = T.alloc_shared([block_M, dim], dtype) + o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype) + o_shared = T.alloc_shared([block_M, dim], dtype) + lse_local = T.alloc_fragment([num_split, block_M], dtype) + lse_local_split = T.alloc_fragment([block_M], accum_dtype) + lse_logsum_local = T.alloc_fragment([block_M], accum_dtype) + lse_max_local = T.alloc_fragment([block_M], accum_dtype) + scale_local = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout( + { + o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), + o_shared: tilelang.layout.make_swizzled_layout(o_shared), + po_shared: tilelang.layout.make_swizzled_layout(po_shared), + } + ) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + T.copy( + glse[ + bz, + by, + :, + bx * block_M : (bx + 1) * block_M, + ], + lse_local, + ) + T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) + for k in T.Pipelined(num_split): + T.copy(lse_local[k, :], lse_local_split) + for i in T.Parallel(block_M): + lse_logsum_local[i] += T.exp2(lse_local_split[i] - lse_max_local[i]) + for i in T.Parallel(block_M): + lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] + for k in T.Pipelined(num_split, num_stages=2): + T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True) + T.copy(po_shared, po_local) + for i in T.Parallel(block_M): + lse_local_split[i] = lse_local[k, i] + for i in T.Parallel(block_M): + scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i]) + for i, j in T.Parallel(block_M, dim): + o_accum_local[i, j] += po_local[i, j] * scale_local[i] + T.copy(o_accum_local, o_shared) + T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True) + + @T.prim_func + def flashattn_mha_inference( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_kv, dtype), + V: T.Tensor(shape_kv, dtype), + glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), + Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] + Output: T.Tensor(shape_q, dtype), + ): + flash_attn_split(Q, K, V, glse, Output_partial) + combine(glse, Output_partial, Output) + + return flashattn_mha_inference + + +def ref_program(Q, K, V, glse, Output_partial, causal): + assert causal is False + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def reduce_ref(Q, K, V, glse, Output_partial, causal): + o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0) + lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads] + lse_max = glse.max(dim=2, keepdim=False).values + for ks in range(num_split): + lse = glse[:, :, ks, :] + lse_logsum += torch.exp2(lse - lse_max) + lse_logsum = torch.log2(lse_logsum) + lse_max + for ks in range(num_split): + lse = glse[:, :, ks, :] + scale = torch.exp2(lse - lse_logsum) # [batch, heads, seqlen_q] + o += Output_partial[:, :, :, ks, :] * scale[:, :, :, None].transpose(1, 2) + return o.to(torch.float16) + + +def flash_split_ref(Q, K, V, causal): + # [batch, seqlen_q, heads, dim] + batch = Q.size(0) + block_M = Q.size(1) + nheads = Q.size(2) + dim = Q.size(3) + block_N = 128 + seqlen_kv = K.size(1) + + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) + gacc_o = torch.empty((num_split, batch, block_M, nheads, dim), device="cuda", dtype=torch.float) + glogsum = torch.empty((num_split, batch, nheads, block_M), device="cuda", dtype=torch.float) + + Q_ = Q * scale + + for ks in range(num_split): + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float("-inf")) + scores_max_prev.fill_(float("-inf")) + for i in range(int((seqlen_kv // num_split) / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum( + "bqhd,bkhd->bhqk", + Q_, + K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) # [batch, seqlen, nheads, block_N] + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] + scores_scale = torch.exp2(scores_max_prev - scores_max) + acc_o *= scores_scale[:, :, :, None].transpose(1, 2) + acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) + acc_s_cast = acc_s.to(torch.float16) + acc_o += torch.einsum( + "bhqk,bkhd->bqhd", + acc_s_cast, + V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :], + ) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o /= logsum[:, :, :, None].transpose(1, 2) + logsum = torch.log2(logsum) + scores_max + gacc_o[ks, :, :, :, :] = acc_o + glogsum[ks, :, :, :] = logsum + + return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) + + +def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): + flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + BLOCK_M = 128 + BLOCK_N = 64 # if D_HEAD <= 128 else 32 + kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) + ref_fn = partial(ref_program, causal=causal) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01) + print("All checks passed!") + + latency = profiler.do_bench(ref_fn, warmup=500) + print("{:.2f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(n_warmup=10, n_repeat=10) + print("{:.4f} ms".format(latency)) + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + main() diff --git a/examples/flash_decoding/test_example_flash_decoding.py b/examples/flash_decoding/test_example_flash_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..c728dfe0e1f14712ba29363ca380fa425bd9d536 --- /dev/null +++ b/examples/flash_decoding/test_example_flash_decoding.py @@ -0,0 +1,19 @@ +import tilelang.testing + +import example_gqa_decode +import example_mha_inference + + +# TODO(lei): fix the correctness of gqa decode on sm90 +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_example_example_gqa_decode(): + example_gqa_decode.main() + + +def test_example_example_mha_inference(): + example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..36c6ef3dc20004e9ac0076d2f6bc7680e5c33371 --- /dev/null +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -0,0 +1,524 @@ +import math +import torch +import torch.nn as nn +from typing import Dict, Tuple, Optional +import tilelang +import tilelang.language as T +from tilelang.autotuner import * +from example_fusedmoe_torch import * + + +@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def moe_forward_tilelang_shared( + d_hidden, + d_expert, + n_shared_experts, + dtype, + num_tokens, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, +): + scale = 1.44269504 # log2(e) + + # Parameters + dhidden = d_hidden + dexpert = d_expert * n_shared_experts + + # Tensors: Note that input shape is reshape to (num_tokens, dhidden) + input_shape = (num_tokens, dhidden) + shared_W_gate_shape = (dexpert, dhidden) + shared_W_up_shape = (dexpert, dhidden) + shared_W_down_shape = (dhidden, dexpert) + + accum_type = T.float32 + + @T.prim_func + def kernel_shared( + input: T.Tensor(input_shape, dtype), # type: ignore + shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore + shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore + shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore + up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore + ): + # Step 1: Compute gate and up logits + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): + # Split the block to shared experts and routed experts + input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) + W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) + W_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) + # Shared experts: no need to check expert_indices + + gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type) + up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type) + + T.use_swizzle(10) + T.clear(gate_logits_local) + T.clear(up_logits_local) + + # Parallel for gate and up matmul + for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): + T.copy(input[bx * block_token, k * block_dhidden], input_shared) + T.copy(shared_W_gate[by * block_dexpert, k * block_dhidden], W_gate_shared) + T.copy(shared_W_up[by * block_dexpert, k * block_dhidden], W_up_shared) + T.gemm(input_shared, W_gate_shared, gate_logits_local, transpose_B=True) + T.gemm(input_shared, W_up_shared, up_logits_local, transpose_B=True) + + # Fuse with SiLU and element-wise product + for i, j in T.Parallel(block_token, block_dexpert): + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] + + T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert]) + + # Step 2: Compute down logits + with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by): + up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) + W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) + output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type) + + T.use_swizzle(10) + T.clear(output_local) + + for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): + T.copy(up_logits[bx * block_token, k * block_dexpert], up_logits_shared) + T.copy(shared_W_down[by * block_dhidden, k * block_dexpert], W_down_shared) + T.gemm(up_logits_shared, W_down_shared, output_local, transpose_B=True) + + T.copy(output_local, output[bx * block_token, by * block_dhidden]) + + return kernel_shared + + +@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def moe_forward_tilelang_routed( + d_hidden, + d_expert, + n_routed_experts, + dtype, + group_sum, + group_count, + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=None, +): + scale = 1.44269504 # log2(e) + + # Parameters + dhidden = d_hidden + dexpert = d_expert + n_routed_experts = n_routed_experts + + # Group info + # group_sum = sum(group_sizes_list) + # group_count = len(group_sizes_list) + # M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list]) + M = math.ceil(group_sum / block_token) + group_count + accum_dtype = T.float32 + + # Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm + input_shape = (group_sum, dhidden) + intermediate_shape = (group_sum, dexpert) + routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden) + routed_expert_up_shape = (n_routed_experts, dexpert, dhidden) + routed_expert_down_shape = (n_routed_experts, dhidden, dexpert) + routed_expert_weights_shape = group_sum + group_sizes_shape = n_routed_experts + + @T.prim_func + def kernel( + input: T.Tensor(input_shape, dtype), # type: ignore + routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore + routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore + routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore + routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore + group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore + up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore + output: T.Tensor(input_shape, dtype), # type: ignore + ): + # Step 1: Compute gate and up logits + with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): + input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) + routed_expert_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) + routed_expert_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) + + gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) + up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) + + cur_group_idx = T.alloc_local([1], T.int32) + cur_group_size = T.alloc_local([1], T.int32) + + T.use_swizzle(10, enable=True) + + m_start_padded = bx * block_token + + cur_group_idx[0] = group_idx_for_bx[bx] + + cur_group_size[0] = group_sizes[cur_group_idx[0]] + m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] + actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + + T.clear(gate_logits_local) + T.clear(up_logits_local) + + for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): + T.copy( + input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden], + input_shared, + coalesced_width=coalesced_width, + ) + T.copy( + routed_expert_gate[ + cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], + routed_expert_gate_shared, + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True) + T.copy( + routed_expert_up[ + cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden + ], + routed_expert_up_shared, + coalesced_width=coalesced_width, + ) + T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True) + + for i, j in T.Parallel(block_token, block_dexpert): + gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) + up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] + + for i, j in T.Parallel(block_token, block_dexpert): + if i < actual_rows: + up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j] + + # Step 2: Compute down logits + with T.Kernel(M, T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by): + up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) + routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) + output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype) + + cur_group_idx = T.alloc_local([1], T.int32) + cur_group_size = T.alloc_local([1], T.int32) + + T.use_swizzle(10, enable=True) + + m_start_padded = bx * block_token + + cur_group_idx[0] = group_idx_for_bx[bx] + + cur_group_size[0] = group_sizes[cur_group_idx[0]] + m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]] + actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]]))) + + T.clear(output_local) + + for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): + T.copy( + up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert], + up_logits_shared, + coalesced_width=coalesced_width, + ) + T.copy( + routed_expert_down[ + cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert + ], + routed_expert_down_shared, + coalesced_width=coalesced_width, + ) + T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True) + + for i, j in T.Parallel(block_token, block_dhidden): + if i < actual_rows: + output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i] + + return kernel + + +class Expert(nn.Module): + def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None): + super().__init__() + self.config = config + self.act_fn = nn.SiLU() + self.d_hidden: int = config["d_hidden"] + self.d_expert: int = config["d_expert"] if d_expert is None else d_expert + self.device = torch.device("cuda") + + self.W_gate_weight = gate.t().contiguous().to(self.device) + self.W_up_weight = up.t().contiguous().to(self.device) + self.W_down_weight = down.t().contiguous().to(self.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = self.act_fn(x @ self.W_gate_weight) + out = (gate * (x @ self.W_up_weight)) @ self.W_down_weight + return out + + +class MoEGate(nn.Module): + def __init__(self, config: Dict, weights: Dict): + super().__init__() + self.top_k: int = config["n_experts_per_token"] + self.num_experts: int = config["n_routed_experts"] + self.d_hidden: int = config["d_hidden"] + + self.W_g_weight = weights["router.weight"].t() + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + logits = x @ self.W_g_weight + scores = logits.softmax(dim=-1) + topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + return topk_indices, topk_scores + + +class MoE(nn.Module): + def __init__( + self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128 + ): + super().__init__() + self.config = config + self.shared_kernel = shared_kernel + self.routed_kernel = routed_kernel + self.padding_M = padding_M + self.experts = nn.ModuleList( + [ + Expert( + config, + gate=weights[f"experts.{i}.0.weight"], + up=weights[f"experts.{i}.1.weight"], + down=weights[f"experts.{i}.2.weight"], + ) + for i in range(config["n_routed_experts"]) + ] + ) + self.device = torch.device("cuda") + self.gating_network = MoEGate(config, weights).to(self.device) + shared_expert_dim = config["d_expert"] * config["n_shared_experts"] + self.shared_expert = Expert( + config=config, + gate=weights["shared_experts.0.weight"], + up=weights["shared_experts.1.weight"], + down=weights["shared_experts.2.weight"], + d_expert=shared_expert_dim, + ).to(self.device) + self.expert_cache = torch.zeros( + (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device + ) + self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0) + self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0) + self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0) + self.stacked_expert_tokens = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), + dtype=torch.float16, + device=self.device, + ) + self.stacked_expert_weights = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device + ) + self.stacked_expert_tokens_idxs = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device + ) + + self.up_logits_shared = torch.empty( + (config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device + ) + self.expert_output_shared = torch.empty( + (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device + ) + self.up_logits_routed = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]), + dtype=torch.float16, + device=self.device, + ) + self.expert_output_routed = torch.empty( + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]), + dtype=torch.float16, + device=self.device, + ) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_shape = x.shape + batch_size, seq_len, hidden_dim = x.shape + expert_indices, expert_scores = self.gating_network(x) + flat_expert_indices = expert_indices.view(-1) + flat_expert_weights = expert_scores.view(-1) + x_flat = x.view(-1, hidden_dim) + + # Prepare for grouped GEMM + idxs = flat_expert_indices.argsort() + counts = flat_expert_indices.bincount().cpu().numpy() + # counts = flat_expert_indices.bincount() + tokens_per_expert = counts.cumsum() + # tokens_per_expert = torch.cumsum(counts, dim=0) + num_per_tok = self.config["n_experts_per_token"] + token_idxs = idxs // num_per_tok + + # Get stacked expert tokens and expert weights + + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x_flat[exp_token_idxs] + + self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens + self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs + self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]] + + group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device) + group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device) + + group_padded_offsets = [0 for _ in range(len(group_sizes))] + for i in range(1, len(group_sizes)): + group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M + + block_token = 128 + M = ( + math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token) + + self.config["n_routed_experts"] + ) + group_idx_for_bx = [0 for _ in range(M)] + + for bx in range(M): + m_start_padded = bx * block_token + for i in range(self.config["n_routed_experts"]): + if m_start_padded >= group_padded_offsets[i]: + group_idx_for_bx[bx] = i + + group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device) + group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device) + + # Multi-stream execution + shared_stream = torch.cuda.Stream() + routed_stream = torch.cuda.default_stream() + torch.cuda.synchronize() + + with torch.cuda.stream(routed_stream): + # Tilelang version: Grouped GEMM + self.routed_kernel( + self.stacked_expert_tokens, + self.stacked_expert_w_gate, + self.stacked_expert_w_up, + self.stacked_expert_w_down, + self.stacked_expert_weights, + group_sizes, + group_offset, + group_padded_offsets, + group_idx_for_bx, + self.up_logits_routed, + self.expert_output_routed, + ) + + # Scatter reduce + self.expert_cache = torch.scatter_reduce( + self.expert_cache, + 0, + self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]), + self.expert_output_routed, + reduce="sum", + ) + routed_output = self.expert_cache.view(*orig_shape) + + with torch.cuda.stream(shared_stream): + self.shared_kernel( + x_flat, + self.shared_expert.W_gate_weight, + self.shared_expert.W_up_weight, + self.shared_expert.W_down_weight, + self.up_logits_shared, + self.expert_output_shared, + ) + shared_output = self.expert_output_shared.view(*orig_shape) + + torch.cuda.synchronize() + + return shared_output + routed_output + + +def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: + """ + DeepSeek-style Mixture of Experts using Tilelang. + + Args: + data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict) + - input: Input tensor of shape [batch_size, seq_len, hidden_size] + - weights: Dictionary containing model weights + - config: Dictionary containing model configuration parameters + + Returns: + Tuple containing: + - output: Processed tensor [batch_size, seq_len, d_model] + """ + input_tensor, weights, config = data + + dtype_str = T.float16 + + shared_kernel = moe_forward_tilelang_shared( + config["d_hidden"], + config["d_expert"], + config["n_shared_experts"], + dtype=dtype_str, + num_tokens=config["batch_size"] * config["seq_len"], + ) + routed_kernel = moe_forward_tilelang_routed( + config["d_hidden"], + config["d_expert"], + config["n_routed_experts"], + dtype=dtype_str, + group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], + group_count=config["n_routed_experts"], + block_token=128, + block_dhidden=128, + block_dexpert=128, + threads=256, + num_stages=1, + k_pack=1, + coalesced_width=2, + ) + + moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) + + output = moe(input_tensor) + + return output + + +def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192): + config = { + "dhidden": d_hidden, + "dexpert": d_expert, + "nroutedexperts": n_routed_experts, + "nsharedexperts": n_shared_experts, + "nexpertspertoken": n_experts_per_token, + "bs": batch_size, + "seqlen": seq_len, + "seed": 81394, + } + + data = generate_input(**config) + + torch.cuda.synchronize() + ref_output = ref_kernel(clone_data(data)).to(torch.float32) + torch.cuda.synchronize() + tilelang_output = custom_kernel(clone_data(data)).to(torch.float32) + torch.cuda.synchronize() + + torch.testing.assert_close(ref_output, tilelang_output, atol=1e-2, rtol=1e-2) + print("โœ… Tilelang and Torch match") + + +if __name__ == "__main__": + main() diff --git a/examples/fusedmoe/example_fusedmoe_torch.py b/examples/fusedmoe/example_fusedmoe_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6322aff7dce196ce12f371d83f47e5c1fa82e4 --- /dev/null +++ b/examples/fusedmoe/example_fusedmoe_torch.py @@ -0,0 +1,210 @@ +import math +import torch +import torch.nn as nn +from typing import Dict, Tuple, Optional + + +# Reference code in PyTorch +class ExpertTorch(nn.Module): + def __init__(self, config: Dict, d_expert: Optional[int] = None): + super().__init__() + self.config = config + self.act_fn = nn.SiLU() + self.d_hidden: int = config["d_hidden"] + self.d_expert: int = config["d_expert"] if d_expert is None else d_expert + + self.W_gate = nn.Linear(self.d_hidden, self.d_expert, bias=False) + self.W_up = nn.Linear(self.d_hidden, self.d_expert, bias=False) + self.W_down = nn.Linear(self.d_expert, self.d_hidden, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = self.act_fn(self.W_gate(x)) + out = self.W_down(gate * self.W_up(x)) + return out + + +class MoEGateTorch(nn.Module): + def __init__(self, config: Dict): + super().__init__() + self.top_k: int = config["n_experts_per_token"] + self.num_experts: int = config["n_routed_experts"] + self.d_hidden: int = config["d_hidden"] + + self.W_g = nn.Linear(self.d_hidden, self.num_experts, bias=False) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + logits = self.W_g(x) + scores = logits.softmax(dim=-1) + topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + return topk_indices, topk_scores + + +class MoETorch(nn.Module): + def __init__(self, config: Dict): + super().__init__() + self.config = config + self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])]) + self.gating_network = MoEGateTorch(config) + shared_expert_dim = config["d_expert"] * config["n_shared_experts"] + self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shared_output = self.shared_expert(x) + expert_indices, expert_scores = self.gating_network(x) + batch_size, seq_len, hidden_dim = x.shape + orig_shape = x.shape + x_flat = x.view(-1, hidden_dim) + flat_expert_indices = expert_indices.view(-1) + flat_expert_weights = expert_scores.view(-1, 1) + routed_output_flat = self.moe_infer(x_flat, flat_expert_indices, flat_expert_weights) + + routed_output = routed_output_flat.view(*orig_shape) + return routed_output + shared_output + + @torch.no_grad() + def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor: + expert_cache = torch.zeros_like(x) + # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) + # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) + # test_expert_ups = torch.zeros((self.config["n_routed_experts"], self.config["d_hidden"], self.config["d_expert"])) + # test_expert_tokens_num = torch.zeros((self.config["n_routed_experts"])) + + idxs = flat_expert_indices.argsort() + counts = flat_expert_indices.bincount().cpu().numpy() + tokens_per_expert = counts.cumsum() + num_per_tok = self.config["n_experts_per_token"] + token_idxs = idxs // num_per_tok + for expert_id, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1] + if start_idx == end_idx: + continue + + expert = self.experts[expert_id] + exp_token_idxs = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idxs] + expert_out = expert(expert_tokens) + + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum") + + return expert_cache + + +def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: + """ + Reference implementation of DeepSeek-style Mixture of Experts using PyTorch. + + Args: + data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict) + - input: Input tensor of shape [batch_size, seq_len, hidden_dim] + - weights: Dictionary containing model weights + - config: Dictionary containing model configuration parameters + + Returns: + Tuple containing: + - output: Processed tensor [batch_size, seq_len, d_model] + """ + input_tensor, weights, config = data + num_experts = config["n_routed_experts"] + moe = MoETorch(config) + + # Fill in the given weights of the model + moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"]) + + for i in range(num_experts): + gate_proj_weight = weights[f"experts.{i}.0.weight"] + up_proj_weight = weights[f"experts.{i}.1.weight"] + down_proj_weight = weights[f"experts.{i}.2.weight"] + + # Transpose weights to match expected shape for nn.Linear + moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t()) + moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t()) + moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t()) + + moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t()) + moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t()) + moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t()) + + output = moe(input_tensor) + + return output + + +# Input generation for the reference code + + +def generate_input( + dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int +) -> Tuple[torch.Tensor, Dict, Dict]: + # Really dumb but for now _ isn't parsing correctly. + d_hidden = dhidden + d_expert = dexpert + n_routed_experts = nroutedexperts + n_shared_experts = nsharedexperts + n_experts_per_token = nexpertspertoken + batch_size = bs + seq_len = seqlen + + config = { + "d_hidden": d_hidden, + "d_expert": d_expert, + "n_routed_experts": n_routed_experts, + "n_shared_experts": n_shared_experts, + "n_experts_per_token": n_experts_per_token, + "batch_size": batch_size, + "seq_len": seq_len, + } + + gen = torch.Generator(device="cuda") + gen.manual_seed(seed) + + num_experts = n_routed_experts + expert_dim = d_expert + weights = {} + + input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous() + + # Initialize router weights + weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden) + + for i in range(num_experts): + weights[f"experts.{i}.0.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.1.weight"] = torch.randn( + (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim) + + weights[f"experts.{i}.2.weight"] = torch.randn( + (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) + + weights["shared_experts.0.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.1.weight"] = torch.randn( + (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(expert_dim * n_shared_experts) + weights["shared_experts.2.weight"] = torch.randn( + (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen + ) / math.sqrt(d_hidden) + + return (input_tensor, weights, config) + + +def clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(clone_data(x) for x in data) + elif isinstance(data, list): + return [clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data diff --git a/examples/fusedmoe/test_example_fusedmoe.py b/examples/fusedmoe/test_example_fusedmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8415895d52f75dc8bf029b7a97e1cabc983b03 --- /dev/null +++ b/examples/fusedmoe/test_example_fusedmoe.py @@ -0,0 +1,12 @@ +import tilelang.testing +import example_fusedmoe_tilelang + + +def test_example_fusedmoe_tilelang(): + example_fusedmoe_tilelang.main( + d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024 + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gdn/README.md b/examples/gdn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..31dd2361e125595c469fef4b44c5e1128d6e96e4 --- /dev/null +++ b/examples/gdn/README.md @@ -0,0 +1,15 @@ +# Gated Delta Net (GDN) kernel implementation with TileLang + +## Requirement + +- TileLang: `0.1.5+17fafc1b3026d910a83eb8052fdf811ba56be0b1` +- Triton: `3.3.0` (used for comparison) +- FLA: commit `f03cb3ae` (used for comparison) + +## Get started + + The [chunk_delta_h](common/chunk_delta_h.py) implements the most critical forward kernel of GDN. It's a good start to understand the GDN logic and the TileLang optimization. + +## Acknowledgments + +This kernel was developed by Yu Cheng and Zhengju Tang following in-depth discussions with Xiaomi's LLM-Core Team (MiMo). diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..39450bc5fc5917e1958687111ba684aabecf800d --- /dev/null +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -0,0 +1,613 @@ +# Reference: fla/ops/common/chunk_delta_h.py + +import sys # noqa: F401 + +import tilelang +import tilelang.language as T + +print(tilelang.__file__, flush=True) + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__, flush=True) + from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch +import torch.nn.functional as F + +torch.random.manual_seed(0) +# torch.set_printoptions(profile="full") + +from test_utils import assert_similar + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + # Note: G should be in logspace and do chunkwise cumsum + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + try: + from fla.ops.utils.cumsum import chunk_local_cumsum + + G = chunk_local_cumsum(G, chunk_size) + except ImportError: + print("fla not found, skip cumsum") + + h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return Q, K, W, G, h0, dht, dO, dv + + +def prepare_input_fake( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + G = torch.ones(B, S, H, dtype=gate_dtype).cuda() + h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda() + dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda() + dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + return Q, K, W, G, h0, dht, dO, dv + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return dh, dh0, dv2 + + +def torch_chunk_gated_delta_rule_bwd_dhu( + Q: torch.Tensor, + K: torch.Tensor, + W: torch.Tensor, + G: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + dO: torch.Tensor, + dv: torch.Tensor, + scale: float, + use_g: bool, + use_initial_state: bool, + use_final_state_gradient: bool, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + B, S, H, DK = Q.shape + DV = dv.shape[-1] + block_S = 64 + BS = S // block_S + dh, dh0, dv2 = ( + torch.empty((B, BS, H, DK, DV), dtype=output_dtype), + torch.empty((B, H, DK, DV), dtype=state_dtype), + torch.empty((B, S, H, DV), dtype=output_dtype), + ) + dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype) + dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype) + Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype) + + if use_final_state_gradient: + dh_tmp = dht.clone().to(accum_dtype) + else: + dh_tmp = torch.zeros_like(dht).to(accum_dtype) + + for i_s in range(BS - 1, -1, -1): + dh[:, i_s, :, :, :] = dh_tmp + dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3) + if use_g: + for i_bh in range(B * H): + i_b, i_h = i_bh // H, i_bh % H + for i_s2 in range(block_S): + if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0: + dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h]) + else: + dv_tmp[i_b, i_s2, i_h, :] = 0 + dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp + + if use_g: + G_last = G[:, i_s * block_S + block_S - 1, :] + for i_bh in range(B * H): + i_b, i_h = i_bh // H, i_bh % H + dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h]) + Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :] + for i_s2 in range(block_S): + for i_k in range(DK): + Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :]) + Q_tmp *= scale + W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :] + dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :] + + torch.backends.cuda.matmul.allow_tf32 = True + dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3)) + dh_tmp -= torch.matmul(W_tmp.permute(0, 2, 3, 1), dv_tmp.permute(0, 2, 1, 3)) + torch.backends.cuda.matmul.allow_tf32 = False + + if use_initial_state: + dh0 = dh_tmp[:, :, :, :] + else: + dh0 = torch.zeros_like(dh_tmp[:, :, :, :]) + print(dh0.dtype) + + return dh, dh0, dv2 + + +@tilelang.jit(out_idx=[-3, -2, -1]) +def tilelang_chunk_gated_delta_rule_bwd_dhu( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g=True, + use_initial_state=True, + use_final_state_gradient=True, + # kernel config + block_DV=64, + threads=256, + num_stages=0, +): + block_S = chunk_size + # Should support cu_seqlen + BS = S // block_S + + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + W_shape = (B, S, H, DK) + G_shape = (B, S, H) + h0_shape = (B, H, DK, DV) + dht_shape = (B, H, DK, DV) + dO_shape = (B, S, H, DV) + dv_shape = (B, S, H, DV) + + dh_shape = (B, BS, H, DK, DV) + dh0_shape = (B, H, DK, DV) + dv2_shape = (B, S, H, DV) + + @T.prim_func + def kernel( + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype) + b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype) + b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32) + dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32) + dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + G_last_local = T.alloc_local((1), dtype=gate_dtype) + G_last_local_exp = T.alloc_local((1), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared") + G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype) + G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype) + G_fragment_exp = T.alloc_fragment((block_S), dtype=gate_dtype) + Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype) + + T.use_swizzle(10) + + T.annotate_layout( + { + b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), + b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + } + ) + + if use_final_state_gradient: + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) + T.copy(b_dh_shared, b_dh_fragment) + else: + T.clear(b_dh_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # The gradient should be stored in the reverse order + i_s_inv = T.ceildiv(S, block_S) - i_s - 1 + + # Store the updated dh + T.copy(b_dh_fragment, b_dh_shared) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + # Update dv + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) + + if use_g: + T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True) + T.copy(G_shared, G_fragment) + G_last_local[0] = G_shared[block_S - 1] + G_last_local_exp[0] = T.exp(G_last_local[0]) + for i_s2 in T.Parallel(block_S): + G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2]) + for i_s2, i_v in T.Parallel(block_S, block_DV): + # with T.If(G_last_local[0] - G_shared[i_s2] <= 0): + with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): + with T.Then(): + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2] + with T.Else(): + dv_fragment[i_s2, i_v] = 0 + + T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared) + T.copy(dv_shared, dv_fragment_2) + for i_s2, i_v in T.Parallel(block_S, block_DV): + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] + + # Store the updated dv + T.copy(dv_fragment, dv_shared) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + # Update dh + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) + + T.clear(Q_fragment) + if use_g: + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] *= G_last_local_exp[0] + T.copy(Q_shared, Q_fragment) + for i_s2 in T.Parallel(block_S): + G_fragment_exp[i_s2] = T.exp(G_shared[i_s2]) + for i_s2, i_k in T.Parallel(block_S, DK): + # Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale + Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale + else: + T.copy(Q_shared, Q_fragment) + for i_s2, i_k in T.Parallel(block_S, DK): + Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * scale + # Get transpose of Q_fragment to meet tf32 gemm requirement + for i_s2, i_k in T.Parallel(block_S, DK): + Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k] + + T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared) + T.copy(dO_shared, dO_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v] + T.copy(dO_fragment_t, dO_shared_t) + + T.clear(b_dh_fragment_1) + T.gemm(Q_fragment_t, dO_shared_t, b_dh_fragment_1, transpose_B=True) + T.clear(b_dh_fragment_2) + T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True) + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v] + + if use_initial_state: + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def test_result(dh_0, dh0_0, dv2_0, dh_1, dh0_1, dv2_1, name): + try: + torch.testing.assert_close(dh_0, dh_1, rtol=1e-2, atol=1e-2, equal_nan=True) + print(f"{name} dh_0 and dh_1 passed for {name}") + except Exception as e: + print(f"{name} dh_0 and dh_1 are not close for {name}") + print(e, end="\n\n") + try: + torch.testing.assert_close(dh0_0, dh0_1, rtol=1e-2, atol=1e-2, equal_nan=True) + print(f"{name} dh0_0 and dh0_1 passed for {name}") + except Exception as e: + print(f"{name} dh0_0 and dh0_1 are not close for {name}") + print(e, end="\n\n") + try: + torch.testing.assert_close(dv2_0, dv2_1, rtol=1e-2, atol=1e-2, equal_nan=True) + print(f"{name} dv2_0 and dv2_1 passed for {name}") + except Exception as e: + print(f"{name} dv2_0 and dv2_1 are not close for {name}") + print(e, end="\n\n") + + close = torch.isclose(dh_0, dh_1, rtol=1e-2, atol=1e-2) + mismatch_indices = torch.nonzero(~close, as_tuple=True) + error_num = 0 + for indices in zip(*mismatch_indices): + if error_num < 100: + print( + f"{name} dh_0[{[idx.item() for idx in indices]}] = {dh_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}, dh_1[{[idx.item() for idx in indices]}] = {dh_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}" + ) + error_num += 1 + close = torch.isclose(dh0_0, dh0_1, rtol=1e-2, atol=1e-2) + mismatch_indices = torch.nonzero(~close, as_tuple=True) + error_num = 0 + for indices in zip(*mismatch_indices): + if error_num < 100: + print( + f"{name} dh0_0[{[idx.item() for idx in indices]}] = {dh0_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dh0_1[{[idx.item() for idx in indices]}] = {dh0_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}" + ) + error_num += 1 + close = torch.isclose(dv2_0, dv2_1, rtol=1e-2, atol=1e-2) + mismatch_indices = torch.nonzero(~close, as_tuple=True) + error_num = 0 + for indices in zip(*mismatch_indices): + if error_num < 100: + print( + f"{name} dv2_0[{[idx.item() for idx in indices]}] = {dv2_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dv2_1[{[idx.item() for idx in indices]}] = {dv2_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}" + ) + error_num += 1 + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=64, + threads=256, + num_stages=0, + use_torch=False, +): + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref, dh0_ref, dv2_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + + # fla ref + print("fla running...", flush=True) + if use_g: + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) + else: + G = G.fill_(0) + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale) + + # tilelang + print("tilelang running...", flush=True) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) + # kernel = tilelang.compile(program) + print(kernel.get_kernel_source()) + dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) + + fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size) + tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) + + print(f"fla time: {fla_time} ms") + print(f"tilelang time: {tilelang_time} ms") + + assert_similar(dh_tilelang, dh_ref, 1e-5, "fla-tilelang", data="dh") + assert_similar(dh0_tilelang, dh0_ref, 1e-5, "fla-tilelang", data="dh0") + assert_similar(dv2_tilelang, dv2_ref, 1e-5, "fla-tilelang", data="dv2") + + # torch ref + if use_torch: + print("torch running...", flush=True) + if use_g: + dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( + Q, + K, + W, + G, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref_torch = dh_ref_torch.cuda() + dh0_ref_torch = dh0_ref_torch.cuda() + dv2_ref_torch = dv2_ref_torch.cuda() + else: + dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( + Q, + K, + W, + None, + h0, + dht, + dO, + dv, + scale, + use_g, + use_initial_state, + use_final_state_gradient, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dh_ref_torch = dh_ref_torch.cuda() + dh0_ref_torch = dh0_ref_torch.cuda() + dv2_ref_torch = dv2_ref_torch.cuda() + + assert_similar(dh_ref_torch, dh_ref, 1e-5, "torch-fla", data="dh") + assert_similar(dh0_ref_torch, dh0_ref, 1e-5, "torch-fla", data="dh0") + assert_similar(dv2_ref_torch, dv2_ref, 1e-5, "torch-fla", data="dv2") + assert_similar(dh_ref_torch, dh_tilelang, 1e-5, "torch-tilelang", data="dh") + assert_similar(dh0_ref_torch, dh0_tilelang, 1e-5, "torch-tilelang", data="dh0") + assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2") + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def main(): + DK = 128 + run_test( + B=1, + S=32768, + H=8, + DK=DK, + DV=128, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + scale=DK**-0.5, + use_g=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=32, + threads=128, + num_stages=1, + use_torch=False, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py new file mode 100644 index 0000000000000000000000000000000000000000..d316a62116c8b2fda7ca24daddc838e0cda94144 --- /dev/null +++ b/examples/gdn/example_chunk_delta_h.py @@ -0,0 +1,408 @@ +# Reference: fla/ops/common/chunk_delta_h.py + +import sys # noqa: F401 +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch +import torch.nn.functional as F +from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 + +from test_utils import assert_similar + +# (zhengju) We can slightly modify the generated cuda code from tilelang lowering +# in the debug folder to make the performance better. To enable this callback, +# you can comment out the following function. +# @register_cuda_postproc_callback +# def tilelang_callback_cuda_postproc(code, _): +# cuda_code = open("../debug/chunk_delta_h_fuse.cu", "r").read() +# code = cuda_code +# return code + +torch.random.manual_seed(0) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + W = F.normalize(W, dim=-1, p=2) + U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + U = F.normalize(U, dim=-1, p=2) + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + try: + from fla.ops.utils.cumsum import chunk_local_cumsum + + G = chunk_local_cumsum(G, chunk_size) + except ImportError: + print("fla not found, skip cumsum") + + initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + return K, W, U, G, initial_state + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + state_dtype, +): + BS = S // chunk_size + h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return h, final_state, V_new + + +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [128, 256] + num_stages = [1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_gated_delta_rule_fwd_h( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + # kernel config + block_DK=64, + block_DV=32, + threads=128, + num_stages=1, +): + block_S = chunk_size + BS = S // block_S + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + W_shape = (B, S, H, DK) + U_shape = (B, S, H, DV) + G_shape = (B, S, H) + h_shape = (B, BS, H, DK, DV) + initial_state_shape = (B, H, DK, DV) + final_state_shape = (B, H, DK, DV) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype) + b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + + U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + G_last_local = T.alloc_local((1), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) + G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) + + T.annotate_layout( + { + b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + G_shared: tilelang.layout.make_swizzled_layout(G_shared), + } + ) + + T.use_swizzle(10) + + if use_initial_state: + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) + T.copy(b_h_shared, b_h_fragment) + else: + T.clear(b_h_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # Store previous result to the hidden tensor, like the epilogue + T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + # Recurrence + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared) + T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) + + # U - W * S + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) + T.copy(U_shared, U_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] + + # Save V_new + if save_new_value: + T.copy(V_new_fragment, dst=V_new_shared) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) + # use_g + if use_g: + G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] + for i_s2, i_v in T.Parallel(block_S, block_DV): + G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh] + T.copy(G_shared, G_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): + with T.Then(): + V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2( + (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695 + ) + with T.Else(): + V_new_fragment[i_s2, i_v] = 0 + G_last_local[0] = T.exp2(G_last_local[0] * 1.442695) + for i_k, i_v in T.Parallel(DK, block_DV): + b_h_fragment[i_k, i_v] *= G_last_local[0] + + # Update intermediate results + T.copy(V_new_fragment, V_new_shared) + T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True) + + T.copy(b_h_fragment, b_h_shared) + + # Save final state + if store_final_state: + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=64, + block_DV=32, + threads=128, + num_stages=0, +): + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + + # fla ref + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + ) + + # tilelang + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) + # (zhengju) If you want to print the generated cuda code, you can uncomment the following line + # print("CUDA Code:\n", kernel.get_kernel_source()) + + fla_time = do_bench( + chunk_gated_delta_rule_fwd_h, + k=K, + w=W, + u=U, + g=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + ) + tilelang_time = do_bench(kernel, K, W, U, G, initial_state) + + # check correctness + try: + h_ref_fp32 = h_ref.to(torch.float32) + h_tilelang_fp32 = h_tilelang.to(torch.float32) + assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False) + print("tilelang chunk gated delta rule fwd h passed โˆš") + except Exception as e: + print("tilelang chunk gated delta rule fwd h failed โœ—") + print(e) + + try: + final_state_ref_fp32 = final_state_ref.to(torch.float32) + final_state_tilelang_fp32 = final_state_tilelang.to(torch.float32) + assert_similar( + final_state_ref_fp32, + final_state_tilelang_fp32, + eps=1e-5, + name="tilelang chunk gated delta rule fwd final_state", + raise_assert=False, + ) + print("tilelang chunk gated delta rule fwd final_state passed โˆš") + except Exception as e: + print("tilelang chunk gated delta rule fwd final_state failed โœ—") + print(e) + + try: + V_new_ref_fp32 = V_new_ref.to(torch.float32) + V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32) + assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False) + print("tilelang chunk gated delta rule fwd V_new passed โˆš") + except Exception as e: + print("tilelang chunk gated delta rule fwd V_new failed โœ—") + print(e) + + print(f"tilelang time: {tilelang_time} ms") + print(f"fla time: {fla_time} ms") + + +def main(): + run_test( + B=1, + S=32768, + H=32, + DK=128, + DV=128, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + use_g=True, + use_initial_state=False, + store_final_state=True, + save_new_value=True, + block_DK=32, + block_DV=32, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gdn/example_chunk_o.py b/examples/gdn/example_chunk_o.py new file mode 100644 index 0000000000000000000000000000000000000000..81536815923b294c38a02540f436670fc71798a4 --- /dev/null +++ b/examples/gdn/example_chunk_o.py @@ -0,0 +1,246 @@ +# Reference: fla/ops/common/chunk_o.py + +import tilelang +import tilelang.language as T +import sys # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.common.chunk_o import chunk_fwd_o +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch + +torch.random.manual_seed(1) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + BS = chunk_size + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + HIDDEN = torch.randn(B, S // BS, H, DK, DV, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + return Q, K, V, HIDDEN, G + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, +): + O = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return O + + +@tilelang.jit(out_idx=[-1]) +def tilelang_chunk_fwd_o( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + # kernel config + block_S=64, + block_DK=64, + block_DV=64, + threads=256, + num_stages=0, +): + assert chunk_size == block_S, "chunk_size must be equal to block_S" + BS = chunk_size + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + H_shape = (B, S // BS, H, DK, DV) + G_shape = (B, S, H) + O_shape = (B, S, H, DV) + + @T.prim_func + def kernel( + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + HIDDEN: T.Tensor(H_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + O: T.Tensor(O_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh): + bb, bh = bbh // H, bbh % H + Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + H_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype) + O_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + O_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") + G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype) + + T.annotate_layout( + { + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + H_shared: tilelang.layout.make_swizzled_layout(H_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + T.clear(A_fragment) + T.clear(O_fragment) + T.disable_warp_group_reg_alloc() + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared) + T.gemm(Q_shared, H_shared, O_fragment) + T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True) + + if use_g: + for i_s in T.Parallel(block_S): + G_shared[i_s] = G[bb, bs * block_S + i_s, bh] + # T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) + for i_s, i_v in T.Parallel(block_S, block_DV): + O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * T.exp(G_shared[i_s]) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(G_diff_local[i_s1, i_s2] <= 0): + with T.Then(): + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) + with T.Else(): + A_fragment[i_s1, i_s2] = 0 + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 < i_s2): # noqa: SIM117 + with T.Then(): + A_fragment[i_s1, i_s2] = 0 + + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared) + T.copy(A_fragment, A_shared) + T.gemm(A_shared, V_shared, O_fragment) + + for i_s, i_v in T.Parallel(block_S, block_DV): + O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale + + T.copy(O_fragment, O_shared) + T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + use_g, + block_DK, + block_DV, + threads, + num_stages, +): + input_dtype_torch = getattr(torch, input_dtype) + output_dtype_torch = getattr(torch, output_dtype) + accum_dtype_torch = getattr(torch, accum_dtype) + gate_dtype_torch = getattr(torch, gate_dtype) + Q, K, V, HIDDEN, G = prepare_input( + B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch + ) + scale = 1.0 / DK**0.5 + + O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) + O_ref = chunk_fwd_o(Q, K, V, HIDDEN, G, scale, chunk_size=chunk_size) + + block_S = chunk_size + O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) + O_tilelang = kernel(Q, K, V, HIDDEN, G) + + try: + torch.testing.assert_close(O_tilelang, O_ref, rtol=1e-2, atol=1e-2) + print("tilelang chunk fwd o passed โˆš") + except Exception as e: + print("tilelang chunk fwd o failed โœ—") + print(e) + + +def main(): + run_test( + B=1, + S=32768, + H=32, + DK=128, + DV=128, + chunk_size=64, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + use_g=True, + block_DK=128, + block_DV=128, + threads=128, + num_stages=1, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..97e2f4f01e08f29fcb845df5af728c7bf00b2728 --- /dev/null +++ b/examples/gdn/example_chunk_o_bwd.py @@ -0,0 +1,526 @@ +# Reference: fla/ops/common/chunk_o.py + +import math +import sys # noqa: F401 + +import tilelang +import tilelang.language as T +from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.common.chunk_o import chunk_bwd_dqkwg +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch +from test_utils import assert_similar + +torch.random.manual_seed(0) +# torch.set_printoptions(profile="full") + + +def prepare_input_fake( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + h = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda() + G = torch.ones(B, S, H, dtype=gate_dtype).cuda() + dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + dh = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda() + dv = torch.ones(B, S, H, DV, dtype=output_dtype).cuda() + W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + return Q, K, V, h, G, dO, dh, dv, W + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + dv = torch.randn(B, S, H, DV, dtype=output_dtype).cuda() + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + return Q, K, V, h, G, dO, dh, dv, W + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, + block_DK, +): + assert DK == 32 and block_DK == 32 or DK > 32 and block_DK >= 64, "When DK > 32, block_DK must be >= 64" + NK = math.ceil(DK / block_DK) + dq = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dw = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dg = torch.empty(NK, B, S, H, dtype=gate_dtype).cuda() + return dq, dk, dw, dg + + +# @register_cuda_postproc_callback +# def tilelang_callback_cuda_postproc(code, _): +# cuda_code = open("../debug/chunk_o_bwd3.log", "r").read() +# code = cuda_code +# return code + + +@tilelang.jit( + out_idx=[-4, -3, -2, -1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) +def tilelang_chunk_o_bwd_dqkwg( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g=True, + use_dw=True, + # kernel config + block_DK=64, + block_DV=64, + threads=256, + num_stages=0, +): + block_S = chunk_size + BS = S // block_S + NK = math.ceil(DK / block_DK) + + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + h_shape = (B, BS, H, DK, DV) + G_shape = (B, S, H) + dO_shape = (B, S, H, DV) + dh_shape = (B, BS, H, DK, DV) + dv_shape = (B, S, H, DV) + W_shape = (B, S, H, DK) + + dq_shape = (B, S, H, DK) + dk_shape = (B, S, H, DK) + dw_shape = (B, S, H, DK) + dg_shape = (NK, B, S, H) + + @T.prim_func + def kernel( + # input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dh: T.Tensor(dh_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + # output + dq: T.Tensor(dq_shape, dtype=output_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dw: T.Tensor(dw_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): + bb, bh = bbh // H, bbh % H + + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + h_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dh_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + k_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + ds_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype) + dg_shared_1 = T.alloc_shared((block_S,), dtype=gate_dtype) + dg_shared_2 = T.alloc_shared((block_S,), dtype=gate_dtype) + dk_shared = T.alloc_shared((block_S, block_DK), dtype=accum_dtype) + + ds_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + ds_fragment_positive = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + ds_fragment_positive_transpose = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dq_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dk_fragment_2 = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dw_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + q_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype) + k_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype) + + dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype) + dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype) + dg_fragment_2 = T.alloc_fragment((block_S,), dtype=gate_dtype) + dg_fragment_final = T.alloc_fragment((block_S,), dtype=gate_dtype) + dg_last_local = T.alloc_local((2,), dtype=gate_dtype) + dg_last_fragment = T.alloc_fragment((block_DV * block_DK), dtype=gate_dtype) + dg_last_fragment_scalar = T.alloc_fragment((1,), dtype=gate_dtype) + dg_last_fragment_2 = T.alloc_fragment((block_S * block_DK), dtype=gate_dtype) + dg_last_fragment_scalar_2 = T.alloc_fragment((1,), dtype=gate_dtype) + G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype, scope="shared") + G_last_local = T.alloc_local((1,), dtype=gate_dtype) + + T.use_swizzle(10) + + T.annotate_layout( + { + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + h_shared: tilelang.layout.make_swizzled_layout(h_shared), + dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + q_shared: tilelang.layout.make_swizzled_layout(q_shared), + k_shared: tilelang.layout.make_swizzled_layout(k_shared), + } + ) + + T.clear(dg_last_local) + T.clear(G_last_local) + T.clear(G_shared) + T.clear(q_fragment) + T.clear(k_fragment) + T.clear(dg_last_fragment) + + T.clear(ds_fragment) + T.clear(dq_fragment) + T.clear(dk_fragment) + T.clear(dw_fragment) + + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared) + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) + + if use_g: + T.clear(dg_last_fragment_scalar) + # FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result + # for i_kv in T.Parallel(block_DK * block_DV): + # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] + for i_kv in T.Parallel(block_DK * block_DV): + dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] + T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) + dg_last_local[0] += dg_last_fragment_scalar[0] + + T.gemm(dO_shared, V_shared, ds_fragment, transpose_B=True) + T.gemm(dO_shared, h_shared, dq_fragment, transpose_B=True) + T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True) + + if use_dw: + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared) + T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True) + + if use_dw: + for i_s, i_k in T.Parallel(block_S, block_DK): + dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared) + T.copy(q_shared, q_fragment) + T.copy(k_shared, k_fragment) + + if use_g: + T.clear(dg_fragment) + T.clear(dg_fragment_2) + for i_s, i_k in T.Parallel(block_S, block_DK): + G_shared[i_s, i_k] = G[bb, bs * block_S + i_s, bh] + G_last_local[0] = G[bb, bs * block_S + block_S - 1, bh] + # Use gmem directly instead of local register + dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) + + for i_s, i_k in T.Parallel(block_S, block_DK): + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale + T.clear(dg_fragment_reduce_tmp) + for i_s, i_k in T.Parallel(block_S, block_DK): + dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k] + # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass + T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False) + + for i_s, i_k in T.Parallel(block_S, block_DK): + with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): + with T.Then(): + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh]) + with T.Else(): + dk_fragment[i_s, i_k] = 0 + T.clear(dg_fragment_reduce_tmp) + for i_s, i_k in T.Parallel(block_S, block_DK): + dg_fragment_reduce_tmp[i_s, i_k] = dk_fragment[i_s, i_k] * (-k_shared[i_s, i_k]) + # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass + T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False) + + # FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result + T.copy(dk_fragment, dk_shared) + T.clear(dg_last_fragment_scalar_2) + for i_sk in T.Parallel(block_S * block_DK): + i_s, i_k = i_sk // block_DK, i_sk % block_DK + dg_last_fragment_2[i_sk] = dk_shared[i_s, i_k] * k_shared[i_s, i_k] + T.reduce_sum(dg_last_fragment_2, dg_last_fragment_scalar_2, dim=-1, clear=False) + dg_last_local[1] = dg_last_fragment_scalar_2[0] + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): + with T.Then(): + ds_fragment[i_s1, i_s2] = ( + ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale + ) + with T.Else(): + ds_fragment[i_s1, i_s2] = 0 + + T.clear(ds_fragment_positive) + T.clear(ds_fragment_positive_transpose) + T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2] + + # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass + T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False) + T.copy(dg_fragment, dg_shared_1) + + # We should transpose the matrix because the reduce_sum statement can only reduce along the last dimension + for i_s1, i_s2 in T.Parallel(block_S, block_S): + ds_fragment_positive_transpose[i_s2, i_s1] = ds_fragment_positive[i_s1, i_s2] + + # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass + T.reduce_sum(ds_fragment_positive_transpose, dg_fragment_2, dim=1, clear=False) + T.copy(dg_fragment_2, dg_shared_2) + + for i_s in T.Parallel(block_S): + dg_fragment_final[i_s] = dg_shared_1[i_s] - dg_shared_2[i_s] + + T.copy(ds_fragment, ds_shared) + T.gemm(ds_shared, k_shared, dq_fragment) + T.gemm(ds_shared, q_shared, dk_fragment, transpose_A=True) + + for i_s in T.Parallel(block_S): + with T.If(i_s >= block_S - 1): # noqa: SIM117 + with T.Then(): + dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1] + + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + for i_s in T.Parallel(block_S): + dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s] + + else: + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 < i_s2): # noqa: SIM117 + with T.Then(): + ds_fragment[i_s1, i_s2] = 0 + T.clear(dk_fragment_2) + T.copy(ds_fragment, ds_shared) + T.gemm(ds_shared, k_shared, dq_fragment) + T.gemm(ds_shared, q_shared, dk_fragment_2, transpose_A=True) + for i_s, i_k in T.Parallel(block_S, block_DK): + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + return kernel + + +def do_bench(fn, *args, warmup=10, rep=10, **kwargs): + """ + Do benchmark for a function. + """ + start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] + for _ in range(warmup): + fn(*args, **kwargs) + + torch.cuda.synchronize() + for i in range(rep): + start_event[i].record() + fn(*args, **kwargs) + end_event[i].record() + torch.cuda.synchronize() + + # Record clocks + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + + return times.mean().item() + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g=True, + use_dw=True, + block_DK=64, + block_DV=64, + threads=256, + num_stages=0, +): + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dq_ref, dk_ref, dw_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK + ) + + # ref + if use_g: + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + else: + dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale) + + # tilelang + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_g, + use_dw, + block_DK, + block_DV, + threads, + num_stages, + ) + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) + + if use_g: + dg_tilelang = dg_tilelang.sum(dim=0) + + # check + try: + assert_similar(dq_ref, dq_tilelang, 1e-5, "tilelang chunk o bwd dq") + print("tilelang chunk o bwd dq passed โˆš") + except Exception as e: + print("tilelang chunk o bwd dq failed โœ—") + print(e) + + try: + assert_similar(dk_ref, dk_tilelang, 1e-5, "tilelang chunk o bwd dk") + print("tilelang chunk o bwd dk passed โˆš") + except Exception as e: + print("tilelang chunk o bwd dk failed โœ—") + print(e) + + if use_g: + try: + assert_similar(dg_ref, dg_tilelang, 1e-5, "tilelang chunk o bwd dg") + print("tilelang chunk o bwd dg passed โˆš") + except Exception as e: + print("tilelang chunk o bwd dg failed โœ—") + print(e) + + if use_dw: + try: + assert_similar(dw_ref, dw_tilelang, 1e-5, "tilelang chunk o bwd dw") + print("tilelang chunk o bwd dw passed โˆš") + except Exception as e: + print("tilelang chunk o bwd dw failed โœ—") + print(e) + + +def main(): + DK = 128 + DV = 128 + run_test( + B=1, + S=32768, + H=8, + DK=DK, + DV=DV, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + scale=DK**-0.5, + # scale=1, + use_g=True, + use_dw=True, + block_DK=64, + block_DV=64, + threads=128, + num_stages=0, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gdn/example_chunk_scaled_dot_kkt.py b/examples/gdn/example_chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ef17e3f4b50479958692460d4fc785b0e82a18 --- /dev/null +++ b/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -0,0 +1,197 @@ +# Reference: fla/ops/common/chunk_scaled_dot_kkt.py + +import tilelang +import tilelang.language as T +import sys # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch + +torch.set_printoptions(profile="full") +torch.random.manual_seed(0) + + +def prepare_input( + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, +): + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=accum_dtype).cuda() + return K, Beta, G + + +def prepare_output( + B, + S, + H, + chunk_size, + dtype, +): + BS = chunk_size + A = torch.empty(B, S, H, BS, dtype=dtype).cuda() + return A + + +@tilelang.jit(out_idx=[-1]) +def tilelang_chunk_scaled_dot_kkt_fwd( + # task config + B, + S, + H, + DK, + chunk_size=64, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + use_g=True, + # kernel config + block_S=64, + block_DK=64, + threads=256, + num_stages=0, +): + K_shape = (B, S, H, DK) + Beta_shape = (B, S, H) + G_shape = (B, S, H) + assert chunk_size == block_S, "chunk_size must be equal to block_S" + BS = chunk_size + output_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=accum_dtype), + A: T.Tensor(output_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + # !! Pay attention to the scope of the shared memory: may cause misaligned address when shape is one dimension or the buffer is too small + Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared") + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype) + Beta_K_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype) + A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + + # Tensor used for gated: + G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared") + G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + } + ) + + T.fill(A_fragment, 0) + T.disable_warp_group_reg_alloc() + for i_s in T.Parallel(block_S): + Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] + + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] + T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True) + + if use_g: + for i_s in T.Parallel(block_S): + G_shared[i_s] = G[bb, bs * block_S + i_s, bh] + for i_s1, i_s2 in T.Parallel(block_S, block_S): + G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): + with T.Then(): + A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) + with T.Else(): + A_fragment[i_s1, i_s2] = 0 + else: + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 <= i_s2): # noqa: SIM117 + with T.Then(): + A_fragment[i_s1, i_s2] = 0 + + T.copy(A_fragment, A_shared) + T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + use_g, + block_DK, + threads, + num_stages, +): + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) + A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) + A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) + + # reference + if use_g: + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + else: + A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype)) + + # tilelang + block_S = chunk_size + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) + A_tilelang = kernel(K, Beta, G) + + try: + torch.testing.assert_close(A_tilelang, A_ref, rtol=1e-2, atol=1e-2) + print("tilelang chunk scaled dot kkt fwd passed โˆš") + except Exception as e: + print("tilelang chunk scaled dot kkt fwd failed โœ—") + print(e) + print("reference cuda kernel:") + print(kernel.get_kernel_source()) + + +def main(): + run_test( + B=1, + S=32768, + H=32, + DK=128, + chunk_size=64, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + use_g=True, + block_DK=64, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gdn/example_cumsum.py b/examples/gdn/example_cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0760b496458ae334d1baf54934c0526962a529a6 --- /dev/null +++ b/examples/gdn/example_cumsum.py @@ -0,0 +1,165 @@ +# Util functions for flash linear attention cumsum +# Reference: fla/ops/utils/cumsum.py + +import tilelang +import tilelang.language as T +import sys # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.utils.cumsum import chunk_local_cumsum_scalar +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch + + +@tilelang.jit( + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} +) +def tilelang_chunk_local_cumsum_scalar( + # task config + B, + S, + H, + chunk_size=64, + is_varlen=False, + head_first=False, + reverse=False, + input_dtype=T.float16, + output_dtype=T.float32, + # kernel config + block_S=64, + threads=256, + use_fragment=False, +): + G_shape = (B, H, S) if head_first else (B, S, H) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + assert chunk_size == block_S, "chunk_size must be equal to block_S" + + @T.prim_func + def kernel( + G: T.Tensor(G_shape, dtype=input_dtype), + G_new: T.Tensor(G_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") + if head_first: + T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared) + else: + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared) + if use_fragment: + G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") + T.copy(G_shared, G_fragment) + T.cumsum(G_fragment, dim=1, reverse=reverse) + if head_first: + T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) + else: + T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) + else: + T.cumsum(G_shared, dim=1, reverse=reverse) + if head_first: + T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S]) + else: + T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh]) + + return kernel + + +def prepare_cumsum_input( + B, + S, + H, + dtype, +): + G = torch.randn(B, S, H, dtype=dtype).cuda() + return G + + +def prepare_cumsum_output( + B, + S, + H, + dtype, +): + G_new = torch.empty(B, S, H, dtype=dtype).cuda() + return G_new + + +def run_test( + B, + S, + H, + chunk_size, + reverse, + head_first, + input_dtype, + output_dtype, + threads, + use_fragment, +): + G = prepare_cumsum_input(B, S, H, getattr(torch, input_dtype)) + G_new_ref = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype)) + G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype)) + + # reference cumsum + G_new_ref = chunk_local_cumsum_scalar( + g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype) + ) + + # tilelang cumsum + block_S = chunk_size + kernel = tilelang_chunk_local_cumsum_scalar( + B=B, + S=S, + H=H, + chunk_size=chunk_size, + reverse=reverse, + head_first=head_first, + input_dtype=input_dtype, + output_dtype=output_dtype, + block_S=block_S, + threads=threads, + use_fragment=use_fragment, + ) + torch.cuda.profiler.start() + G_new_tilelang = kernel(G) + torch.cuda.profiler.stop() + try: + torch.testing.assert_close(G_new_tilelang, G_new_ref, rtol=1e-2, atol=1e-2) + print("tilelang cumsum passed โˆš") + except Exception as e: + print("tilelang cumsum failed โœ—") + print(e) + print("G:") + print(G.view(-1)) + print("G_new_tilelang:") + print(G_new_tilelang.view(-1)) + print("G_new_ref:") + print(G_new_ref.view(-1)) + + +def main(): + run_test( + B=1, + S=32768, + H=32, + chunk_size=64, + reverse=True, + head_first=False, + input_dtype=T.float32, + output_dtype=T.float32, + threads=256, + use_fragment=False, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gdn/example_wy_fast.py b/examples/gdn/example_wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac086ca76916afddbb671377b97809546528e40 --- /dev/null +++ b/examples/gdn/example_wy_fast.py @@ -0,0 +1,220 @@ +# Reference: fla/ops/gated_delta_rule/wy_fast.py + +import tilelang +import tilelang.language as T +import sys # noqa: F401 + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch + +torch.random.manual_seed(1) + + +def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32): + BS = chunk_size + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + A = torch.randn(B, S, H, BS, dtype=output_dtype).cuda() + return K, V, Beta, G, A + + +def prepare_output( + B, + S, + H, + DK, + DV, + output_dtype, +): + W = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + U = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return W, U + + +@tilelang.jit(out_idx=[-2, -1]) +def tilelang_recompute_w_u_fwd( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + gate_dtype, + accum_dtype, + chunk_size, + # kernel config + block_S=64, + block_DK=64, + block_DV=64, + threads=256, + num_stages=0, +): + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + Beta_shape = (B, S, H) + assert chunk_size == block_S, "chunk_size must be equal to block_S" + BS = chunk_size + G_shape = (B, S, H) + A_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=output_dtype), + W: T.Tensor(K_shape, dtype=output_dtype), + U: T.Tensor(V_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared") + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") + A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype) + W_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + W_shared = T.alloc_shared((block_S, block_DK), dtype=output_dtype) + U_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + + T.annotate_layout( + { + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + V_shared: tilelang.layout.make_swizzled_layout(V_shared), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + W_shared: tilelang.layout.make_swizzled_layout(W_shared), + U_shared: tilelang.layout.make_swizzled_layout(U_shared), + W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), + U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), + } + ) + + T.disable_warp_group_reg_alloc() + for i_s in T.Parallel(block_S): + Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] + G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) + + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) + + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + for i_s, i_v2 in T.Parallel(block_S, block_DV): + U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] + T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True) + # First copy to smem, then copy to gmem to reduce U2RU instructions + T.copy(U_fragment, U_shared) + T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) + + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s] + T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True) + # First copy to smem, then copy to gmem to reduce U2RU instructions + T.copy(W_fragment, W_shared) + T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + gate_dtype, + accum_dtype, + block_DK, + block_DV, + threads, + num_stages, +): + K, V, Beta, G, A = prepare_input( + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) + W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) + W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) + + # reference + W_ref, U_ref = recompute_w_u_fwd(K, V, Beta, G, A, None) + + # tilelang + block_S = chunk_size + kernel = tilelang_recompute_w_u_fwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + gate_dtype, + accum_dtype, + chunk_size, + block_S=block_S, + block_DK=block_DK, + block_DV=block_DV, + threads=threads, + num_stages=num_stages, + ) + print(kernel.get_kernel_source()) + W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) + + try: + torch.testing.assert_close(W_tilelang, W_ref, rtol=1e-2, atol=1e-2) + print("tilelang recompute w passed โˆš") + except Exception as e: + print("tilelang recompute w failed โœ—") + print(e) + try: + torch.testing.assert_close(U_tilelang, U_ref, rtol=1e-2, atol=1e-2) + print("tilelang recompute u passed โˆš") + except Exception as e: + print("tilelang recompute u failed โœ—") + print(e) + + +def main(): + run_test( + B=1, + S=32768, + H=32, + DK=128, + DV=128, + chunk_size=64, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + gate_dtype=T.float32, + accum_dtype=T.float32, + block_DK=64, + block_DV=32, + threads=128, + num_stages=3, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py new file mode 100644 index 0000000000000000000000000000000000000000..de8afc2b7770432db2535c05c69a852b1af802da --- /dev/null +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -0,0 +1,535 @@ +# Reference: fla/ops/gated_delta_rule/wy_fast.py + +import sys # noqa: F401 + +import tilelang +import tilelang.language as T + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id 00000000 +# sys.path.insert(0, "/home/tzj/flash-linear-attention") +try: + import fla + + print(fla.__file__) + from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr +except ImportError: + print("fla not found, using tilelang implementation") + fla = None + +import torch +import torch.nn.functional as F + +torch.random.manual_seed(0) +torch.set_printoptions(profile="full") + + +def prepare_input_fake( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BS = chunk_size + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + Beta = torch.ones(B, S, H, dtype=input_dtype).cuda() + G = torch.ones(B, S, H, dtype=gate_dtype).cuda() + A = torch.ones(B, S, H, BS, dtype=input_dtype).cuda() + dw = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + du = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + return K, V, Beta, G, A, dw, du + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BS = chunk_size + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + V = F.normalize(V, dim=-1, p=2) + Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + G = torch.randn(B, S, H, dtype=gate_dtype).cuda() + A = torch.randn(B, S, H, BS, dtype=input_dtype).cuda() + dw = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + du = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return K, V, Beta, G, A, dw, du + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, +): + dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda() + dg = torch.empty(B, S, H, dtype=gate_dtype).cuda() + return dk, dv, dbeta, dg + + +@tilelang.jit( + out_idx=[-5, -4, -3, -2, -1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) +def tilelang_wy_fast_bwd( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + # kernel config + block_DK=64, + block_DV=64, + threads=128, + num_stages=0, +): + block_S = chunk_size + BS = block_S + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + Beta_shape = (B, S, H) + G_shape = (B, S, H) + A_shape = (B, S, H, BS) + dw_shape = (B, S, H, DK) + du_shape = (B, S, H, DV) + + dk_shape = (B, S, H, DK) + dv_shape = (B, S, H, DV) + dbeta_shape = (B, S, H) + dg_shape = (B, S, H) + dA_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + # output + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared_beta_g = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + V_shared_beta = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype) + G_shared = T.alloc_shared((block_S,), dtype=gate_dtype) + G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype) + dw_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + du_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + + dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dk_fragment_beta_g = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dv_fragment_beta = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype) + dbeta_fragment_v = T.alloc_fragment((block_S,), dtype=accum_dtype) + dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dbeta_fragment_reduce_tmpv = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype) + dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype) + + T.use_swizzle(10) + + T.clear(dA_fragment) + T.clear(dk_fragment) + T.clear(dk_fragment_beta_g) + T.clear(dv_fragment) + T.clear(dv_fragment_beta) + T.clear(dbeta_fragment_k) + T.clear(dbeta_fragment_v) + T.clear(dg_fragment) + + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) + for i_s in T.Parallel(block_S): + Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] + G_shared[i_s] = G[bb, bs * block_S + i_s, bh] + G_shared_exp[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) + + # Update dk + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared) + T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True) + T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s] + # for i_s, i_k2 in T.Parallel(block_S, block_DK): + # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] + T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) + + # for i_s, i_k2 in T.Parallel(block_S, block_DK): + # dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dg_fragment_reduce_tmp[i_s, i_k2] = ( + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] + ) + T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False) + + # correct dk + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + # Update dv + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + for i_s, i_v2 in T.Parallel(block_S, block_DV): + V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] + T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared) + T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True) + T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True) + for i_s, i_v2 in T.Parallel(block_S, block_DV): + dv_fragment[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * Beta_shared[i_s] + # for i_s, i_v2 in T.Parallel(block_S, block_DV): + # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] + for i_s, i_v2 in T.Parallel(block_S, block_DV): + dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] + T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False) + + T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV]) + + # Temporary store dbeta, dg and dA + for i_s in T.Parallel(block_S): + dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s] + dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s] + # correct dA + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :]) + + return kernel + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) +def tilelang_wy_fast_bwd_split( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + # kernel config + block_DK=64, + block_DV=64, + threads=128, + num_stages=0, +): + block_S = chunk_size + BS = block_S + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + Beta_shape = (B, S, H) + G_shape = (B, S, H) + A_shape = (B, S, H, BS) + dw_shape = (B, S, H, DK) + du_shape = (B, S, H, DV) + + dk_shape = (B, S, H, DK) + dv_shape = (B, S, H, DV) + dbeta_shape = (B, S, H) + dA_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + # input + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + G: T.Tensor(G_shape, dtype=gate_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dw: T.Tensor(dw_shape, dtype=input_dtype), + du: T.Tensor(du_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=output_dtype), + dv: T.Tensor(dv_shape, dtype=output_dtype), + dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), + dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), + dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype) + A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + dA_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype) + dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dA_A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dA_A_fragment_1 = T.alloc_fragment((block_S,), dtype=accum_dtype) + dA_A_fragment_2 = T.alloc_fragment((block_S,), dtype=accum_dtype) + dk_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + dk_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dk_fragment_beta = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype) + dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype) + dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype) + G_shared = T.alloc_shared((block_S,), dtype=gate_dtype) + G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype) + + T.clear(dbeta_fragment_reduce_tmpk) + T.clear(dbeta_fragment_k) + T.clear(dA_A_fragment_1) + T.clear(dA_A_fragment_2) + + T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared) + for i_s in T.Parallel(block_S): + Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] + G_shared[i_s] = G[bb, bs * block_S + i_s, bh] + for i_s in T.Parallel(block_S): + G_shared_exp[i_s] = T.exp(G_shared[i_s]) + + # Load intermediate results + # for i_s in T.Parallel(block_S): + # dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh] + # dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh] + T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared) + # T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + + # Update dA + T.copy(dA_shared, dA_fragment) + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 <= i_s2): # noqa: SIM117 + with T.Then(): + dA_fragment[i_s1, i_s2] = 0 + T.copy(dA_fragment, dA_shared) + T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True) + T.copy(dA_fragment, dA_shared) + T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(i_s1 <= i_s2): + with T.Then(): + dA_fragment[i_s1, i_s2] = 0 + with T.Else(): + dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2] + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): + with T.Then(): + dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) + with T.Else(): + dA_fragment[i_s1, i_s2] = 0 + T.copy(dA_fragment, dA_shared) + + # acceptable dA diff + # T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) + + # Update dk using previous dk + T.clear(A_fragment) + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared) + T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) + T.copy(dk_shared, dk_fragment) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] + T.gemm(K_shared_beta, K_shared, A_fragment, transpose_B=True) + T.gemm(dA_shared, K_shared, dk_fragment_beta, clear_accum=True) + # for i_s, i_k2 in T.Parallel(block_S, block_DK): + # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] + T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) + T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True) + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s] + for i_s, i_k2 in T.Parallel(block_S, block_DK): + dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2] + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + # Update dg and dbeta + T.copy(A_fragment, A_shared) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + dA_A_fragment[i_s1, i_s2] = dA_fragment[i_s1, i_s2] * A_fragment[i_s1, i_s2] + # Note: Reduce operation now not supported in shared memory + # FIXME: reduce will cause incorrect result when dim != -1 + T.reduce_sum(dA_A_fragment, dA_A_fragment_1, dim=1) + T.reduce_sum(dA_A_fragment, dA_A_fragment_2, dim=0) + + for i_s1, i_s2 in T.Parallel(block_S, block_S): + dg_A_positive[bb, bs * block_S + i_s1, bh, i_s2] = dA_A_fragment[i_s1, i_s2] + dg_A_negative[bb, bs * block_S + i_s2, bh, i_s1] = dA_A_fragment[i_s1, i_s2] + + for i_s in T.Parallel(block_S): + dbeta_k[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK=64, + block_DV=64, + threads=128, + num_stages=0, +): + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + BS = chunk_size + dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() + dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() + dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() + dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() + + # ref + dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None) + + # tilelang + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) + torch.cuda.synchronize() + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) + torch.cuda.synchronize() + + dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) + + from test_utils import assert_similar + + assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) + assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) + assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) + assert_similar(dg_ref, dg_tilelang, eps=1e-5, name="dg", raise_assert=False) + + +def main(): + DK = 128 + DV = 128 + run_test( + B=1, + S=32768, + H=8, + DK=DK, + DV=DV, + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + block_DK=32, + block_DV=32, + threads=128, + num_stages=0, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py new file mode 100644 index 0000000000000000000000000000000000000000..e749fa0874ee5dbddecbabe23ce1c72f74662647 --- /dev/null +++ b/examples/gdn/test_example_gdn_compilation.py @@ -0,0 +1,320 @@ +import torch +import tilelang.testing +from tilelang import language as T + +B = 1 +S = 1024 # small but for test only. +H = 32 +DK = 128 +DV = 128 +input_dtype = T.bfloat16 +output_dtype = T.bfloat16 +accum_dtype = T.float32 +gate_dtype = T.float32 +state_dtype = T.float32 +chunk_size = 64 +use_g = True +use_initial_state = True +store_final_state = True +use_final_state_gradient = True +save_new_value = True +block_DK = 64 +block_DV = 32 +threads = 128 +num_stages = 1 + + +def test_example_wy_fast_compilation(): + from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input + + K, V, Beta, G, A = prepare_input( + B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype) + ) + # tilelang + block_S = chunk_size + kernel = tilelang_recompute_w_u_fwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + gate_dtype, + accum_dtype, + chunk_size, + block_S=block_S, + block_DK=block_DK, + block_DV=block_DV, + threads=threads, + num_stages=num_stages, + ) + print(kernel.get_kernel_source()) + W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) + + +def test_example_wy_fast_bwd_split_compilation(): + from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output + + K, V, Beta, G, A, dw, du = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + BS = chunk_size + dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() + dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() + dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() + dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() + + # tilelang + kernel = tilelang_wy_fast_bwd( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du) + torch.cuda.synchronize() + kernel_split = tilelang_wy_fast_bwd_split( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + block_DK, + block_DV, + threads, + num_stages, + ) + kernel_split( + K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative + ) + torch.cuda.synchronize() + + dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang + dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) + + +def test_example_chunk_o_compilation(): + from example_chunk_o import tilelang_chunk_fwd_o, prepare_input + + Q, K, V, HIDDEN, G = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + scale = 1.0 / DK**0.5 + block_S = chunk_size + kernel = tilelang_chunk_fwd_o( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + scale, + use_g, + block_S, + block_DK, + block_DV, + threads, + num_stages, + ) + O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 + + +def test_example_chunk_o_bwd_compilation(): + from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input + + Q, K, V, h, G, dO, dh, dv, W = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_o_bwd_dqkwg( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + True, + block_DK, + block_DV, + threads, + num_stages, + ) + + dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841 + if use_g: + dg_tilelang = dg_tilelang.sum(dim=0) + + +def test_example_chunk_scaled_dot_kkt_compilation(): + from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input + + K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype)) + block_S = chunk_size + kernel = tilelang_chunk_scaled_dot_kkt_fwd( + B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages + ) + A_tilelang = kernel(K, Beta, G) # noqa: F841 + + +def test_example_cumsum_compilation(): + from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output + + G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype)) + G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype)) + block_S = chunk_size + kernel = tilelang_chunk_local_cumsum_scalar( + B=B, + S=S, + H=H, + chunk_size=chunk_size, + reverse=False, + head_first=False, + input_dtype=gate_dtype, + output_dtype=gate_dtype, + block_S=block_S, + threads=threads, + use_fragment=False, + ) + G_new_tilelang = kernel(G) # noqa: F841 + + +def test_example_chunk_delta_h_compilation(): + from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input + + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_g, + use_initial_state, + store_final_state, + save_new_value, + block_DK, + block_DV, + threads, + num_stages, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841 + + +def test_example_chunk_delta_bwd_compilation(): + from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input + + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + 1.0, + use_g, + use_initial_state, + use_final_state_gradient, + block_DV, + threads, + num_stages, + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841 + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gdn/test_utils.py b/examples/gdn/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3588551ce39ad1bee4267b727979336d73341561 --- /dev/null +++ b/examples/gdn/test_utils.py @@ -0,0 +1,38 @@ +import torch + + +def print_red_warning(message): + print(f"\033[31mWARNING: {message}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f"{name} all zero") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): + x_mask = torch.isfinite(x) + y_mask = torch.isfinite(y) + if not torch.all(x_mask == y_mask): + print_red_warning(f"{name} Error: isfinite mask mismatch") + if raise_assert: + raise AssertionError + if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") + if raise_assert: + raise AssertionError + x = x.masked_fill(~x_mask, 0) + y = y.masked_fill(~y_mask, 0) + sim = calc_sim(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print_red_warning(f"{name} Error: {diff}") + if raise_assert: + raise AssertionError + else: + print(f"{name} {data} passed") diff --git a/examples/gemm/README.md b/examples/gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9ab7fb6614654e4594484d2e72ba9a7703b2706b --- /dev/null +++ b/examples/gemm/README.md @@ -0,0 +1,452 @@ +# TileLang GEMM (Matrix Multiplication) Examples + +TileLang is a domain-specific language designed to simplify the process of writing GPU kernels. It provides high-level abstractions for memory allocation, scheduling, and tiling, which are critical for achieving maximum performance on modern hardware architectures like NVIDIA GPUs. This README demonstrates how to write and optimize a matrix multiplication (GEMM) kernel using TileLang. + +## Table of Contents + +- [Table of Contents](#table-of-contents) +- [Getting Started](#getting-started) + - [Prerequisites](#prerequisites) + - [Installation](#installation) +- [Simple GEMM Example](#simple-gemm-example) + - [Code Walkthrough](#code-walkthrough) + - [Compiling and Profiling](#compiling-and-profiling) +- [Advanced GEMM Features](#advanced-gemm-features) + - [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling) + - [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining) + - [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality) +- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations) +- [Verifying Correctness](#verifying-correctness) +- [Fine-grained MMA Computations](#fine-grained-mma-computations) + - [Example Workflow](#example-workflow) + - [Summary](#summary) +- [References](#references) + +--- + +## Getting Started + +### Prerequisites + +- **Python 3.8+** +- **NVIDIA GPU** with a recent CUDA toolkit installed +- **PyTorch** (optional, for easy correctness verification) +- **tilelang** +- **bitblas** (optional; used for swizzle layout utilities in the advanced examples) + +### Installation + +```bash +pip install tilelang bitblas +``` + +*(Adjust accordingly if you are installing from source or using a different environment.)* + +--- + +## Simple GEMM Example + +Below is a basic matrix multiplication (GEMM) example demonstrating how TileLang handles buffer allocation, tiling, and kernel dispatch. For simplicity, we'll multiply two 1024ร—1024 matrices using 128 threads/block. + +```python +import tilelang +from tilelang import Profiler +import tilelang.language as T + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Define a grid with enough blocks to cover Mร—N + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + + # Allocate shared memory for the current tile of A and B + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + + # Allocate a local (register) fragment for partial accumulations + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Initialize the local accumulation buffer to zero + T.clear(C_local) + + # Loop over the K dimension in block_K chunks, using a 3-stage pipeline + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy from global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + + # Perform a matrix multiply-accumulate on the tile + T.gemm(A_shared, B_shared, C_local) + + # Copy the accumulated result from local memory (C_local) to global memory (C) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main +``` + +### Code Walkthrough + +1. **Define the Kernel Launch Configuration:** + ```python + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + ``` + This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads. + +2. **Shared Memory Allocation:** + ```python + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + ``` + Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access. + +3. **Local Fragment Accumulation:** + ```python + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + ``` + Partial results are stored in registers (or local memory) to reduce writes to global memory. + +4. **Pipelined Loading and GEMM:** + ```python + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(...) + T.gemm(...) + ``` + Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation. + +5. **Copy Out the Results:** + ```python + T.copy(C_local, C[by * block_M, bx * block_N]) + ``` + Writes the final computed tile from registers/shared memory to global memory. + +### Compiling and Profiling + +```python +func = matmul(1024, 1024, 1024, 128, 128, 32) +print(func) # Prints an IR-like representation of the TileLang kernel + +artifact = tilelang.lower(func) + +profiler = Profiler(artifact.rt_mod, artifact.params, result_idx=[2]) + +import torch +a = torch.randn(1024, 1024).cuda().half() +b = torch.randn(1024, 1024).cuda().half() + +c = profiler(a, b) +ref_c = a @ b + +# Validate results +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + +# Get CUDA Kernel Source +print(artifact.kernel_source) +``` + +--- + +## Advanced GEMM Features + +### Custom Memory Layout / Swizzling + +**Swizzling** rearranges data in shared memory or global memory to mitigate bank conflicts, improve cache utilization, and better match the GPUโ€™s warp execution pattern. TileLang provides helper functions like `make_swizzle_layout` to annotate how buffers should be laid out in memory. + +### Parallel Copy and Auto-Pipelining + +- **Parallel Copy** allows you to distribute the copy of a block tile across all threads in a block, speeding up the transfer from global memory to shared memory. +- **Auto-Pipelining** uses multiple stages to overlap copying with computation, reducing idle cycles. + +### Rasterization for L2 Cache Locality + +Enabling **swizzle (rasterization)** at the kernel level can improve data reuse and reduce cache thrashing in L2. This is especially important when matrices are large. + +--- + +## Enhanced GEMM Example with Annotations + +Below is a more advanced snippet that showcases how to apply memory layouts, enable swizzling, and parallelize the copy operations to maximize performance: + +```python +import tilelang.language as T +# `make_mma_swizzle_layout` is a python-defined layout function +# that helps align data for MMA (Matrix Multiply-Accumulate) operations. +from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout + +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + # Allocate shared and local fragments + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Annotate memory layout + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Enable swizzle-based rasterization for better L2 locality + T.use_swizzle(panel_size=10, enable=True) + + # Clear the local accumulation buffer + T.clear(C_local) + + # Pipelined iteration over K dimension + for idx in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + T.copy(A[by * block_M, idx * block_K], A_shared) + + # Parallel copy tile of B + for ko, j in T.Parallel(block_K, block_N): + B_shared[ko, j] = B[idx * block_K + ko, bx * block_N + j] + + # Perform local GEMM on the shared-memory tiles + T.gemm(A_shared, B_shared, C_local) + + # Copy the result tile back + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main +``` + +**Key Differences vs. Basic Example** +1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling). +2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization. +3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions. + +--- + +## Verifying Correctness + +Once you compile and load your kernel into a runtime module (`rt_mod`), you can use tools like **PyTorch** to easily create random matrices on the GPU, run your TileLang kernel, and compare the results to a reference implementation (e.g., `torch.matmul` or `@` operator). + +```python +import torch + +# Suppose your compiled kernel is in rt_mod +profiler = Profiler(rt_mod, params, result_idx=[2]) + +A = torch.randn(1024, 1024).cuda().half() +B = torch.randn(1024, 1024).cuda().half() + +C_tilelang = profiler(A, B) +C_ref = A @ B + +torch.testing.assert_close(C_tilelang, C_ref, rtol=1e-2, atol=1e-2) +print("Results match!") +``` + +--- + +## Fine-grained MMA Computations + +For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points. + +### Example Workflow + +```python +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] +``` + +1. **Set Up Tile Sizes and Thread Bindings** + Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID). + +2. **Allocate Warp-local Fragments** + Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like: + ```python + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + ``` + Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warpโ€™s register tiles. + +3. **Load Data via `ldmatrix`** + Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well: + ```python + for ki in T.serial(0, (block_K // micro_size_k)): + # Warp-synchronous load for A + mma_emitter.ldmatrix_a(A_local, A_shared, ki) + + # Warp-synchronous load for B + mma_emitter.ldmatrix_b(B_local, B_shared, ki) + ``` + Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers. + +4. **Perform the MMA Instruction** + After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially: + \[ + C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}} + \] + where each thread in the warp calculates a small portion of the final tile. For instance: + ```python + mma_emitter.mma(A_local, B_local, C_local) + ``` + Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel. + +5. **Store Results via `stmatrix`** + Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet: + ```python + mma_emitter.stmatrix(C_local, C_shared) + ``` + orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer. + +### Summary + +By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with manual thread bindings and memory allocations, you can replicate the control and performance of raw CUDA at the TileLang level. This approach is best suited for expert users who are comfortable with GPU warp-level programming, since it does require a deep understanding of hardware concurrency, memory hierarchies, and scheduling. However, the payoff can be significant for performance-critical paths, where every byte of bandwidth and every cycle of latency must be carefully orchestrated. + +--- + +## References + +- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM. +- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA. +- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul. diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..906a55d5d00fb516c00a10bacdf42c916a23bdb3 --- /dev/null +++ b/examples/gemm/example_gemm.py @@ -0,0 +1,61 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + print("c:") + print(c) + print("ref_c:") + print(ref_c) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # Get CUDA Source + print("CUDA Source:") + print(kernel.get_kernel_source()) + + # benchmark + profiler = kernel.get_profiler() + latency = profiler.do_bench(backend="cupti") + # latency = profiler.do_bench() + print(f"tilelang Latency: {latency}ms") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..ca322217341f5745193372751ad00a009d9f42c2 --- /dev/null +++ b/examples/gemm/example_gemm_autotune.py @@ -0,0 +1,239 @@ +import argparse +import itertools +import tilelang as tl +import tilelang.language as T +from tilelang.autotuner import AutoTuner +from tilelang.carver.template import MatmulTemplate +from tilelang.carver.arch import CUDA +from tilelang.carver.arch import CDNA +from tilelang.carver.roller.rasterization import NoRasterization +import torch + + +def ref_program(A, B): + """ + Compute the matrix product of A and the transpose of B. + + A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes. + """ + return A @ B.T + + +def get_configs(M, N, K, with_roller=False, topk=20): + """ + Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply. + + When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended + configurations (device-specific TensorCore-friendly tilings). Each returned dict contains: + - block_M, block_N, block_K: tile sizes + - num_stages: pipeline staging (0 means no explicit staging) + - thread_num: total threads used for the block + - enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling) + + When with_roller is False this returns the Cartesian product of a fixed set of candidate + parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag. + + Parameters: + M, N, K (int): GEMM dimensions used to generate valid tile sizes. + with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints; + otherwise use a predefined candidate grid. + topk (int): Maximum number of roller hints to request when with_roller is True. + + Returns: + List[dict]: A list of configuration dictionaries as described above. + + Raises: + ValueError: if with_roller is True but the roller returns no hints. + """ + if with_roller: + arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + roller_hints = carve_template.recommend_hints(topk=topk) + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + # block_rows, block_cols represents warp partitioning + block_rows, block_cols = block_m // warp_m, block_n // warp_n + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0 + config["thread_num"] = block_rows * block_cols * 32 + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + else: + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + ) + ) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } + for c in _configs + ] + return configs + + +def get_best_config(M, N, K, with_roller=False): + def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, + ): + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( + out_idx=[-1], + target="auto", + ) + .set_profile_args( + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=False, + ) + ) + return autotuner.run(warmup=3, rep=20) + + +def get_heuristic_config() -> dict: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version in {80}: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} + elif sm_version in {90}: + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} + else: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} + + +@tl.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_autotune( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_autotune + + +def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False): + use_autotune = True + if use_autotune: + result = get_best_config(M, N, K, with_roller) + print(result.config) + kernel = result.kernel + else: + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + + # benchmark + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + tilelang_latency = profiler.do_bench() + ref_latency = profiler.do_bench(ref_program) + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print(f"TileLang latency: {tilelang_latency}") + print(f"Ref latency: {ref_latency}") + print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}") + print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") + args = parser.parse_args() + main(args.m, args.n, args.k, args.use_autotune, args.with_roller) diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py new file mode 100644 index 0000000000000000000000000000000000000000..746e6ec011d8e44830431198dc03060ba4e5af91 --- /dev/null +++ b/examples/gemm/example_gemm_intrinsics.py @@ -0,0 +1,185 @@ +from tilelang import tvm as tvm +from tvm import DataType +import tilelang +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@tilelang.jit(out_idx=[2]) +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == T.int32: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + # chunk = 32 if in_dtype == T.float16 else 64 + chunk = 32 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def gemm_intrinsics( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a(A_local, A_shared, ki) + + # Load B into fragment + mma_emitter.ldmatrix_b(B_local, B_shared, ki) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix(C_local, C_shared) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return gemm_intrinsics + + +def ref_program(A, B): + return A @ B.T + + +def main(M=4096, N=4096, K=4096): + in_dtype, out_dtype, accum_dtype = T.float16, T.float16, T.float32 + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + + profiler = kernel.get_profiler() + + latency = profiler.do_bench(profiler.func, warmup=25) + + print(latency) + + # Ensure that the latency is not None + assert latency is not None + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main(M=4096, N=4096, K=4096) diff --git a/examples/gemm/example_gemm_intrinsics_dcu.py b/examples/gemm/example_gemm_intrinsics_dcu.py new file mode 100644 index 0000000000000000000000000000000000000000..e43bef16d7c3f64044a4a338c48313be5e25fb2e --- /dev/null +++ b/examples/gemm/example_gemm_intrinsics_dcu.py @@ -0,0 +1,189 @@ +from tilelang import tvm as tvm +from tvm import DataType +import tilelang +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mmac_macro_generator import ( + MatrixCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func +from tilelang import disable_cache + +disable_cache() + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@tilelang.jit(out_idx=[2]) +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + # chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 64 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMAC Wrapper to Auto Generate Code for MMAC + mmac_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def gemm_intrinsics( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mmac_emitter.ldmatrix_a(A_local, A_shared, ki) + + # Load B into fragment + mmac_emitter.ldmatrix_b(B_local, B_shared, ki) + + # Perform Matrix Multiplication + mmac_emitter.mmac(A_local, B_local, C_local) + + # Perform STMatrix + mmac_emitter.stmatrix(C_local, C_shared) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + j // micro_size_y, + i // micro_size_x, + i % micro_size_x, + j % micro_size_y, + ] + + return gemm_intrinsics + + +def ref_program(A, B): + return A @ B.T + + +def main(): + M, N, K = 16384, 16384, 16384 + in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + + profiler = kernel.get_profiler() + + latency = profiler.do_bench(profiler.func, warmup=25) + + print(latency) + print(kernel.get_kernel_source()) + # Ensure that the latency is not None + assert latency is not None + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..30f55de6a06d06eadabd9461ee0eba1169521764 --- /dev/null +++ b/examples/gemm/example_gemm_persistent.py @@ -0,0 +1,136 @@ +import tilelang +import tilelang.language as T +from tilelang.carver.arch import driver +import argparse + + +@tilelang.jit(out_idx=[-1]) +def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + T.use_swizzle(10) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[bx * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + return main + + +@tilelang.jit(out_idx=[-1]) +def matmul_persistent( + M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32, use_persistent_primitive=True +): + sm_num = driver.get_num_sms() + m_blocks = T.ceildiv(M, block_M) + n_blocks = T.ceildiv(N, block_N) + waves = T.ceildiv(m_blocks * n_blocks, sm_num) + group_size = 8 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(sm_num, threads=threads) as (block_id): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + for w in T.serial(waves): + tile_id = sm_num * w + block_id + bx = (tile_id // group_size) % m_blocks + by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size + + if bx * block_M < M and by * block_N < N: + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[bx * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + @T.prim_func + def main_persistent_primitive( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(sm_num, threads=threads) as (block_id): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id): + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[bx * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + return main_persistent_primitive if use_persistent_primitive else main + + +def ref_program(A, B): + return A @ B + + +def main(M=4096, N=4096, K=4096): + total_flops = 2 * M * N * K + + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 64 + threads = 256 + num_stages = 3 + + persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("Persistent GEMM: All check passed.") + persistent_latency = persistent_profiler.do_bench(warmup=500) + print(f"Persistent GEMM Latency: {persistent_latency} ms") + print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") + + non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("Non-Persistent GEMM: All check passed.") + non_persistent_latency = non_persistent_profiler.do_bench(warmup=500) + print(f"Non-Persistent GEMM Latency: {non_persistent_latency} ms") + print(f"Non-Persistent GEMM TFlops: {total_flops / non_persistent_latency * 1e-9} TFlops") + + print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=8192, help="N dimension") + parser.add_argument("--K", type=int, default=8192, help="K dimension") + args = parser.parse_args() + M, N, K = args.M, args.N, args.K + main(M, N, K) diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..8663c878d043c86a765be08ccf99ae87ce9d1bee --- /dev/null +++ b/examples/gemm/example_gemm_schedule.py @@ -0,0 +1,68 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_schedule( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 Cache Locality + T.use_swizzle(panel_size=10) + + # Clear the local buffer + T.clear(C_local) + + # Auto pipeline the computation + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Instead of using + # T.copy(B[k * block_K, bx * block_N], B_shared) + # we can also use Parallel to auto map the thread + # bindings and vectorize the copy operation. + for k, j in T.Parallel(block_K, block_N): + B_shared[k, j] = B[ko * block_K + k, bx * block_N + j] + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_schedule + + +def main(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + print("c:") + print(c) + print("ref_c:") + print(ref_c) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # Get CUDA Source + print("CUDA Source:") + print(kernel.get_kernel_source()) + + +if __name__ == "__main__": + main() diff --git a/examples/gemm/test_example_gemm.py b/examples/gemm/test_example_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f69364be64fd95f1aa59b3e8a5f69e1a4c57dcf --- /dev/null +++ b/examples/gemm/test_example_gemm.py @@ -0,0 +1,26 @@ +import tilelang.testing +import example_gemm_autotune +import example_gemm_intrinsics +import example_gemm_schedule +import example_gemm + + +def test_example_gemm_autotune(): + # enable roller for fast tuning + example_gemm_autotune.main(M=1024, N=1024, K=1024, with_roller=True) + + +def test_example_gemm_intrinsics(): + example_gemm_intrinsics.main(M=1024, N=1024, K=1024) + + +def test_example_gemm_schedule(): + example_gemm_schedule.main() + + +def test_example_gemm(): + example_gemm.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gemm_fp8/README.md b/examples/gemm_fp8/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9d7011a064fb97fb4c78e8f64267d1a4a7f1e35f --- /dev/null +++ b/examples/gemm_fp8/README.md @@ -0,0 +1 @@ +**Notes**: Now we only support fp8 with mma instructions instead of `T.gemm`, because the cutlass version of tilelang is too old, we should update the cutlass version in future. \ No newline at end of file diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd.py b/examples/gemm_fp8/example_tilelang_gemm_amd.py new file mode 100644 index 0000000000000000000000000000000000000000..93f8c4980c36409e38afb1439b244918eee31748 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -0,0 +1,116 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import torch_assert_close +import itertools + + +def ref_program(A, B): + return (A.half() @ B.half().T).to(dtype=torch.float32) + + +def manual_check_prog(C, C_ref): + torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1) + + +def supply_prog(args): + a_param, b_param = args + M, K = a_param.shape + N, _ = b_param.shape + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + return [a, b] + + +def get_configs(): + block_Ms = [32, 64, 128] + block_Ns = [32, 64, 128] + block_Ks = [64, 128] + num_stages = [0] + num_threads = [256] + k_packs = [1, 2] + gemm_types = ["ss", "rs"] + + valid_configs = [] + + for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + "num_threads": t, + "k_pack": kp, + "gemm_type": gemm_type, + } + ) + return valid_configs + + +@tilelang.autotune( + configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog +) +@tilelang.jit(out_idx=[-1]) +def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): + dtype = T.float8_e4m3fnuz + accum_dtype = T.float32 + + @T.prim_func + def gemm_fp8_rs( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + A_local = T.alloc_fragment((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_local) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + @T.prim_func + def gemm_fp8_ss( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + if gemm_type == "ss": + return gemm_fp8_ss + elif gemm_type == "rs": + return gemm_fp8_rs + else: + raise ValueError(f"Invalid gemm_type: {gemm_type}") + + +def test_gemm_fp8(M, N, K): + kernel = fp8_matmul(M, N, K) + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + c = kernel(a, b) + ref_c = ref_program(a, b) + torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("passed~") + + +if __name__ == "__main__": + test_gemm_fp8(512, 512, 512) diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..1b440a7952a211b61f20e8c7c849d474a89cb0c1 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -0,0 +1,63 @@ +import torch +import tilelang +import tilelang.language as T + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): + @T.prim_func + def gemm_fp8( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_fp8 + + +def test_gemm_fp8(M, N, K, dtype): + torch_dtype = T.dtype(dtype).as_torch() + + kernel = matmul(M, N, K, 128, 128, 64, dtype) + + a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) + b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype) + + c = kernel(a, b) + + ref_c = (a.half() @ b.half().T).to(dtype=torch_dtype) + + print(c) + print(ref_c) + + diff = calc_diff(c, ref_c) + print(f"diff: {diff}") + assert diff < 1e-3 + + +def main(): + test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn) + test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2) + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5d84d72f16cd7af7e1304bfbacd28c80795035 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -0,0 +1,81 @@ +import torch +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): + # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128. + # if block_K < 128, promote after 128/block_K iters. + # if block_K > 128, promote after every iter. + update_interval = 128 // block_K if block_K < 128 else 1 + + @T.prim_func + def gemm_fp8_2xAcc( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + if (k + 1) % update_interval == 0: + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] + T.clear(C_local) + # Tail processing + if K_iters % update_interval != 0: + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_fp8_2xAcc + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def test_gemm_fp8(M, N, K, dtype): + torch_dtype = T.dtype(dtype).as_torch() + + kernel = matmul(M, N, K, 128, 128, 64, dtype) + + a = torch.rand(M, K, dtype=torch.float16, device="cuda") + a = (100 * (2 * a - 1)).to(dtype=torch_dtype) + b = torch.rand(N, K, dtype=torch.float16, device="cuda") + b = (100 * (2 * b - 1)).to(dtype=torch_dtype) + + c = kernel(a, b) + + ref_c = a.float() @ b.float().T + + diff = calc_diff(c, ref_c) + print(f"diff: {diff}") + assert diff < 1e-3 + + +def main(): + test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn) + test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2) + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py new file mode 100644 index 0000000000000000000000000000000000000000..7ecde7c1b4e03cf7cc6f3f0284f4062d30e25cb5 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -0,0 +1,228 @@ +import torch +from tilelang import tvm as tvm +import tilelang.testing +from tvm import DataType +import tilelang.language as T +from tilelang.intrinsics import get_swizzle_layout +from tilelang.intrinsics.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) +from tilelang.transform import simplify_prim_func +from tilelang.utils.tensor import map_torch_type + +tilelang.testing.set_random_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@tilelang.jit(out_idx=[2]) +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + T.float16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + T.float16, + T.float32, + T.int32, + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + is_float8 = in_dtype in [ + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, + ] + if out_dtype == T.int32 or is_float8: + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 if in_dtype == T.float16 else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def gemm_fp8_intrinsic( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return gemm_fp8_intrinsic + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + src_code = kernel.get_kernel_source() + print(src_code) + # src_code is the generated cuda source + assert src_code is not None + + in_dtype = map_torch_type(in_dtype) + out_dtype = map_torch_type(out_dtype) + accum_dtype = map_torch_type(accum_dtype) + + if in_dtype in {torch.int8, torch.int32}: + A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() + B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + A = torch.randn(M, K).to(in_dtype).cuda() + B = torch.randn(N, K).to(in_dtype).cuda() + else: + A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 + B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 + + C = torch.zeros(M, N, device="cuda", dtype=accum_dtype) + + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + + C = profiler(A, B) + + latency = profiler.do_bench(warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def main(): + assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) + assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7e8b360805459bda83b73dd7334b1bd923a201 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -0,0 +1,124 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm_v2( + A_shared, + B_shared, + C_tmem, + trans_A, + trans_B, + mbar=mbar, + wg_wait=-1, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 64, 256, 32 +trans_A, trans_B = False, True +num_stages = 2 +threads = 256 +for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]: + for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]: + torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) + torch_acc_dtype = map_torch_type(tvm_acc_dtype) + print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") + in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype + + func = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + ) + jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, + }, + ) + # jit_kernel.export_ptx("./dump.ptx") + # jit_kernel.export_sources("./dump.cu") + + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + + c = jit_kernel(a, b) + ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float() + c = c.float() + diff = calc_diff(c, ref_c) + # assert diff < 1e-3, f"{diff}" + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}") + + profiler = jit_kernel.get_profiler() + latency = profiler.do_bench() + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_fp8/test_example_gemm_fp8.py b/examples/gemm_fp8/test_example_gemm_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..19a9ee00a7cac624526585fc6caba191037ed46d --- /dev/null +++ b/examples/gemm_fp8/test_example_gemm_fp8.py @@ -0,0 +1,20 @@ +import tilelang.testing +import example_tilelang_gemm_fp8_2xAcc +import example_tilelang_gemm_fp8_intrinsic +import example_tilelang_gemm_fp8 + + +def test_example_tilelang_gemm_fp8_2xAcc(): + example_tilelang_gemm_fp8_2xAcc.main() + + +def test_example_tilelang_gemm_fp8_intrinsic(): + example_tilelang_gemm_fp8_intrinsic.main() + + +def test_example_tilelang_gemm_fp8(): + example_tilelang_gemm_fp8.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gemm_sm100/README.md b/examples/gemm_sm100/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28bb611bff167fdf7ba2291833edb91fd3d17beb --- /dev/null +++ b/examples/gemm_sm100/README.md @@ -0,0 +1,106 @@ +# TileLang SM100 Support (Preview) + +This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality. + +## Current Limitations (Manual Implementation Required) + +### 1. Manual TCGEN5.MMA Management +Users must manually handle TCGEN5MMA operations using: +- `T.alloc_tmem()` - Allocate Tensor Memory +- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting +- Manual synchronization with mbarrier + +### 2. Manual mbarrier Synchronization +TCGEN5MMA is asynchronous and requires explicit synchronization: +```python +mbar = T.alloc_barrier(1) # expect-arrive-count = 1 +T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0) +T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required +``` + +## Examples + +### TCGEN5MMA Example (`gemm_tcgen5mma.py`) +Demonstrates TCGEN5MMA operations with: +- Tensor Memory allocation +- Manual mbarrier synchronization +- TCGEN5MMA gemm operations + +### Traditional MMA Example (`gemm_mma.py`) +Shows standard MMA operations that work across architectures for comparison. + +## Code Example + +The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication: + +```python +import torch +import tilelang +import tilelang.language as T + +@T.prim_func +def main( + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((N, K), T.bfloat16), + C: T.Tensor((M, N), T.bfloat16), +): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + # 1. Allocate memory buffers + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory + B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory + C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory + mbar = T.alloc_barrier(1) # mbarrier synchronization primitive + + C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage + C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory + + # 2. Main computation loop + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + # Data loading: global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + + # TCGEN5MMA computation: asynchronous launch, output to Tensor Memory + T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True, + mbar=mbar, wg_wait=-1, clear_accum=k==0) + + # Critical: wait for TCGEN5MMA completion + T.mbarrier_wait_parity(mbar, k%2) + + # 3. Output processing (only subset of threads) + T.copy(C_tmem, C_local) # Tensor Memory โ†’ registers + T.copy(C_local, C_shared) # registers โ†’ shared memory + + # 4. Write back to global memory + T.copy(C_shared, C[by * block_M, bx * block_N]) +``` + +### Compilation and Usage + +```python +# Parameter setup +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 128, 256, 128 + +# Compile kernel +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required +}) + +# Run test +a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) +b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) +c = jit_kernel(a, b) + +# Verify correctness +ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + +# Performance benchmark +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") +``` + diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py new file mode 100644 index 0000000000000000000000000000000000000000..226e33c01e474ec646ba7f7e7ac39c86a2497c6a --- /dev/null +++ b/examples/gemm_sm100/gemm_mma.py @@ -0,0 +1,94 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + # Copy tile of A + # This is a sugar syntax for parallelized copy + # for i, k in T.Parallel(M, block_K): + # A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[bx * block_N, ko * block_K], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +M = 128 # M = T.dynamic("m") if you want to use dynamic shape +N = 128 +K = 32 +block_M = 128 +block_N = 128 +block_K = 32 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +func = matmul(M, N, K, block_M, block_N, block_K) + +# 2. Compile the kernel into a torch function +# out_idx specifies the index of the output buffer in the argument list +# if out_idx is specified, the tensor will be created during runtime +# target currently can be "cuda" or "hip" or "cpu". +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +print(jit_kernel.get_kernel_source()) +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(N, K, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +c = jit_kernel(a, b) + +print(c) +# Reference multiplication using PyTorch +ref_c = a @ b.T + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = jit_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py new file mode 100644 index 0000000000000000000000000000000000000000..523a94fea6737bcd33f879ee8b49aaecaa3740af --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -0,0 +1,83 @@ +import torch +import tilelang +import tilelang.language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 128, 256, 128 +trans_A, trans_B = False, True +in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float +num_stages = 2 +threads = 256 + +func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) + +print(jit_kernel.get_kernel_source()) + +a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) +b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) +c = jit_kernel(a, b) +ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +print(f"Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS") diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py new file mode 100644 index 0000000000000000000000000000000000000000..7b93f2a779e2e56dd496712abf9fa16363feafa8 --- /dev/null +++ b/examples/gemm_sp/example_custom_compress.py @@ -0,0 +1,336 @@ +import argparse + +import tilelang +import tilelang.language as T + +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import randn_semi_sparse +from tilelang.utils.tensor import torch_assert_close + +from triton.testing import do_bench + +import torch + +torch.manual_seed(42) + +DEFAULT_CONFIG = { # take best config from autotune script + "4090": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, + "h20": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, +} + +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp_fp16_custom_compress( + M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout +): + e_factor, e_dtype = (16, T.int16) + + @T.prim_func + def gemm_sp_fp16_custom_compress( + A_sparse: T.Tensor((M, K // 2), T.float16), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + B_shared = T.alloc_shared((block_K, block_N), T.float16) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + if use_cutlass_layout: + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), + } + ) + T.clear(C_local) + T.disable_warp_group_reg_alloc() + T.use_swizzle(panel_size=10, enable=enable_rasterization) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_sp_fp16_custom_compress + + +def torch_compress(dense): + """ + A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. + """ + if dense.dim() != 2: + raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") + + m, k = dense.shape + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}") + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + return (sparse, meta) + + +def decode_metadata(meta: torch.Tensor) -> torch.Tensor: + assert meta.dtype is torch.int16 + groups_per_meta = 16 // 4 # 4 groups per uint16 + out = [] + for g in range(groups_per_meta): + group_bits = (meta >> (g * 4)) & 0xF + idx0 = group_bits & 0x3 + idx1 = (group_bits >> 2) & 0x3 + out.append(torch.stack([idx0, idx1], dim=-1)) + return torch.concat(out, dim=-1).view(meta.shape[0], -1) + + +@tilelang.jit( + out_idx=[1, 2], + pass_configs={ + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, + }, +) +def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): + e_factor, e_dtype = ARCH_INFO["8.0"] + e_K = K // e_factor + elem, group = 2, 4 + + assert M % block_M == 0, "M must be divisible by block_M" + assert K % block_K == 0, "K must be divisible by block_K" + assert K % e_factor == 0, "K must be divisible by e_factor" + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" + + @T.prim_func + def kernel( + A: T.Tensor((M, K), dtype), + A_sp: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, e_K), e_dtype), + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + if use_cutlass_layout: + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), + } + ) + T.clear(A_sp_shared) + T.clear(E_shared) + # TODO: alloc_var seems buggy here + non_zero_cnt = T.alloc_local((1,), dtype=T.uint8) + non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8) + T.copy(A[bx * block_M, by * block_K], A_shared) + for tm in T.Parallel(block_M): + for g_i in range(0, block_K // group): + a_k = g_i * group + non_zero_cnt[0] = 0 + for i in range(elem): + non_zero_elt_log_idx[i] = 0 + for i in range(group): + val = A_shared[tm, a_k + i] + if val != 0.0: + non_zero_elt_log_idx[non_zero_cnt[0]] = i + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val + non_zero_cnt[0] += 1 + # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: + non_zero_elt_log_idx[0] = 0 + non_zero_elt_log_idx[1] = 3 + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] + A_sp_shared[tm, a_k // 2] = 0.0 + elif non_zero_cnt[0] == 1: + A_sp_shared[tm, a_k // 2 + 1] = 0 + non_zero_elt_log_idx[1] = 3 + for i in T.serial(elem): + val = non_zero_elt_log_idx[i] + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) + + return kernel + + +def main(): + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") + parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") + args = parser.parse_args() + kernel = matmul_sp_fp16_custom_compress( + args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype], use_cutlass_layout=args.use_cutlass_layout + ) + + a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) + b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) + + if args.use_torch_compressor: + assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" + a_sparse, e = torch_compress(a) + else: + a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a) + + c = kernel(a_sparse, e, b) + + ref_c = a @ b + + assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" + torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) + print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}") + + latency = do_bench(lambda: kernel(a_sparse, e, b)) + ref_latency = do_bench(lambda: a @ b) + + total_flops = 2 * args.m * args.n * args.k + tflops = total_flops / latency / 1e9 + ref_tflops = total_flops / ref_latency / 1e9 + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..10f524adbc791867f94a8af203ce787217a13183 --- /dev/null +++ b/examples/gemm_sp/example_gemm_sp.py @@ -0,0 +1,133 @@ +import argparse + +import tilelang +import tilelang.language as T + +from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import compress, randn_semi_sparse +from tilelang.contrib import nvcc +from triton.testing import do_bench + +import torch + +arch = nvcc.get_target_compute_version() + +DEFAULT_CONFIG = { # take best config from autotune script + "4090": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 64, + "num_stages": 1, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 256, + "block_N": 128, + "block_K": 64, + "num_stages": 2, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, + "h20": { + T.float: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + T.float16: { + "block_M": 128, + "block_N": 64, + "block_K": 128, + "num_stages": 3, + "thread_num": 128, + "policy": T.GemmWarpPolicy.Square, + "enable_rasterization": True, + }, + }, +} + +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization): + e_factor, e_dtype = ARCH_INFO[arch] + + @T.prim_func + def gemm_sp_fp16( + A_sparse: T.Tensor((M, K // 2), T.float16), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), T.float16), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + B_shared = T.alloc_shared((block_K, block_N), T.float16) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + T.disable_warp_group_reg_alloc() + T.use_swizzle(panel_size=10, enable=enable_rasterization) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch), + } + ) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_sp_fp16 + + +def main(): + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") + args = parser.parse_args() + kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) + + a = randn_semi_sparse(args.m, args.k, device="cuda", dtype=torch.half) + b = torch.randn(args.k, args.n, device="cuda", dtype=torch.half) + + a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]["block_K"], arch=arch) + c = kernel(a_sparse, e, b) + + ref_c = a @ b + + assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" + torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) + print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}") + + latency = do_bench(lambda: kernel(a_sparse, e, b)) + ref_latency = do_bench(lambda: a @ b) + + total_flops = 2 * args.m * args.n * args.k + tflops = total_flops / latency / 1e9 + ref_tflops = total_flops / ref_latency / 1e9 + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_sp/test_example_gemm_sp.py b/examples/gemm_sp/test_example_gemm_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..fe26df14497e392bd0cdc03b9a3352d4fbd2f24b --- /dev/null +++ b/examples/gemm_sp/test_example_gemm_sp.py @@ -0,0 +1,16 @@ +import tilelang.testing + +import example_custom_compress +import example_gemm_sp + + +def test_example_custom_compress(): + example_custom_compress.main() + + +def test_example_gemm_sp(): + example_gemm_sp.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py new file mode 100644 index 0000000000000000000000000000000000000000..62073c5bddfa93959ad21b98c7b7cda3ffcb1e76 --- /dev/null +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -0,0 +1,60 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): + splitK = K // split_k + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0): + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + + for i, j in T.Parallel(block_M, block_N): + T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j]) + + return main + + +def main(): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + kernel(a, b, c) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py new file mode 100644 index 0000000000000000000000000000000000000000..83e83b5d2a7fff6eddbed10ddf63a19f8d18d425 --- /dev/null +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -0,0 +1,59 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): + splitK = K // split_k + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0): + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + + T.atomic_add(C[by * block_M, bx * block_N], C_shared) + + return main + + +def main(): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + + import torch + + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + kernel(a, b, c) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_splitk/test_example_gemm_splitk.py b/examples/gemm_splitk/test_example_gemm_splitk.py new file mode 100644 index 0000000000000000000000000000000000000000..055b09162767d4a208bdec0d7b5ca8ccefec772c --- /dev/null +++ b/examples/gemm_splitk/test_example_gemm_splitk.py @@ -0,0 +1,16 @@ +import tilelang.testing + +import example_tilelang_gemm_splitk +import example_tilelang_gemm_splitk_vectorize_atomicadd + + +def test_example_tilelang_gemm_splitk(): + example_tilelang_gemm_splitk.main() + + +def test_example_tilelang_gemm_splitk_vectorize_atomicadd(): + example_tilelang_gemm_splitk_vectorize_atomicadd.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec1541ea4aac095dc34b69ea55bdd73d66f4db7 --- /dev/null +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -0,0 +1,203 @@ +import torch +import torch.backends +import tilelang +from tilelang import language as T +import math + + +def cdiv(a, b): + return math.ceil(a / b) + + +# disable tf32 +torch.backends.cuda.matmul.allow_tf32 = False + +m = 256 +n = 1024 +k = 512 + +total_sm = 108 + +torch.random.manual_seed(0) +# uniform distribution from -1 to 1 +A = torch.rand(m, k, device="cuda", dtype=torch.float16) * 2 - 1 +B = torch.rand(n, k, device="cuda", dtype=torch.float16) * 2 - 1 + +streamk_programs = total_sm +BLOCK_SIZE_M = 16 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 32 +two_tiles = False +M, K = A.shape +N, K = B.shape +# accumulator types +# compute grid (work to do per SM on the first wave) +num_block_m = tilelang.cdiv(M, BLOCK_SIZE_M) +num_block_n = tilelang.cdiv(N, BLOCK_SIZE_N) +iters_per_tile = tilelang.cdiv(K, BLOCK_SIZE_K) +total_tiles = num_block_m * num_block_n + +# Two-tile SK + DP +streamk_tiles = total_tiles % streamk_programs +if total_tiles - streamk_tiles > streamk_programs: # (total_tiles // total_programs > 1) + streamk_tiles += streamk_programs + +blocking_tiles = total_tiles - streamk_tiles +streamk_iters = streamk_tiles * iters_per_tile + +streamk_full_tiles = streamk_iters // streamk_programs +streamk_partial_tiles = streamk_iters % streamk_programs + +print(f"{total_tiles=} ") +print(f"{iters_per_tile=} ") + +sm_patition_factor = max(blocking_tiles // total_sm, 1) + + +@tilelang.jit +def tl_matmul_streamk( + M, + N, + K, + streamk_tiles, + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, +): + assert not trans_A + A_shape = (M, K) if not trans_A else (K, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + @T.macro + def compute_first_wave( + pid: T.int32, + A_buf: T.Tensor, + A_buf_shared: T.SharedBuffer, + B_buf: T.Tensor, + B_buf_shared: T.SharedBuffer, + C: T.Tensor, + C_local: T.LocalBuffer, + ): + start_iter = T.alloc_fragment((1,), T.int32, "local") + end_iter = T.alloc_fragment((1,), T.int32, "local") + + start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) + last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) + + while start_iter[0] < last_iter: + end_iter[0] = T.min( + start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)), + last_iter, + ) + + tile_id = start_iter[0] // iters_per_tile + remain_iters = start_iter[0] % iters_per_tile + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + + T.clear(C_local) + for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): + T.copy( + A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], + A_buf_shared, + ) + T.copy( + B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], + B_buf_shared, + ) + T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B) + + # last iteration of the tile always happens before its start on another SM + if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0): + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j]) + + start_iter[0] = end_iter[0] + + @T.macro + def compute_full_tiles( + pid: T.int32, + A_buf: T.Tensor, + A_shared: T.SharedBuffer, + B_buf: T.Tensor, + B_shared: T.SharedBuffer, + C: T.Tensor, + C_local: T.LocalBuffer, + ): + for p in T.serial(sm_patition_factor): + tile_id = pid + streamk_tiles + p * total_sm + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A_buf[pid_m * block_M, k * block_K], A_shared) + T.copy(B_buf[pid_n * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B) + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + + @T.prim_func + def main( + A: T.Tensor(A_shape, dtypeAB), + B: T.Tensor(B_shape, dtypeAB), + C: T.Tensor((M, N), dtypeC), + ): + with T.Kernel(streamk_programs, threads=threads) as pid: + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local) + + if sm_patition_factor > 0: + compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) + + return main + + +def main(): + kernel = tl_matmul_streamk( + m, + n, + k, + streamk_tiles, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + False, + True, + T.float16, + T.float16, + T.float32, + 2, + 64, + ) + + print(kernel.get_kernel_source()) + + b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + + kernel(A, B, b_c) + + C = torch.matmul(A, B.T) + + print(b_c) + print(C) + torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_streamk/test_example_tilelang_gemm_splitk.py b/examples/gemm_streamk/test_example_tilelang_gemm_splitk.py new file mode 100644 index 0000000000000000000000000000000000000000..a26ba74aede947a589923f7d1a57de3a14435de2 --- /dev/null +++ b/examples/gemm_streamk/test_example_tilelang_gemm_splitk.py @@ -0,0 +1,14 @@ +import tilelang.testing + +from example_tilelang_gemm_streamk import main + + +# not fully supported on sm90 +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_le(8, 9) +def test_example_tilelang_gemm_streamk(): + main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd0e4dd9f88cc4d74500e9b773cfcdb38999704 --- /dev/null +++ b/examples/gemv/example_gemv.py @@ -0,0 +1,368 @@ +import argparse +import itertools +import tilelang as tl +import tilelang.language as T +from tvm import DataType +from tilelang.autotuner import autotune +from tilelang import jit + + +def ref_program(A, B): + return A @ B.T + + +@tl.jit(out_idx=[-1]) +def naive_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn: + tn = T.get_thread_binding(0) # tn = threadIdx.x + A_shared = T.alloc_shared((BLOCK_K,), dtype) + B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype) + C_reg = T.alloc_local((1,), accum_dtype) + T.clear(C_reg) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for tk in T.serial(BLOCK_K): + A_shared[tk] = A[bk * BLOCK_K + tk] + B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] + for tk in T.serial(BLOCK_K): + C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn, tk].astype(accum_dtype) + C[bn * BLOCK_N + tn] = C_reg[0] + + return main + + +@tl.jit(out_idx=[-1]) +def naive_splitk_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((1,), dtype) + B_local = T.alloc_local((1,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + A_local[0] = A[bk * BLOCK_K + tk] + B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk] + C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main + + +@tl.jit(out_idx=[-1]) +def splitk_gemv( + N: int, + K: int, + BLOCK_N: int, + BLOCK_K: int, + reduce_threads: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + TILE_K = T.ceildiv(BLOCK_K, reduce_threads) + + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + C_accum = T.alloc_local((1,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.serial(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main + + +@tl.jit(out_idx=[-1]) +def splitk_gemv_vectorized( + N: int, + K: int, + BLOCK_N: int, + reduce_threads: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_shared = T.alloc_shared((BLOCK_N,), accum_dtype) + C_accum = T.alloc_local((1,), accum_dtype) + if tk == 0: + C_shared[tn] = 0 + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + T.atomic_add(C_shared[tn], C_accum[0]) + C[bn * BLOCK_N + tn] = C_shared[tn] + + return main + + +@tl.jit(out_idx=[-1]) +def splitk_gemv_vectorized_tvm( + N: int, + K: int, + BLOCK_N: int, + reduce_threads: int, + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, +): + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + C_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_accum[0], + True, + C_reduced[0], + tk, + dtype="handle", + ) + ) + + C[bn * BLOCK_N + tn] = C_reduced[0] + + return main + + +def get_block_template_configs(): + iter_params = dict( + block_M=[2, 4, 8, 32, 64, 128], block_N=[2, 4, 8, 32, 64, 128], num_stages=[0, 1, 2, 3, 4], threads=[32, 64, 128, 256] + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@tl.autotune( + configs=get_block_template_configs(), + warmup=3, + rep=20, +) +@tl.jit( + pass_configs={ + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + out_idx=[2], +) +def gemv_alloc_reducer( + M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float +): + @T.prim_func + def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore + with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: + o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all") + T.clear(o_reducer) + for i0_n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages): + a_smem = T.alloc_shared((block_M, block_N), dtype) + T.copy(a[i0_m * block_M, i0_n * block_N], a_smem) + a_frag = T.alloc_fragment((block_M, block_N), dtype) + T.copy(a_smem, a_frag) + x_frag = T.alloc_fragment(block_N, dtype) + T.copy(x[i0_n * block_N], x_frag) + for i1_m, i1_n in T.Parallel(block_M, block_N): + o_reducer[i1_m] += a_frag[i1_m, i1_n] * x_frag[i1_n] + T.finalize_reducer(o_reducer) + T.copy(o_reducer, o[i0_m * block_M]) + + return main + + +def get_thread_template_configs(): + iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune( + configs=get_thread_template_configs(), + warmup=3, + rep=20, +) +@jit( + out_idx=[-1], + target="auto", +) +def get_autotuned_kernel( + N, + K, + BLOCK_N=None, + reduce_threads=None, +): + dtype = T.float16 + accum_dtype = T.float32 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits + BLOCK_K = reduce_threads * TILE_K + + @T.prim_func + def main( + A: T.Tensor((K,), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: + tn = T.get_thread_binding(0) + tk = T.get_thread_binding(1) + A_local = T.alloc_local((TILE_K,), dtype) + B_local = T.alloc_local((TILE_K,), dtype) + C_accum = T.alloc_local((1,), accum_dtype) + + T.clear(C_accum) + for bk in T.serial(T.ceildiv(K, BLOCK_K)): + for k in T.vectorized(TILE_K): + A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] + B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] + for k in T.serial(TILE_K): + C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) + C_reduced = T.alloc_local((1,), accum_dtype) + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ): + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_accum[0], + True, + C_reduced[0], + tk, + dtype="handle", + ) + ) + + C[bn * BLOCK_N + tn] = C_reduced[0] + + return main + + +def check_correctness_and_bench(kernel, N, K, do_bench=True): + profiler = kernel.get_profiler() + profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2) + if do_bench: + latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=50) + print(f"Torch Latency: {latency} ms") + latency = profiler.do_bench(kernel, warmup=50) + print(f"TileLang Latency: {latency} ms\n") + + +def main(do_bench: bool = True): + parser = argparse.ArgumentParser(description="GEMV Example") + parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") + args, _ = parser.parse_known_args() + N, K = args.n, args.k + check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K, do_bench=do_bench) + check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K, do_bench=do_bench) + check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K, do_bench=do_bench) + + print("Test passed!") + + if do_bench: + best_result = get_autotuned_kernel(N, K) + best_config = best_result.config + kernel = splitk_gemv_vectorized_tvm(N, K, **best_config) + profiler = kernel.get_profiler() + latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500) + print(f"Torch Latency: {latency} ms") + tilelang_thread_latency = profiler.do_bench(kernel, warmup=500) + print(f"TileLang SIMT Latency: {tilelang_thread_latency} ms\n") + kernel = gemv_alloc_reducer(N, K) + profiler = kernel.get_profiler() + tilelang_tile_latency = profiler.do_bench(kernel, warmup=500) + print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n") + + +if __name__ == "__main__": + main() diff --git a/examples/gemv/test_example_gemv.py b/examples/gemv/test_example_gemv.py new file mode 100644 index 0000000000000000000000000000000000000000..323337a7a6a0f21f79ed4455d8243e3561f3847a --- /dev/null +++ b/examples/gemv/test_example_gemv.py @@ -0,0 +1,9 @@ +import example_gemv + + +def test_example_gemv(): + example_gemv.main(do_bench=False) + + +if __name__ == "__main__": + test_example_gemv() diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..bb57c60731a4b495028612546def8b24324940a8 --- /dev/null +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -0,0 +1,239 @@ +import torch +import math +import argparse +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): + """ + args: + a (torch.Tensor): Input tensor of shape (M, K). + b (torch.Tensor): Input tensor of shape (G, K, N). + """ + accum_dtype = T.float32 + + @T.prim_func + def kernel( + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore + ): + with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by): + A_shared = T.alloc_shared([block_M, block_K], dtype) + B_shared = T.alloc_shared([block_K, block_N], dtype) + C_local = T.alloc_fragment([block_M, block_N], accum_dtype) + cur_batch_idx = T.alloc_local([1], T.int32) + cur_batch_size = T.alloc_local([1], T.int32) + + m_start_padded = bx * block_M + + for i in range(batch_count): + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + + cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] + actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + for i, j in T.Parallel(block_M, block_N): + with T.If(i < actual_rows), T.Then(): + C[m_start + i, by * block_N + j] = C_local[i, j] + + return kernel + + +class _GroupedGEMM(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b, batch_sizes): + block_M = 64 + block_N = 64 + block_K = 64 + padding_M = block_M + num_stages = 2 + threads = 128 + batch_sum = a.shape[0] + batch_count = b.shape[0] + K = a.shape[1] + N = b.shape[2] + + assert a.shape[1] == b.shape[1] + assert batch_sizes.shape[0] == batch_count + assert batch_sizes.sum() == batch_sum + + batch_offsets_list = [0] + batch_padded_offsets_list = [0] + for i in range(batch_count - 1): + batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i]) + for i in range(batch_count - 1): + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes[i] + 1) / padding_M) * padding_M) + batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=a.device, dtype=torch.int32) + + kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages, threads) + + o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) + ctx.save_for_backward(a, b, batch_sizes, batch_offsets) + ctx.batch_sum = batch_sum + ctx.batch_count = batch_count + ctx.K = K + return o + + @staticmethod + def backward(ctx, grad_output): + block_M = 64 + block_N = 64 + block_K = 64 + num_stages = 2 + threads = 128 + + M = ctx.K + N = grad_output.shape[1] + + A, B, batch_sizes, batch_offsets = ctx.saved_tensors + + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] + kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, num_stages, threads) + + dB = kernel(A, grad_output, batch_sizes, batch_offsets) + return None, dB, None + + +def ref_program(a, b, batch_sizes): + assert a.shape[0] == sum(batch_sizes) + assert b.shape[0] == len(batch_sizes) + + output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) + + start = 0 + a_list = [] + b_list = [] + for i, size in enumerate(batch_sizes): + end = start + size + part_a = a[start:end] + part_b = b[i] + output[start:end] = torch.mm(part_a, part_b) + + a_list.append(part_a) + b_list.append(part_b) + start = end + + return output + + +def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): + batch_sum = sum(batch_sizes_list) + batch_count = len(batch_sizes_list) + batch_offsets_list = [0] + batch_padded_offsets_list = [0] + for i in range(batch_count - 1): + batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) + for i in range(batch_count - 1): + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i] + 1) / padding_M) * padding_M) + A = torch.randn(batch_sum, K, device=device, dtype=dtype) + B = torch.randn(batch_count, K, M, device=device, dtype=dtype) + C = torch.empty(batch_sum, M, device=device, dtype=dtype) + batch_sizes = torch.tensor(batch_sizes_list, device=device, dtype=torch.int32) + batch_offsets = torch.tensor(batch_offsets_list, device=device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=device, dtype=torch.int32) + # print(batch_sizes_tensor) + # print(batch_offsets_tensor) + # print(batch_padded_offsets_tensor) + return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets + + +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): + """ + args: + a (torch.Tensor): Input tensor of shape (M, K). + b (torch.Tensor): Input tensor of shape (G, K, N). + """ + accum_dtype = T.float32 + + @T.prim_func + def kernel( + A: T.Tensor([batch_sum, M], dtype), # type: ignore + B: T.Tensor([batch_sum, N], dtype), # type: ignore + C: T.Tensor([batch_count, M, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + ): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz): + A_shared = T.alloc_shared([block_K, block_M], dtype) + B_shared = T.alloc_shared([block_K, block_N], dtype) + C_local = T.alloc_fragment([block_M, block_N], accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages): + for i, j in T.Parallel(block_K, block_M): + A_shared[i, j] = T.if_then_else(i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i, bx * block_M + j], 0) + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = T.if_then_else(i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i, by * block_N + j], 0) + T.gemm(A_shared, B_shared, C_local, transpose_A=True) + + T.copy(C_local, C[bz, bx * block_M, by * block_N]) + + return kernel + + +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): + padding_M = block_M + device = torch.device("cuda") + dtype = torch.float16 + + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, False, padding_M, device, dtype) + + A.requires_grad_(False) + B.requires_grad_(True) + O_ref = ref_program(A, B, batch_sizes) + dO = torch.randn_like(O_ref) + + O_ref.backward(dO, retain_graph=True) + dB_ref, B.grad = B.grad.clone(), None + + GroupedGEMM = _GroupedGEMM.apply + O = GroupedGEMM(A, B, batch_sizes) + O.backward(dO, retain_graph=True) + dB, B.grad = B.grad.clone(), None + + if torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2): + print("โœ… Tilelang and Torch match") + else: + print("โŒ Tilelang and Torch mismatch") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") + args = parser.parse_args() + + batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] + K, M, trans_b = args.K, args.M, args.trans_b + + block_M = 64 + block_N = 128 + block_K = 64 + num_stages = 2 + threads = 256 + + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd.py b/examples/grouped_gemm/example_grouped_gemm_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..48d91605145405a94c6e697207bb63ee49bc6a66 --- /dev/null +++ b/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -0,0 +1,163 @@ +import torch +import argparse +import tilelang +import tilelang.language as T +import math + + +def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): + """ + Perform grouped matrix multiplication using PyTorch. + + Args: + a (torch.Tensor): Input tensor of shape (N, K). + b (torch.Tensor): Input tensor of shape (G, K, M). + batch_sizes (torch.Tensor): 1D tensor containing the sizes of each group. + + Returns: + torch.Tensor: Resulting tensor after grouped matrix multiplication. + """ + assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a" + assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes" + + # Initialize output tensor + output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype) + + # Perform grouped GEMM + start = 0 + for i, size in enumerate(batch_sizes): + end = start + size + part_a = a[start:end] + part_b = b[i].transpose(0, 1) if trans_b else b[i] + part_out = torch.mm(part_a, part_b) + output[start:end] = part_out + start = end + + return output + + +@tilelang.jit(out_idx=[2]) +def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): + """ + args: + a (torch.Tensor): Input tensor of shape (M, K). + b (torch.Tensor): Input tensor of shape (G, K, N). + """ + batch_sum = sum(batch_sizes_list) + batch_count = len(batch_sizes_list) + accum_dtype = T.float32 + total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list) + + @T.prim_func + def kernel( + A: T.Tensor([batch_sum, K], dtype), # type: ignore + B: T.Tensor([batch_count, K, N], dtype), # type: ignore + C: T.Tensor([batch_sum, N], dtype), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore + ): + with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): + A_shared = T.alloc_shared([block_M, block_K], dtype) + B_shared = T.alloc_shared([block_K, block_N], dtype) + C_local = T.alloc_fragment([block_M, block_N], accum_dtype) + cur_batch_idx = T.alloc_local([1], T.int32) + cur_batch_size = T.alloc_local([1], T.int32) + + m_start_padded = bx * block_M + + for i in range(batch_count): + in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i] + cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0]) + + cur_batch_size[0] = batch_sizes[cur_batch_idx[0]] + m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]] + actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded)) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + for i, j in T.Parallel(block_M, block_N): + with T.If(i < actual_rows), T.Then(): + C[m_start + i, by * block_N + j] = C_local[i, j] + + return kernel + + +def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): + batch_sum = sum(batch_sizes_list) + batch_count = len(batch_sizes_list) + batch_offsets_list = [0] + batch_padded_offsets_list = [0] + for i in range(batch_count - 1): + batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) + for i in range(batch_count - 1): + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M) + A = torch.randn(batch_sum, K, device=device, dtype=dtype) + B = torch.randn(batch_count, K, M, device=device, dtype=dtype) + C = torch.empty(batch_sum, M, device=device, dtype=dtype) + batch_sizes = torch.tensor(batch_sizes_list, device=device, dtype=torch.int32) + batch_offsets = torch.tensor(batch_offsets_list, device=device, dtype=torch.int32) + batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=device, dtype=torch.int32) + # print(batch_sizes_tensor) + # print(batch_offsets_tensor) + # print(batch_padded_offsets_tensor) + return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets + + +def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False): + padding_M = block_M + batch_sum = sum(batch_sizes_list) + kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads) + # print(kernel.get_kernel_source()) + + device = torch.device("cuda") + dtype = torch.float16 + + A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype) + out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets) + ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b) + # print(out) + # print(ref_output) + if torch.allclose(out, ref_output, rtol=0.01, atol=0.01): + print("โœ… Tilelang and Torch match") + else: + print("โŒ Tilelang and Torch mismatch") + + if profile: + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets]) + print(f"Latency: {latency} ms") + print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops") + + +def test_grouped_gemm(): + run_tilelang_grouped_gemm([64], 8192, 8192, 64, 64, 64, False) + run_tilelang_grouped_gemm([64, 128, 256], 8192, 8192, 64, 64, 64, False) + run_tilelang_grouped_gemm([63], 8192, 8192, 64, 64, 64, False) + run_tilelang_grouped_gemm([100, 200, 300, 400], 8192, 8192, 64, 64, 64, False) + run_tilelang_grouped_gemm([63, 77, 111, 280], 8192, 8192, 64, 64, 64, False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes") + parser.add_argument("--K", type=int, default=8192, help="reduce dim") + parser.add_argument("--M", type=int, default=8192, help="output dim") + parser.add_argument("--trans_b", action="store_true", help="transpose B") + parser.add_argument("--profile", action="store_true", help="profile") + args = parser.parse_args() + + batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")] + K, M, trans_b = args.K, args.M, args.trans_b + + block_M = 64 + block_N = 128 + block_K = 64 + num_stages = 2 + threads = 256 + + run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile) diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py new file mode 100644 index 0000000000000000000000000000000000000000..65f463b71bb06c53ab579a1c0389e71b0b1c387e --- /dev/null +++ b/examples/hadamard_transform/example_hadamard.py @@ -0,0 +1,153 @@ +import tilelang +import tilelang.language as T +from tilelang.intrinsics import make_mma_swizzle_layout + +import math +import argparse +import torch +from torch.nn import functional as F +import scipy + + +def is_pow_of_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + +@tilelang.jit(out_idx=[1]) +def hadamard(b, n, dtype): + assert is_pow_of_2(n), "n must be a power of 2" + assert 2 <= n <= 32768, "n must be in [2, 32768]" + elem_size = {T.float32: 4, T.float16: 2, T.bfloat16: 2}[dtype] + + logN = int(math.log2(n)) + threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] + thread_elem = n // threads # Each thread is responsible for a chunk of elements + thread_round = int(math.log2(thread_elem)) + + warps = 1 if threads <= 32 else threads // 32 + warp_round = int(math.log2(threads / warps)) + warp_size = threads // warps + + block_round = int(math.log2(warps)) + + exchange_round = n * elem_size // 32768 if n * elem_size > 32768 else 1 # Suppose we use 32KB shared memory at most + thread_elem_in_smem = thread_elem // exchange_round if exchange_round > 1 else thread_elem + + # debug log + # print(f'{threads=}, {thread_round=}') + # print(f'{warps=}, {warp_round=}, {warp_size=}') + # print(f'{block_round=}') + # print(f'{exchange_round=}') + + @T.macro + def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int): + tx = T.get_thread_binding(0) + for i in T.serial(round): + tx_stride = 1 << i + another_tx = tx ^ tx_stride + sign = (tx >> i) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] + + for j in T.Pipelined(thread_elem, num_stages=1): + buf[j] = T.tvm_warp_shuffle( + 0xFFFFFFFF, # mask of all threads + local[j], + another_tx % warp_size, + warp_size, + warp_size, + ) + local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j]) + + @T.prim_func + def main(A: T.Tensor((b, n), dtype), B: T.Tensor((b, n), dtype)): + with T.Kernel(b, threads=threads) as bx: + local = T.alloc_local((thread_elem,), dtype) + shared = T.alloc_shared((threads, thread_elem_in_smem), dtype) + T.annotate_layout({shared: make_mma_swizzle_layout(shared)}) + tx = T.get_thread_binding(0) + + # 1. Load from HBM to register + for i in T.vectorized(thread_elem): + local[i] = A[bx, tx * thread_elem + i] + + # 2. Hadamard inside thread, n<=8 + for i in T.serial(thread_round): + chunksize = 1 << (i + 1) + chunknum = thread_elem // chunksize + for j in T.serial(chunknum): + chunkbase = j * chunksize + for k in T.serial(chunksize // 2): + local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] + local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] + + # 3. Hadamard inside warp, n<=512 + # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory + another_val = T.alloc_local((thread_elem,), dtype) + + warp_shfl(local, another_val, warp_round) + + # 4. Hadamard inside block, n<=32768 + # Only exchange once for n<=8192, since shared mem can hold all elems + if block_round > 0: + warp_id = tx // warp_size + lane_id = tx % warp_size + src_tx = warp_id * warp_size + lane_id + tgt_warp_id = tx % warps + tgt_lane_id = tx // warps + tgt_tx = tgt_warp_id * warp_size + tgt_lane_id + + # 4.1 Write to smem, swap, read from smem + for cur_round in T.serial(exchange_round): + exchange_base = thread_elem_in_smem * cur_round + for j in T.vectorized(thread_elem_in_smem): + shared[src_tx, j] = local[exchange_base + j] + + for j in T.vectorized(thread_elem_in_smem): + local[exchange_base + j] = shared[tgt_tx, j] + + # 4.2 Warp shuffle + warp_shfl(local, another_val, block_round) + + # 4.3 Write to smem, swap, read from smem + for cur_round in T.serial(exchange_round): + exchange_base = thread_elem_in_smem * cur_round + for j in T.vectorized(thread_elem_in_smem): + shared[tgt_tx, j] = local[exchange_base + j] + + for j in T.vectorized(thread_elem_in_smem): + local[exchange_base + j] = shared[src_tx, j] + + # 5. Write back to HBM + for i in T.vectorized(thread_elem): + B[bx, tx * thread_elem + i] = local[i] + + return main + + +def ref_program(x: torch.Tensor): + assert x.ndim == 2 + dim = x.shape[-1] + assert is_pow_of_2(dim) + return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="Batch size") + parser.add_argument("--dim", type=int, default=32768, help="Dimension") + args = parser.parse_args() + + B, D = args.batch, args.dim + x = torch.randn((B, D), device="cuda") + kernel = hadamard(B, D, T.float32) + y = kernel(x) + y_ref = ref_program(x) + torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) + print("All tests passed.") + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + latency = profiler.do_bench(warmup=100) + print("Tile-lang: {:.2f} ms".format(latency)) + + +if __name__ == "__main__": + main() diff --git a/examples/lazy_jit/lazyjit.en.ipynb b/examples/lazy_jit/lazyjit.en.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..99cb977f0066b97da63368e5f550ef160bd52f1d --- /dev/null +++ b/examples/lazy_jit/lazyjit.en.ipynb @@ -0,0 +1,789 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Lazy JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Lazy JIT combines the jit generation and invocation logic.\n", + "\n", + "The function signature syntax is similar to triton but with significant enhancements, most notably allowing Tensor annotations:\n", + "\n", + "For example, the code below annotates a 2D Tensor with T.Tensor[[int, int], T.float16]\n", + "1. Each dimension is a compile-time constant; changing it triggers recompilation\n", + "2. Its dtype must be T.float16\n", + "\n", + "DType can also be Any or None in addition to a concrete type\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm(\n", + " A: T.Tensor[[int, int], T.float16],\n", + " B: T.Tensor[[int, int], T.float16],\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, K = A.shape\n", + " K, N = B.shape\n", + " C = T.empty((M, N), out_dtype)\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "Call the Tensor directly as an argument to trigger the full jit compile-and-run workflow:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "Change the call-site arguments; if the compiler parameters differ, it recompiles:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "You can also manually call compile helpers to build a kernel\n", + "\n", + "1. `ker.compile` compiles the kernel\n", + "2. `ker.get_tir` retrieves the tir\n", + "3. `ker.par_compile` compiles in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2025-11-25 17:29:46 [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n" + ] + } + ], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### Separate the implementation with macros" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "Next we'll implement a simple gemm in several ways. For convenience, first write a macro that captures the main gemm logic:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### Mark dynamic shapes with T.dyn\n", + "\n", + "When some dimensions are dynamic, mark them with T.dyn. T.dyn can take a string argument to name the variable" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_dyn_K(\n", + " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", + " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", + "):\n", + " M, K = A.shape\n", + " K, N = B.shape\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "c60fd346", + "metadata": {}, + "source": [ + "Inspect the lazy_jit function signature: parameters with a `$` suffix are compile-time constants that may vary, and those with `$dyn` are runtime variables" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c6992eb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n", + " 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gemm_dyn_K.func.annot" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### Use T.StridedTensor to annotate tensors with strides\n", + "\n", + "Annotation format: T.StridedTensor[Shape, Stride, DType]. Each Shape or Stride entry can be\n", + "* int: compile-time constant\n", + "* T.dyn: runtime value\n", + "\n", + "DType can be None or Any" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", + " M, N = A.shape\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A[::2, ::2])\n", + "B_ref = A[::2, ::2].contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### Annotate tensors with T.ptr\n", + "lazy_jit lets you declare a handle with T.ptr, but you must define its shape inside the function via T.match_buffer" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr(\n", + " A: T.ptr,\n", + " B: T.ptr,\n", + " M: int,\n", + " N: int,\n", + " K: int,\n", + "):\n", + " A = T.match_buffer(A, (M, K), T.float16)\n", + " B = T.match_buffer(B, (K, N), T.float16)\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### Use T.int32 to annotate runtime variables\n", + "\n", + "lazy_jit lets you define runtime variables with T.int32 or other types, enabling a fully dynamic gemm similar to triton" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr_dyn(\n", + " A: T.ptr,\n", + " B: T.ptr,\n", + " M: T.int32,\n", + " N: T.int32,\n", + " K: T.int32,\n", + "):\n", + " A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n", + " B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## Compilation and parallel compilation" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "lazyjit and the original jit both support parallel compilation\n", + "\n", + "To avoid wasting memory with torch.tensor placeholders, use T.Tensor to create placeholders" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6d7f05cdfff412e9a527332438f7aa2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## More convenient macros" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang macros are now upgraded:\n", + "\n", + "1. Allow `T.Ref` as an annotation, similar to C++ pass-by-reference\n", + "2. Allow returning multiple values\n", + "3. Allow nesting and recursion" + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### Passing references with T.Ref\n", + "\n", + "The reference via T.Ref can target a var or a buffer element" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# import tilelang.language as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # Supports constant indices\n", + " macro_with_ref(x[1])\n", + "\n", + " # Also supports variable indices\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### Pass as arguments\n", + "\n", + "You can pass macros as parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def element_wise(\n", + " A: T.Tensor[[T.dyn], Any],\n", + " fn,\n", + "):\n", + " (N,) = A.shape\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Macro recursion\n", + "\n", + "Macro can be recursive, even if it's rarely needed, as long as the termination condition is known at compile time" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macro returning multiple values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# import tilelang.language as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/examples/lazy_jit/lazyjit.zh.ipynb b/examples/lazy_jit/lazyjit.zh.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..601c5c5d2fe3610adfed8edd890d6a701f5c49f8 --- /dev/null +++ b/examples/lazy_jit/lazyjit.zh.ipynb @@ -0,0 +1,789 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Lazy JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Lazy JIT ๅฐ† jit ็”Ÿๆˆๅ’Œ่ฐƒ็”จ็š„้€ป่พ‘ๅˆๅนถๅˆฐไธ€่ตท\n", + "\n", + "ๅ‡ฝๆ•ฐ็ญพๅ็š„ๅ†™ๆณ•ไธŽ triton ็›ธไผผ๏ผŒไฝ†ๅšไบ†ๅคง้‡ๅขžๅผบ๏ผŒๆœ€ไธป่ฆ็š„ๅขžๅผบๆ˜ฏๅ…่ฎธๅฏน Tensor ็š„ๆ ‡ๆณจ๏ผš\n", + "\n", + "ไพ‹ๅฆ‚๏ผŒไธ‹้ข็š„ไปฃ็ ็”จ T.Tensor[[int, int], T.float16] ๆฅๆ ‡ๆณจไบ†ไธ€ไธชไบŒ็ปด Tensor\n", + "1. ๅฎƒ็š„ๆฏไธช็ปดๅบฆ้ƒฝๆ˜ฏ็ผ–่ฏ‘ๆœŸๅธธ้‡๏ผŒๅฆ‚ๆžœๆ”นๅ˜๏ผŒไผš่งฆๅ‘้‡ๆ–ฐ็ผ–่ฏ‘\n", + "2. ๅฎƒ็š„็ฑปๅž‹ๅฟ…้กปๆ˜ฏ T.float16\n", + "\n", + "DType ้™คไบ†ๅ†™็กฎๅฎš็š„ๅค–๏ผŒ่ฟ˜ๅฏไปฅๅ†™ Any ๆˆ–่€… None" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm(\n", + " A: T.Tensor[[int, int], T.float16],\n", + " B: T.Tensor[[int, int], T.float16],\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, K = A.shape\n", + " K, N = B.shape\n", + " C = T.empty((M, N), out_dtype)\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "็›ดๆŽฅๅฐ† Tensor ไฝœไธบๅ‚ๆ•ฐ่ฐƒ็”จ๏ผŒๅณๅฏ่งฆๅ‘ๅฎŒๆ•ด็š„ jit ็ผ–่ฏ‘่ฟ่กŒๆต็จ‹๏ผš" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "ๆ›ดๆ”น่ฐƒ็”จ็š„ๅ‚ๆ•ฐ๏ผŒๅฆ‚ๆžœ็ผ–่ฏ‘ๅ™จๅ‚ๆ•ฐๅ‘็”Ÿไบ†ๅ˜ๅŒ–๏ผŒไผš่งฆๅ‘้‡ๆ–ฐ็ผ–่ฏ‘๏ผš" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "ไฝ ไนŸๅฏไปฅๆ‰‹ๅŠจ่ฐƒ็”จ compile ๅ‡ฝๆ•ฐ็ผ–่ฏ‘ kernel\n", + "\n", + "1. `ker.compile` ็ผ–่ฏ‘ kernel\n", + "2. `ker.get_tir` ่Žทๅ– tir\n", + "3. `ker.par_compile` ๅนถ่กŒ็ผ–่ฏ‘" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2025-11-25 17:29:46 [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n" + ] + } + ], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### ็”จ macro ๆฅๅˆ†็ฆปๅฎž็Žฐ" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "ๆŽฅไธ‹ๆฅ๏ผŒๆˆ‘ไปฌไผš็”จๅ„็งๆ–นๅผๆฅๅฎž็Žฐไธ€ไธช็ฎ€ๅ•็š„ gemm๏ผŒไธบไบ†ๆ–นไพฟ๏ผŒๆˆ‘ไปฌๅ…ˆๅ†™ไธ€ไธช macro ๆŠŠ gemm ็š„ไธป่ฆ้€ป่พ‘ๅ†™ๅ‡บๆฅ๏ผš" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### ็”จ T.dyn ๆ ‡่ฎฐๅŠจๆ€ Shape\n", + "\n", + "ๅฝ“ๆŸไบ›็ปดๅบฆๆ˜ฏๅŠจๆ€็š„็š„ๆ—ถๅ€™๏ผŒๅฏไปฅ็”จ T.dyn ๆฅๆ ‡่ฎฐใ€‚T.dyn ๅฏไปฅๆŽฅๅ—ไธ€ไธชๅญ—็ฌฆไธฒๅ‚ๆ•ฐ๏ผŒ่กจ็คบๅ˜้‡็š„ๅๅญ—" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_dyn_K(\n", + " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", + " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", + "):\n", + " M, K = A.shape\n", + " K, N = B.shape\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "c60fd346", + "metadata": {}, + "source": [ + "ๆŸฅ็œ‹ lazy_jit ็š„ๅ‡ฝๆ•ฐ็ญพๅ๏ผŒๅ…ถไธญๅธฆๆœ‰ๅŽ็ผ€`$` ็š„ๆ˜ฏไธ็กฎๅฎš็š„็ผ–่ฏ‘ๆœŸๅธธ้‡๏ผŒๅธฆๆœ‰ `$dyn` ็š„ๆ˜ฏ่ฟ่กŒๆ—ถ็š„ๅ˜้‡" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c6992eb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n", + " 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gemm_dyn_K.func.annot" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### ็”จ T.StridedTensor ๆ ‡่ฎฐๅธฆ stride ็š„ Tensor\n", + "\n", + "ๆ ‡่ฎฐๆ–นๆณ•๏ผšT.StridedTensor[Shape, Stride, DType]๏ผŒๆฏไธช Shape ๆˆ– Stride ๅฏไปฅๅ†™\n", + "* int: ่กจ็คบ็ผ–่ฏ‘ๆœŸๅธธ้‡\n", + "* T.dyn๏ผš่กจ็คบ่ฟ่กŒๆ—ถๅธธ้‡\n", + "\n", + "DType ๅฏไปฅๅ†™ None ๆˆ– Any" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", + " M, N = A.shape\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A[::2, ::2])\n", + "B_ref = A[::2, ::2].contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### ็”จ T.ptr ๆ ‡ๆณจ Tensor\n", + "lazy_jit ๅ…่ฎธไฝ ็”จ T.ptr ๆฅๅฃฐๆ˜Žไธ€ไธช handle๏ผŒไฝ†ๅฟ…้กปๅœจๅ‡ฝๆ•ฐๅ†…็”จ T.match_buffer ็ป™ๅฎƒๅฎšไน‰ shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr(\n", + " A: T.ptr,\n", + " B: T.ptr,\n", + " M: int,\n", + " N: int,\n", + " K: int,\n", + "):\n", + " A = T.match_buffer(A, (M, K), T.float16)\n", + " B = T.match_buffer(B, (K, N), T.float16)\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### ็”จ T.int32 ๆ ‡ๆณจ่ฟ่กŒๆ—ถๅ˜้‡\n", + "\n", + "lazy_jit ๅ…่ฎธไฝ ็”จ T.int32 ๆˆ–ๅ…ถไป–็ฑปๅž‹ๆฅๅฎšไน‰่ฟ่กŒๆ—ถๅ˜้‡๏ผŒ่ฟ™ๆ ท๏ผŒไฝ ๅฏไปฅๅ†™ไธ€ไธชๅฎŒๅ…จๅŠจๆ€็š„ gemm๏ผŒ่ฟ™ๅ’Œ triton ้žๅธธ็›ธไผผ" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def gemm_ptr_dyn(\n", + " A: T.ptr,\n", + " B: T.ptr,\n", + " M: T.int32,\n", + " N: T.int32,\n", + " K: T.int32,\n", + "):\n", + " A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n", + " B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## ็ผ–่ฏ‘ไธŽๅนถ่กŒ็ผ–่ฏ‘" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "lazyjit ๅ’ŒๅŽŸๆฅ็š„ jit ้ƒฝๆ”ฏๆŒๅนถ่กŒ็ผ–่ฏ‘\n", + "\n", + "ไธบไบ†้˜ฒๆญข torch.tensor ็™ฝ็™ฝๆตช่ดนๅ†…ๅญ˜๏ผŒๅฏไปฅไฝฟ็”จ T.Tensor ๆฅๅˆ›ๅปบ placeholder" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6d7f05cdfff412e9a527332438f7aa2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## ๆ›ดไพฟๅˆฉ็š„ Macro" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang ็š„ macro ็Žฐๅœจๅทฒ็ปๅ‡็บง๏ผš\n", + "\n", + "1. ๅ…่ฎธ็”จ `T.Ref` ไฝœไธบ annotation๏ผŒ่ฟ™็ฑปไผผไธŽ C++ ็š„ๅผ•็”จไผ ้€’\n", + "2. ๅ…่ฎธ่ฟ”ๅ›žๅคšไธชๅ€ผ\n", + "3. ๅ…่ฎธๅตŒๅฅ—๏ผŒ้€’ๅฝ’" + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### T.Ref ไผ ้€’ๅผ•็”จ\n", + "\n", + "T.Ref ไผ ้€’็š„ๅผ•็”จๅฏไปฅ var ไนŸๅฏไปฅๆ˜ฏ Buffer ็š„็ดขๅผ•" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# import tilelang.language as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # ๆ”ฏๆŒๅธธ้‡ index\n", + " macro_with_ref(x[1])\n", + "\n", + " # ไนŸๆ”ฏๆŒๅ˜้‡ index\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### ๅฝ“ไฝœๅ‚ๆ•ฐไผ ้€’\n", + "\n", + "ไฝ ๅฏไปฅๆŠŠ macro ๅฝ“ๅšๅ‚ๆ•ฐไผ ้€’" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.lazy_jit\n", + "def element_wise(\n", + " A: T.Tensor[[T.dyn], Any],\n", + " fn,\n", + "):\n", + " (N,) = A.shape\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Macro ้€’ๅฝ’\n", + "\n", + "่™ฝ็„ถไธ็Ÿฅ้“ๆœ‰ๆฒกๆœ‰่ฟ™็ง้œ€ๆฑ‚๏ผŒไฝ† macro ๆ˜ฏๅฏไปฅ้€’ๅฝ’็š„๏ผŒไฝ†่ฆๆฑ‚็ปˆๆญขๆกไปถ็ผ–่ฏ‘ๆœŸ้—ด็กฎๅฎš" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.lazy_jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macro ่ฟ”ๅ›žๅคšไธชๅ€ผ" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# import tilelang.language as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/examples/linear_attention/README.md b/examples/linear_attention/README.md new file mode 100644 index 0000000000000000000000000000000000000000..92b10692b32a8c9f1aa5ed979510acd5321f84e4 --- /dev/null +++ b/examples/linear_attention/README.md @@ -0,0 +1 @@ +# Linear Attention diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..397ec7bdf6fe6d27a47e3d29c82d1b331ebb277d --- /dev/null +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -0,0 +1,203 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +import argparse +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA +from fla.modules.l2norm import l2norm_fwd +from einops import rearrange +from typing import Optional, Tuple + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) +def tl_fused_chunk_bwd_kernel( + B, + S, + H, + DK, + DV, + dtype: T.dtype = T.float16, + scale: float = None, +) -> torch.Tensor: + if scale is None: + scale = DK**-0.5 + accum_dtype = T.float32 + + chunk_size = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tilelang.cdiv(DK, BK) + NV = tilelang.cdiv(DV, BV) + NT = tilelang.cdiv(S, chunk_size) + + @T.prim_func + def fused_chunk_linear_attn_bwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + dO: T.Tensor([B, S, H, DV], dtype), # type: ignore + dQ: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dK: T.Tensor([B, S, H, DK], accum_dtype), # type: ignore + dV: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + ): + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + + ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + ds_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + dq = T.alloc_fragment([chunk_size, BK], accum_dtype) + dq_shared = T.alloc_shared([chunk_size, BK], accum_dtype) + dk = T.alloc_fragment([chunk_size, BK], accum_dtype) + dk_shared = T.alloc_shared([chunk_size, BK], accum_dtype) + dv = T.alloc_fragment([chunk_size, BV], accum_dtype) + dv_shared = T.alloc_shared([chunk_size, BV], accum_dtype) + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + do = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BV, BK], accum_dtype) + h_shared = T.alloc_shared([BV, BK], dtype) + dh = T.alloc_fragment([BK, BV], accum_dtype) + dh_shared = T.alloc_shared([BK, BV], dtype) + + T.annotate_layout( + { + dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), + } + ) + T.use_swizzle(10) + + T.clear(h) + T.clear(dh) + + # Calculate dQ + for i in T.Pipelined(0, NT): + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) + + T.gemm(do, v, ds, transpose_B=True, clear_accum=True) + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else(row >= col, ds[row, col], 0) + + T.gemm(ds_shared, k, dq, clear_accum=True) + T.copy(h, h_shared) + T.gemm(do, h_shared, dq) + T.gemm(v, k, h, transpose_A=True) + for row, col in T.Parallel(chunk_size, BK): + dq[row, col] *= scale + T.copy(dq, dq_shared) + T.atomic_add(dQ[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dq_shared) + + # Calculate dK, dV (reversely) + for i in T.Pipelined(1, NT + 1): + start = NT - i + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy(K[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + T.copy(dO[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], do) + + # Calculate dk + T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds` + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) + T.gemm(ds_shared, q, dk, clear_accum=True) + T.copy(dh, dh_shared) + T.gemm(v, dh_shared, dk, transpose_B=True) + + # Calculate dv + T.gemm(k, q, ds, transpose_B=True, clear_accum=True) + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) + T.gemm(ds_shared, do, dv, clear_accum=True) + T.gemm(k, dh_shared, dv) + + # Update dh + T.gemm(q, do, dh, transpose_A=True) + + T.copy(dk, dk_shared) + T.atomic_add(dK[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], dk_shared) + T.copy(dv, dv_shared) + T.atomic_add(dV[i_b, start * chunk_size : (start + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], dv_shared) + + return fused_chunk_linear_attn_bwd + + +def tl_fused_chunk_bwd(Q, K, V, dO): + B, S, H, D = Q.shape + kernel = tl_fused_chunk_bwd_kernel(B, S, H, D, D) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros_like(K, dtype=torch.float32) + dV = torch.zeros_like(V, dtype=torch.float32) + kernel(Q, K, V, dO, dQ, dK, dV) + return dQ.to(torch.float16), dK.to(torch.float16), dV.to(torch.float16) + + +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: + q, k, v = q.float(), k.float(), v.float() + if scale is None: + scale = q.shape[-1] ** -0.5 + chunk_size = 64 + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + h = kv[:, :, -1, :, :] + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + inter = q @ kv + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v + o = inter + intra + return rearrange(o, "b h n c d -> b (n c) h d"), h + + +def main(B=1, S=1024, H=16, D=128): + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + + # qk norm is necessary for linear attn + q = l2norm_fwd(q)[0].requires_grad_(True) + k = l2norm_fwd(k)[0].requires_grad_(True) + + dq, dk, dv = tl_fused_chunk_bwd(q, k, v, do) + q.grad = k.grad = v.grad = None + o_ref, _ = ref_program(q, k, v) + o_ref.backward(do, retain_graph=True) + + assert torch.allclose(dq, q.grad, atol=1e-2, rtol=1e-2), f"dq max err: {(dq - q.grad).abs().max()}" + assert torch.allclose(dk, k.grad, atol=1e-2, rtol=1e-2), f"dk max err: {(dk - k.grad).abs().max()}" + assert torch.allclose(dv, v.grad, atol=1e-2, rtol=1e-2), f"dv max err: {(dv - v.grad).abs().max()}" + print("Passed all tests!โœ…") + + # Benchmark + q.grad = k.grad = v.grad = None + o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) + t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") + args = parser.parse_args() + + main(args.B, args.S, args.H, args.D) diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..849841e5179e3a3baccf1bb7794e425cd5fee990 --- /dev/null +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -0,0 +1,149 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +import argparse +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA +from fla.modules.l2norm import l2norm_fwd +from einops import rearrange +from typing import Optional, Tuple + + +@tilelang.jit( + out_idx=[4], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def tl_fused_chunk_fwd_kernel( + B, + S, + H, + DK, + DV, + dtype: T.dtype = T.float16, + scale: float = None, +) -> torch.Tensor: + if scale is None: + scale = DK**-0.5 + accum_dtype = T.float32 + + chunk_size = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tilelang.cdiv(DK, BK) + NV = tilelang.cdiv(DV, BV) + NT = tilelang.cdiv(S, chunk_size) + + @T.prim_func + def fused_chunk_linear_attn_fwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([B, S, H, DV], accum_dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype), + ): # type: ignore + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BK, BV], accum_dtype) + h_shared = T.alloc_shared([BK, BV], dtype) + s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + o = T.alloc_fragment([chunk_size, BV], accum_dtype) + o_shared = T.alloc_shared([chunk_size, BV], accum_dtype) + + T.annotate_layout({o_shared: tilelang.layout.make_swizzled_layout(o_shared)}) + T.use_swizzle(10) + + T.clear(h) + + for i in T.Pipelined(0, NT): + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + + T.gemm(q, k, s, clear_accum=True, transpose_B=True) + for row, col in T.Parallel(chunk_size, chunk_size): + s_shared[row, col] = T.if_then_else(row >= col, s[row, col], 0) + + T.gemm(s_shared, v, o, clear_accum=True) + T.copy(h, h_shared) + T.gemm(k, v, h, transpose_A=True) + T.gemm(q, h_shared, o) + T.copy(o, o_shared) + T.atomic_add(O[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], o_shared) + + # Output final state + T.copy(h, final_state[i_b, i_h, i_k * BK : (i_k + 1) * BK, i_v * BV : (i_v + 1) * BV]) + + return fused_chunk_linear_attn_fwd + + +def tl_fused_chunk_fwd(q, k, v): + B, S, H, D = q.shape + kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) + print(kernel.get_kernel_source()) + o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32) + h = kernel(q, k, v, o) + return o, h + + +def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: + q, k, v = q.float(), k.float(), v.float() + if scale is None: + scale = q.shape[-1] ** -0.5 + chunk_size = 64 + q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale + k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size) + v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + h = kv[:, :, -1, :, :] + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + inter = q @ kv + intra = ( + (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + ) @ v + o = inter + intra + return rearrange(o, "b h n c d -> b (n c) h d"), h + + +def main(B=1, S=512, H=16, D=128): + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + + # qk norm is necessary for linear attn + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + + o, h = tl_fused_chunk_fwd(q, k, v) + o_ref, h_ref = ref_program(q, k, v) + + assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f"o max err: {(o - o_ref).abs().max()}" + assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f"h max err: {(h - h_ref).abs().max()}" + print("Passed all tests!โœ…") + + t1 = do_bench(lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), backend="cupti") + t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend="cupti") + print(f"Triton latency: {t1:.3f} ms") + print(f"TileLang latency: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.3f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=1024, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") + args = parser.parse_args() + + main(args.B, args.S, args.H, args.D) diff --git a/examples/linear_attention/example_mamba_chunk_scan.py b/examples/linear_attention/example_mamba_chunk_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..1958dfb5aa95a64fd38e40f6632b787b39150c19 --- /dev/null +++ b/examples/linear_attention/example_mamba_chunk_scan.py @@ -0,0 +1,285 @@ +import argparse +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, repeat +import itertools + + +def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd + + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) + return out + + +def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): + """ + Argument: + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + C: (batch, seqlen, ngroups, dstate) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + _, _, ngroups, _, _ = cb.shape + batch, seqlen, nheads, headdim = x.shape + # _, _, ngroups, dstate = B.shape + # assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + # assert C.shape == B.shape + # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups) + # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + # rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum( + "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks) + ) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = ( + torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out + ) + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return out + + +def get_configs(): + iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[7], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def chunk_scan_fwd( + batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, +): + dtype = T.float16 + accum_dtype = T.float32 + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + + @T.prim_func + def main( + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore + D: T.Tensor((nheads), dtype), # type: ignore + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore + ): + with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as ( + bz, + bx, + by, + ): + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") + cb_local = T.alloc_fragment((block_M, block_K), dtype) + dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") + dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) + dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) + dt_shared = T.alloc_shared((block_K), dtype, scope="shared") + dt_local = T.alloc_fragment((block_K), accum_dtype) + x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") + dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") + scale_m_local = T.alloc_fragment((block_M), accum_dtype) + C_shared = T.alloc_shared((block_M, block_Dstate), dtype) + prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) + D_local = T.alloc_fragment((1), accum_dtype) + x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") + x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + batch_idx = by % batch + chunk_idx = by // batch + # m: chunk_size + # n : headdim + m_idx = bx // T.ceildiv(headdim, block_N) + n_idx = bx % T.ceildiv(headdim, block_N) + + T.annotate_layout( + { + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared), + } + ) + + T.no_set_max_nreg() + + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared) + T.copy(dA_cs_m_shared, dA_cs_m_local) + T.clear(acc_o) + + for i in T.Parallel(block_M): + scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) + T.copy( + C[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz // (nheads // ngroups), + 0:block_Dstate, + ], + C_shared, + ) + T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) + T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] *= scale_m_local[i] + + loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + cb[ + batch_idx, + chunk_idx, + bz // (nheads // ngroups), + m_idx * block_M : (m_idx + 1) * block_M, + k * block_K : (k + 1) * block_K, + ], + cb_shared, + ) + T.copy(cb_shared, cb_local) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared) + T.copy(dA_cs_k_shared, dA_cs_k_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) + T.copy(dt_shared, dt_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] *= dt_local[j] + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) + T.copy( + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_shared, + ) + T.gemm(cb_local, x_shared, acc_o) + + D_local[0] = D[bz] + T.copy( + x[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + x_residual_shared, + ) + T.copy(x_residual_shared, x_residual_local) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] += x_residual_local[i, j] * D_local[0] + + T.copy(acc_o, acc_o_shared) + T.copy( + acc_o_shared, + Output[ + batch_idx, + chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M, + bz, + n_idx * block_N : (n_idx + 1) * block_N, + ], + ) + + return main + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) + total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate + + if not args.tune: + kernel = chunk_scan_fwd( + batch, + seq_len, + chunk_size, + groups, + heads, + dim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128, + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") diff --git a/examples/linear_attention/example_mamba_chunk_state.py b/examples/linear_attention/example_mamba_chunk_state.py new file mode 100644 index 0000000000000000000000000000000000000000..fb766d5e9c9c7d4b7e4de1f1f15f801f84b9de03 --- /dev/null +++ b/examples/linear_attention/example_mamba_chunk_state.py @@ -0,0 +1,178 @@ +import argparse +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, repeat +import itertools + + +def chunk_state_triton(B, x, dt, dA_cumsum): + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd + + return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False) + + +def ref_program(B, x, dt, dA_cumsum): + """ + Argument: + B: (batch, seqlen, ngroups, headdim) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + Return: + states: (batch, nchunks, nheads, headdim, dstate) + """ + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + dstate = B.shape[-1] + _, _, nchunks, chunk_size = dt.shape + assert seqlen <= nchunks * chunk_size + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + ngroups = B.shape[2] + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seqlen < nchunks * chunk_size: + x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) + x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) + B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) + decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) + return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) + + +def get_configs(): + iter_params = dict(block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[4]) +def chunk_state_fwd( + batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128 +): + dtype = T.float16 + accum_dtype = T.float32 + nchunks = T.ceildiv(seqlen, chunk_size) + p = 1.44269504 + + @T.prim_func + def main( + B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), + Output: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), + ): + with T.Kernel(nheads, T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), batch * nchunks, threads=threads) as (bz, bx, by): + x_shared = T.alloc_shared((block_K, block_M), dtype) + x_local = T.alloc_fragment((block_K, block_M), dtype) + xt_local = T.alloc_fragment((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + dt_shared = T.alloc_shared((block_K), dtype) + dA_cumsum_shared = T.alloc_shared((block_K), dtype) + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + scale = T.alloc_fragment((block_K), accum_dtype) + dA_cs_last = T.alloc_fragment((1), accum_dtype) + dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype) + dt_local = T.alloc_fragment((block_K), accum_dtype) + + loop_range = T.ceildiv(chunk_size, block_K) + + batch_idx = by % batch + chunk_idx = by // batch + m_idx = bx // T.ceildiv(dstate, block_N) + n_idx = bx % T.ceildiv(dstate, block_N) + + T.annotate_layout( + {x_shared: tilelang.layout.make_swizzled_layout(x_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)} + ) + + dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] + T.clear(acc_o) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + x[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz, + m_idx * block_M : (m_idx + 1) * block_M, + ], + x_shared, + ) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cumsum_shared) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared) + T.copy(dA_cumsum_shared, dA_cumsum_local) + T.copy(dt_shared, dt_local) + for i in T.Parallel(block_K): + scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i] + T.copy(x_shared, x_local) + for i, j in T.Parallel(block_M, block_K): + xt_local[i, j] = x_local[j, i] * scale[j] + T.copy( + B[ + batch_idx, + chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K, + bz // (nheads // ngroups), + n_idx * block_N : (n_idx + 1) * block_N, + ], + B_shared, + ) + T.gemm(xt_local, B_shared, acc_o) + T.copy(acc_o, acc_o_shared) + T.copy( + acc_o_shared, + Output[batch_idx, chunk_idx, bz, m_idx * block_M : (m_idx + 1) * block_M, n_idx * block_N : (n_idx + 1) * block_N], + ) + + return main + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=80, help="heads") + parser.add_argument("--groups", type=int, default=1, help="groups") + parser.add_argument("--seq_len", type=int, default=4096, help="sequence length") + parser.add_argument("--chunk_size", type=int, default=256, help="chunk size") + parser.add_argument("--dim", type=int, default=64, help="dim") + parser.add_argument("--dstate", type=int, default=128, help="dstate") + parser.add_argument("--tune", action="store_true", help="tune configs") + args = parser.parse_args() + batch, heads, groups, seq_len, chunk_size, dim, dstate = ( + args.batch, + args.heads, + args.groups, + args.seq_len, + args.chunk_size, + args.dim, + args.dstate, + ) + total_flops = 2 * batch * seq_len * heads * dim * dstate + + if not args.tune: + kernel = chunk_state_fwd( + batch, seq_len, chunk_size, groups, heads, dim, dstate, block_M=64, block_N=128, block_K=64, num_stages=4, threads=128 + ) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = chunk_state_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) + best_latency = best_result.latency + best_config = best_result.config + ref_latency = best_result.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") diff --git a/examples/linear_attention/example_retention_fwd.py b/examples/linear_attention/example_retention_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..f45e383889bd7ef2d93e1a00539e72110465ea43 --- /dev/null +++ b/examples/linear_attention/example_retention_fwd.py @@ -0,0 +1,107 @@ +import torch +import tilelang as tl +import tilelang.language as T +from tilelang.profiler import do_bench + +import argparse + + +@tl.jit(out_idx=3, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def chunk_retention_fwd_kernel( + B, + S, + H, + DK, + DV, + dtype: T.dtype = T.float16, + scale: float = None, +) -> torch.Tensor: + if scale is None: + scale = DK**-0.5 + accum_dtype = T.float32 + + chunk_size = 64 + BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tl.cdiv(DK, BK) + NV = tl.cdiv(DV, BV) + NT = tl.cdiv(S, chunk_size) + + @T.prim_func + def chunk_retention_fwd( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + ): + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + log_decay = T.alloc_var(T.float32) + log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay + + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BK, BV], accum_dtype) + h_shared = T.alloc_shared([BK, BV], dtype) + s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + o = T.alloc_fragment([chunk_size, BV], accum_dtype) + T.clear(h) + + T.use_swizzle(10) + + for i in T.Pipelined(0, NT): + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v) + + T.gemm(q, k, s, clear_accum=True, transpose_B=True) + for row, col in T.Parallel(chunk_size, chunk_size): + s_shared[row, col] = T.if_then_else(row >= col, s[row, col] * T.exp2((row - col) * log_decay), 0) + + T.copy(h, h_shared) + T.gemm(q, h_shared, o, clear_accum=True) + for row, col in T.Parallel(chunk_size, BV): + o[row, col] = T.exp2((row + 1) * log_decay) * o[row, col] + T.gemm(s_shared, v, o) + + for row, col in T.Parallel(chunk_size, BV): + v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay) + for row, col in T.Parallel(BK, BV): + h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col] + T.copy(o, O[i_k, i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV]) + T.gemm(k, v, h, transpose_A=True) + + return chunk_retention_fwd + + +def postprocess(o): + return o if o.size(0) == 1 else o.sum(0) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--B", type=int, default=8, help="Batch size") + parser.add_argument("--S", type=int, default=4096, help="Seq len") + parser.add_argument("--H", type=int, default=32, help="Num heads") + parser.add_argument("--D", type=int, default=128, help="Head dim") + args = parser.parse_args() + B, S, H, D = args.B, args.S, args.H, args.D + total_flops = 2.0 * B * S * S * H * D # causal + + q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16) + + kernel = chunk_retention_fwd_kernel(B, S, H, D, D) + + t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100) + print(f"Tilelang latency: {t:.3f} ms") + print(f"Tilelang TFLOPs: {total_flops / t * 1e-9}") + + +if __name__ == "__main__": + main() diff --git a/examples/linear_attention/test_linear_attn.py b/examples/linear_attention/test_linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..346fa8e96ed983bb12f8bb0845611203d536392a --- /dev/null +++ b/examples/linear_attention/test_linear_attn.py @@ -0,0 +1,18 @@ +import tilelang.testing + +import example_linear_attn_fwd +import example_linear_attn_bwd + + +@tilelang.testing.requires_cuda +def test_example_linear_attn_fwd(): + example_linear_attn_fwd.main() + + +@tilelang.testing.requires_cuda +def test_example_linear_attn_bwd(): + example_linear_attn_bwd.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/minference/README.md b/examples/minference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8cba732609e59295453a32a846a21fbd73bcb3cd --- /dev/null +++ b/examples/minference/README.md @@ -0,0 +1,28 @@ +# Performance Benchmark + +## Hardware & Environment +- **Hardware**: NVIDIA H100 PCIe +- **CUDA version**: 12.8.1 +- **PyTorch Version**: 2.7.1+cu128 +- **Triton Version**: 3.3.1 + +## Performance Results +BATCH_SIZE=1, HEAD=1, DIM=64 + +| SEQ_LEN | VS_LIST | Triton Time | TileLang Time | Speedup | +|---------|--------------|-------------|---------------|---------| +| 8192 | [1000, 200] | 0.168 ms | 0.105 ms | 1.60x | +| 8192 | [1000, 600] | 0.207 ms | 0.119 ms | 1.74x | +| 8192 | [800, 600] | 0.207 ms | 0.122 ms | 1.70x | +| | | | | | +| 16384 | [1000, 200] | 0.261 ms | 0.167 ms | 1.56x | +| 16384 | [1000, 600] | 0.419 ms | 0.258 ms | 1.62x | +| 16384 | [800, 600] | 0.422 ms | 0.255 ms | 1.65x | +| | | | | | +| 32768 | [1000, 200] | 0.374 ms | 0.248 ms | 1.51x | +| 32768 | [1000, 600] | 0.823 ms | 0.554 ms | 1.49x | +| 32768 | [800, 600] | 0.826 ms | 0.558 ms | 1.48x | +| | | | | | +| 65536 | [1000, 200] | 0.637 ms | 0.524 ms | 1.22x | +| 65536 | [1000, 600] | 1.758 ms | 1.501 ms | 1.17x | +| 65536 | [800, 600] | 1.783 ms | 1.489 ms | 1.20x | diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..f96e73ae511d300e3fa3569ef6910805ea19bca6 --- /dev/null +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -0,0 +1,623 @@ +# Copyright (c) 2024-2025 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import math +import argparse + +import torch +import triton +import triton.language as tl + +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench + + +@tilelang.jit(out_idx=[3]) +def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size): + block_M = 64 + block_N = 64 + num_stages = 2 + threads = 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 + shape = [batch, heads, seq_len, dim] + + seq_blocks = (seq_len + block_M - 1) // block_M + + count_shape = [batch, heads, seq_blocks] + + offset_shape = count_shape + [slash_size] + index_shape = count_shape + [vertical_size] + + vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size) + + dtype = T.float16 + accum_dtype = T.float32 + int_dtype = T.int32 + + def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def Prefetch( + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + column_index: T.SharedBuffer([vertical_size_round], int_dtype), + column_count: T.int32, + k: T.int32, + bz: T.int32, + by: T.int32, + ): + with T.attr("default", "async_scope", 1): + for i, j in T.Parallel(block_N, dim): + K_shared[i, j] = T.if_then_else(k + i < column_count, K[bz, by, column_index[k + i], j], 0) + + with T.attr("default", "async_scope", 1): + for i, j in T.Parallel(block_N, dim): + V_shared[i, j] = T.if_then_else(k + i < column_count, V[bz, by, column_index[k + i], j], 0) + + T.ptx_commit_group() + + @T.macro + def Compute( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + k: T.int32, + column_count: T.int32, + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + count: T.int32, + ): + T.ptx_wait_group(count) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k + j < column_count, 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] = acc_o[i, j] * scores_scale[i] + + T.copy(acc_s, acc_s_cast) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + @T.prim_func + def vs_sparse_flashattn_ws( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + BlockCount: T.Tensor(count_shape, int_dtype), + BlockOffset: T.Tensor(offset_shape, int_dtype), + ColumnCount: T.Tensor(count_shape, int_dtype), + ColumnIndex: T.Tensor(index_shape, int_dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz): + bx = T.ceildiv(seq_len, block_M) - 1 - bc + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([2, block_N, dim], dtype) + V_shared = T.alloc_shared([2, block_N, dim], dtype) + K_shared_1 = T.alloc_shared([block_N, dim], dtype) + V_shared_1 = T.alloc_shared([block_N, dim], dtype) + K_shared_2 = T.alloc_shared([block_N, dim], dtype) + V_shared_2 = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_count = T.alloc_local([1], int_dtype) + block_offset = T.alloc_shared([slash_size_round], int_dtype, scope="shared") + column_count = T.alloc_local([1], int_dtype) + column_index = T.alloc_shared([vertical_size_round], int_dtype, scope="shared") + + T.create_list_of_mbarrier([128] * 9) + + T.annotate_layout( + { + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + block_count[0] = BlockCount[bz, by, bx] + column_count[0] = ColumnCount[bz, by, bx] + + for vi in T.Parallel(slash_size_round): + if vi < slash_size: + block_offset[vi] = BlockOffset[bz, by, bx, vi] + + for vi in T.Parallel(vertical_size_round): + if vi < vertical_size: + column_index[vi] = ColumnIndex[bz, by, bx, vi] + + tid = T.get_thread_binding() + + if tid >= 128: + T.annotate_producer_reg_dealloc() + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.mbarrier_arrive(mbarrier=8) + for bi in T.serial(block_count[0]): + k = block_offset[bi] + T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1)) + T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :]) + T.mbarrier_arrive(mbarrier=bi % 2) + T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1)) + T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :]) + T.mbarrier_arrive(mbarrier=bi % 2 + 2) + else: + T.annotate_consumer_reg_alloc() + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.mbarrier_wait_parity(mbarrier=8, parity=0) + for bi in T.serial(block_count[0]): + k = block_offset[bi] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype)) + + T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1)) + T.gemm(Q_shared, K_shared[bi % 2, :, :], acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.mbarrier_arrive(mbarrier=bi % 2 + 4) + + T.copy(scores_max, scores_max_prev) + + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] = acc_o[i, j] * scores_scale[i] + + T.copy(acc_s, acc_s_cast) + T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=((bi & 3) >> 1)) + T.gemm(acc_s_cast, V_shared[bi % 2, :, :], acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.mbarrier_arrive(mbarrier=bi % 2 + 6) + + T.reduce_sum(acc_s, scores_sum, dim=1) + + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + if column_count[0] != 0: + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, by) + for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1): + k = bi * block_N + if bi % 2 == 0: + Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count[0], k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count[0], + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 1, + ) + else: + Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], k + block_N, bz, by) + + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + k, + column_count[0], + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 1, + ) + if T.ceildiv(column_count[0], block_N) % 2 == 0: + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count[0], block_N) * block_N - block_N, + column_count[0], + Q_shared, + K_shared_2, + V_shared_2, + scores_scale, + scores_sum, + logsum, + 0, + ) + else: + Compute( + acc_s, + acc_s_cast, + acc_o, + scores_max, + scores_max_prev, + T.ceildiv(column_count[0], block_N) * block_N - block_N, + column_count[0], + Q_shared, + K_shared_1, + V_shared_1, + scores_scale, + scores_sum, + logsum, + 0, + ) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return vs_sparse_flashattn_ws + + return kernel_func(block_M, block_N, num_stages, threads) + + +@triton.jit +def _triton_mixed_sparse_attn_fwd_kernel( + Q, + K, + V, + seqlens, + sm_scale, + block_count, + block_offset, + column_count, + column_index, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_oz, + stride_oh, + stride_om, + stride_ok, + Z, + H, + N_CTX, + NUM_ROWS, + NNZ_S, + NNZ_V, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + dtype: tl.constexpr, +): + start_m = tl.program_id(0) # bx + off_hz = tl.program_id(1) # by + + seqlen = tl.load(seqlens + off_hz // H) + if start_m * BLOCK_M >= seqlen: + return + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh + kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh + + q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk + v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk + o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok + + num_blks = tl.load(block_count + off_hz * NUM_ROWS + start_m) + blks_ptr = block_offset + (off_hz * NUM_ROWS + start_m) * NNZ_S + num_cols = tl.load(column_count + off_hz * NUM_ROWS + start_m) + cols_ptr = column_index + (off_hz * NUM_ROWS + start_m) * NNZ_V + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + q = (q * qk_scale).to(dtype) + + # loop over k, v and update accumulator + m_mask = offs_m[:, None] < seqlen + + for block_index in range(num_blks): + start_n = tl.load(blks_ptr + block_index) + cols = start_n + offs_n + n_mask = cols < seqlen + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0) + v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0) + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + causal_mask = cols[None, :] <= offs_m[:, None] + qk = tl.where(m_mask & causal_mask, qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(dtype), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + for start_n in range(0, num_cols, BLOCK_N): # + # bi * BLOCK_N: bi * BLOCK_N + BLOCK_N + n_mask = start_n + offs_n < num_cols + cols = tl.load(cols_ptr + start_n + offs_n, mask=n_mask, other=0) + # -- load k, v -- + k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0) + v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0) + # -- compute qk -- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.where(m_mask & n_mask, qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant -- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(dtype), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back O + acc /= l_i[:, None] + # acc = tl.where(m_mask, acc / l_i[:, None], 0.0) + tl.store(o_ptrs, acc.to(dtype), mask=m_mask) + + +def _triton_mixed_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlens: torch.Tensor, + block_count: torch.Tensor, + block_offset: torch.Tensor, + column_count: torch.Tensor, + column_index: torch.Tensor, + sm_scale: float, + block_size_M: int = 64, + block_size_N: int = 64, +) -> torch.Tensor: + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.zeros_like(q) + grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1) + dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16 + _triton_mixed_sparse_attn_fwd_kernel[grid]( + q, + k, + v, + seqlens, + sm_scale, + block_count, + block_offset, + column_count, + column_index, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + block_count.shape[-1], + block_offset.shape[-1], + column_index.shape[-1], + BLOCK_M=block_size_M, + BLOCK_N=block_size_N, + BLOCK_DMODEL=Lk, + dtype=dtype, + num_warps=4, + num_stages=2, + ) + + return o + + +def vertical_slash_sparse_attention( + query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + block_size_M: int = 64, + block_size_N: int = 64, +): + from torch.utils.cpp_extension import load + import os + + current_dir = os.path.dirname(os.path.abspath(__file__)) + sources = [os.path.join(current_dir, "ops", "kernels.cpp"), os.path.join(current_dir, "ops", "vertical_slash_index.cu")] + ops = load(name="convert", sources=sources, verbose=False) + convert_vertical_slash_indexes = ops.convert_vertical_slash_indexes + batch_size, num_heads, context_size, head_dim = query.shape + pad = (block_size_M - context_size) & (block_size_M - 1) + if pad == block_size_M: + pad = 0 + query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) + key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) + value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim + query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + + seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device) + sm_scale = head_dim**-0.5 + block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes( + seqlens, + v_idx, + s_idx, + context_size, + block_size_M, + block_size_N, + ) + + tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, v_idx.shape[2], s_idx.shape[2]) + + def run(is_triton: bool = True): + if is_triton: + out = _triton_mixed_sparse_attention( + query, + key, + value, + seqlens, + block_count, + block_offset, + column_count, + column_index, + sm_scale, + block_size_M, + block_size_N, + ) + else: + out = tl_kernel(query, key, value, block_count, block_offset, column_count, column_index) + return out[..., :context_size, :head_dim] + + return run + + +def sum_all_diagonal_matrix(mat: torch.tensor): + b, h, n, m = mat.shape + zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right + mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides + sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns + return sum_diags[:, :, 1:] + + +def main(argv=None): + parser = argparse.ArgumentParser() + + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--heads", type=int, default=1) + parser.add_argument("--seq_len", type=int, default=16384) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--vertical_size", type=int, default=1000) + parser.add_argument("--slash_size", type=int, default=200) + + args = parser.parse_args(argv) + + BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim + + vertical_size, slash_size = args.vertical_size, args.slash_size + + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + q_len = SEQ_LEN + + vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) + last_q = 64 + qk = torch.einsum("bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k) + arange = torch.arange(last_q, device="cuda") + qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], qk[:, :, :, -last_q:], -torch.inf) + qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + vertical_topk = torch.topk(vertical, vertical_size, -1).indices + + slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1] + slash[..., -30:] = torch.inf + + slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices + + _attn = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash) + + tilelang_out = _attn(False) + triton_out = _attn(True) + + torch.testing.assert_close(triton_out, tilelang_out, atol=1e-2, rtol=1e-2) + + triton_time = do_bench(lambda: _attn(True)) + tilelang_time = do_bench(lambda: _attn(False)) + + print(f"triton_time: {triton_time:.3f}ms") + print(f"tilelang_time: {tilelang_time:.3f}ms") + print(f"speedup: {triton_time / tilelang_time:.2f}x") + + +if __name__ == "__main__": + main() diff --git a/examples/minference/ops/kernels.cpp b/examples/minference/ops/kernels.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f1e33976447b6dbc8ae1f591af2dc27851f0e0d --- /dev/null +++ b/examples/minference/ops/kernels.cpp @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "torch/extension.h" +#include + +std::vector convert_vertical_slash_indexes( + torch::Tensor seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int context_size, int block_size_M, int block_size_N); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("convert_vertical_slash_indexes", &convert_vertical_slash_indexes, + "dynamic sparse index function"); +} diff --git a/examples/minference/ops/vertical_slash_index.cu b/examples/minference/ops/vertical_slash_index.cu new file mode 100644 index 0000000000000000000000000000000000000000..ae01f331b1ab284e6d646aa072e7ab61bb5b3d0a --- /dev/null +++ b/examples/minference/ops/vertical_slash_index.cu @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include +#include +#include +#include + +#include + +__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[block_count++] = idx; + } +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int N_HEADS, + int N_ROWS, + int BLOCK_SIZE_M, + int BLOCK_SIZE_N, + int NNZ_V, + int NNZ_S +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int seqlen = seqlens[batch_idx]; + int block_idx_m = group_idx * blockDim.x + threadIdx.x; + int start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= seqlen) { + return; + } + int end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + int tmp_col_cnt = 0, tmp_blk_cnt = 0; + int s = 0, v = 0; + int v_idx = vertical_indexes[v++]; + int s_idx = slash_indexes[s++]; + while (s_idx >= end_m) { + s_idx = slash_indexes[s++]; + } + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + v_idx = end_m + BLOCK_SIZE_M; + } + } else { + if (s < NNZ_S) { + s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + break; + } + if (s_idx > range_end + BLOCK_SIZE_M) { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int BATCH_SIZE, + int N_HEADS, + int N_ROWS, + int NNZ_V, + int NNZ_S +) { + const int BLOCK_SIZE_M = 64; + const int BLOCK_SIZE_N = 64; + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + seqlens, vertical_indexes, slash_indexes, + block_count, block_offset, column_count, column_index, + N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S + ); +} + +std::vector convert_vertical_slash_indexes( + torch::Tensor seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int context_size, + int block_size_M, + int block_size_N +) { + assert(block_size_M == 64); + assert(block_size_N == 64); + + cudaSetDevice(seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); + torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); + + convert_vertical_slash_indexes_64x64( + seqlens.data_ptr(), + vertical_indexes.data_ptr(), + slash_indexes.data_ptr(), + block_count.data_ptr(), + block_offset.data_ptr(), + column_count.data_ptr(), + column_index.data_ptr(), + batch_size, + num_heads, + num_rows, + nnz_vertical, + nnz_slash + ); + + return { block_count, block_offset, column_count, column_index }; +} diff --git a/examples/minference/ops/vertical_slash_index.hip b/examples/minference/ops/vertical_slash_index.hip new file mode 100644 index 0000000000000000000000000000000000000000..f01fd421125d6ccde89bb402c2cd9a30cb1cec20 --- /dev/null +++ b/examples/minference/ops/vertical_slash_index.hip @@ -0,0 +1,161 @@ +// !!! This is a file automatically generated by hipify!!! +#include +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include +#include +#include +#include + +#include + +__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[block_count++] = idx; + } +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int N_HEADS, + int N_ROWS, + int BLOCK_SIZE_M, + int BLOCK_SIZE_N, + int NNZ_V, + int NNZ_S +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int seqlen = seqlens[batch_idx]; + int block_idx_m = group_idx * blockDim.x + threadIdx.x; + int start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= seqlen) { + return; + } + int end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + int tmp_col_cnt = 0, tmp_blk_cnt = 0; + int s = 0, v = 0; + int v_idx = vertical_indexes[v++]; + int s_idx = slash_indexes[s++]; + while (s_idx >= end_m) { + s_idx = slash_indexes[s++]; + } + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + v_idx = end_m + BLOCK_SIZE_M; + } + } else { + if (s < NNZ_S) { + s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + break; + } + if (s_idx > range_end + BLOCK_SIZE_M) { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int BATCH_SIZE, + int N_HEADS, + int N_ROWS, + int NNZ_V, + int NNZ_S +) { + const int BLOCK_SIZE_M = 64; + const int BLOCK_SIZE_N = 64; + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0, + seqlens, vertical_indexes, slash_indexes, + block_count, block_offset, column_count, column_index, + N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S + ); +} + +std::vector convert_vertical_slash_indexes( + torch::Tensor seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int context_size, + int block_size_M, + int block_size_N +) { + assert(block_size_M == 64); + assert(block_size_N == 64); + + hipSetDevice(seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); + torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); + + convert_vertical_slash_indexes_64x64( + seqlens.data_ptr(), + vertical_indexes.data_ptr(), + slash_indexes.data_ptr(), + block_count.data_ptr(), + block_offset.data_ptr(), + column_count.data_ptr(), + column_index.data_ptr(), + batch_size, + num_heads, + num_rows, + nnz_vertical, + nnz_slash + ); + + return { block_count, block_offset, column_count, column_index }; +} diff --git a/examples/minference/test_vs_sparse_attn.py b/examples/minference/test_vs_sparse_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..f01df3808f86fb17d41a7b3617f76c949534fa45 --- /dev/null +++ b/examples/minference/test_vs_sparse_attn.py @@ -0,0 +1,12 @@ +import tilelang.testing + +import example_vertical_slash_sparse_attn + + +@tilelang.testing.requires_cuda +def test_vs_sparse_attn(): + example_vertical_slash_sparse_attn.main(argv=[]) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/norm/rms_norm.py b/examples/norm/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..57bccc1a0f901ddb2ac84b0b0e2ac8f92c95480a --- /dev/null +++ b/examples/norm/rms_norm.py @@ -0,0 +1,76 @@ +import torch +import tilelang +import tilelang.language as T + + +def rms_norm_splitk(M, N, blk_m, blk_k): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, blk_k), dtype) + A_local = T.alloc_fragment((blk_m, blk_k), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + num_k_step = T.ceildiv(N, blk_k) + T.clear(A_local) + for k in range(num_k_step): + T.copy(A[bx * blk_m, k * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_local[i, j] += A_shared[i, j] * A_shared[i, j] + T.reduce_sum(A_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + + for k in range(num_k_step): + # reverse, better cache hit rate + T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_shared[i, j] *= A_powsum[i] + T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k]) + + return main + + +@tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True}) +def rms_norm(M, N, blk_m): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, N), dtype) + A_pow_local = T.alloc_fragment((blk_m, N), dtype) + A_local = T.alloc_fragment((blk_m, N), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) + T.copy(A_shared, A_local) + for i, j in T.Parallel(blk_m, N): + A_pow_local[i, j] = A_local[i, j] * A_local[i, j] + T.reduce_sum(A_pow_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + for i, j in T.Parallel(blk_m, N): + A_local[i, j] *= A_powsum[i] + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) + + return main + + +def ref_program(x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12) + + +if __name__ == "__main__": + M, N, blk_m, blk_k = 8192, 8192, 1, 512 + kernel = rms_norm(M, N, blk_m) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) diff --git a/examples/norm/test_rms_norm.py b/examples/norm/test_rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..53db03d98ce32ee034a1d36b9ee590c0992267ce --- /dev/null +++ b/examples/norm/test_rms_norm.py @@ -0,0 +1,74 @@ +import torch +import tilelang +import tilelang.testing +import tilelang.language as T + + +def rms_norm_splitk(M, N, blk_m, blk_k): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, blk_k), dtype) + A_local = T.alloc_fragment((blk_m, blk_k), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + num_k_step = T.ceildiv(N, blk_k) + T.clear(A_local) + for k in range(num_k_step): + T.copy(A[bx * blk_m, k * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_local[i, j] += A_shared[i, j] * A_shared[i, j] + T.reduce_sum(A_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + + for k in range(num_k_step): + # reverse, better cache hit rate + T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_shared[i, j] *= A_powsum[i] + T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k]) + + return main + + +def rms_norm(M, N, blk_m): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, N), dtype) + A_pow_local = T.alloc_fragment((blk_m, N), dtype) + A_local = T.alloc_fragment((blk_m, N), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) + T.copy(A_shared, A_local) + for i, j in T.Parallel(blk_m, N): + A_pow_local[i, j] = A_local[i, j] * A_local[i, j] + T.reduce_sum(A_pow_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + for i, j in T.Parallel(blk_m, N): + A_local[i, j] *= A_powsum[i] + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) + + return main + + +def ref_program(x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12) + + +def test_rms_norm(M=1024, N=1024, blk_m=1): + program = rms_norm(M, N, blk_m) + kernel = tilelang.compile(program, out_idx=-1, pass_configs={"tl.disable_tma_lower": True}) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/online_softmax/online_softmax.py b/examples/online_softmax/online_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..811870e441b3f297c711a04179bde5cce925f126 --- /dev/null +++ b/examples/online_softmax/online_softmax.py @@ -0,0 +1,72 @@ +import torch +import tilelang as tl +import tilelang.language as T +from tilelang.profiler import do_bench +from typing import Callable + + +@tl.jit(out_idx=[1]) +def softmax_kernel( + M, + N, + dtype: T.dtype = T.float16, +) -> "Callable": + BN = min(tl.next_power_of_2(N), 8192) + NN = tl.cdiv(N, BN) + + accum_dtype = T.float32 + + scale = 1.44269504 # log2(e) + + @T.prim_func + def main( + X: T.Tensor([M, N], dtype), + Y: T.Tensor([M, N], dtype), + ): + with T.Kernel(M, threads=128) as (i_m): + x = T.alloc_fragment([BN], dtype) + y = T.alloc_fragment([BN], dtype) + lse = T.alloc_fragment([1], accum_dtype) + max_x = T.alloc_fragment([1], dtype) + exp_x = T.alloc_fragment([BN], accum_dtype) + sum_exp_x = T.alloc_fragment([1], accum_dtype) + T.fill(lse, -T.infinity(accum_dtype)) + + for i_n in T.Pipelined(0, NN): + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) + + T.reduce_max(x, max_x, dim=0, clear=True) + + for j in T.Parallel(BN): + exp_x[j] = T.exp2(x[j] * scale - max_x[0] * scale) + + T.reduce_sum(exp_x, sum_exp_x, dim=0, clear=True) + + lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0]) + + for i_n in T.Pipelined(0, NN): + T.copy(X[i_m, i_n * BN : (i_n + 1) * BN], x) + + for j in T.Parallel(BN): + y[j] = T.exp2(x[j] * scale - lse[0]) + + T.copy(y, Y[i_m, i_n * BN : (i_n + 1) * BN]) + + return main + + +M = 8192 +N = 8192 +kernel = softmax_kernel(M, N) +dtype = torch.float16 +X = torch.randn(M, N, dtype=dtype, device="cuda") +Y = kernel(X) +Y_ref = X.softmax(dim=1) + +torch.testing.assert_close(Y, Y_ref, rtol=1e-2, atol=1e-2) + +t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100) +t2 = do_bench(lambda: kernel(X), warmup=25, rep=100) +print(f"torch latency: {t1:.3f} ms") +print(f"TileLang latency: {t2:.3f} ms") +print(f"Speedup: {t1 / t2:.3f}x") diff --git a/examples/plot_layout/README.md b/examples/plot_layout/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8204e93d804edde4a1d9bbb00366f7c7be39dae1 --- /dev/null +++ b/examples/plot_layout/README.md @@ -0,0 +1,108 @@ +The following example demonstrates how to generate and visualize a memory layout using tilelang tools `plot_layout`. + +Example Code + +```python +import tilelang.language as T +from tvm import DataType +from tvm.tir import IndexMap +from typing import Literal, Callable +from tilelang.intrinsics.utils import get_mma_micro_size +from tilelang.tools import plot_layout + +def make_mma_load_base_layout(dtype: str = T.float16, + matrix: Literal["A", "B"] = "A", + transposed: bool = False) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + dtype : str + The data type of the matrix. + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.intrinsics.mma_layout import ( + shared_16x16_to_mma_32x8_layout_sr, + shared_16x16_to_mma_32x8_layout_rs, + shared_16x32_to_mma_32x16_layout, + shared_32x16_to_mma_32x16_layout, + ) + assert matrix in ["A", "B"], "matrix should be either A or B" + dtype_bits = DataType(dtype).bits + assert transposed is False, "transposed is not supported yet" + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + transform_func_sr: Callable = None + transform_func_rs: Callable = None + if dtype_bits == 16: + transform_func_sr = shared_16x16_to_mma_32x8_layout_sr + transform_func_rs = shared_16x16_to_mma_32x8_layout_rs + elif dtype_bits == 8: + transform_func_sr = shared_16x32_to_mma_32x16_layout + transform_func_rs = shared_32x16_to_mma_32x16_layout + else: + raise ValueError(f"Unsupported dtype {dtype}") + is_sr_conditions = [False] + is_sr_conditions.append(matrix == "A" and not transposed) + is_sr_conditions.append(matrix == "B" and transposed) + is_sr_axis_order = any(is_sr_conditions) + + transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs + + micro_size_s, _, micro_size_r = get_mma_micro_size(dtype) + + transform_func = transform_func + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + return base_fragment + + +# Create a 16ร—16 matrix layout for ldmatrix operations +base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) + +# Print the layout structure (optional for debugging) +print(base_layout) + +# Plot and save the layout visualization +plot_layout(base_layout, name="base_layout") +``` + +Output + +![base_layout](./images/base_layout.png) diff --git a/examples/plot_layout/fragment_mfma_load_a.py b/examples/plot_layout/fragment_mfma_load_a.py new file mode 100644 index 0000000000000000000000000000000000000000..d45cc227bc2d0fcef5f1d034c0ed51f62f4c571e --- /dev/null +++ b/examples/plot_layout/fragment_mfma_load_a.py @@ -0,0 +1,127 @@ +import tilelang.language as T +from typing import Literal, Callable +from tvm.tir import IndexMap +from tilelang.intrinsics.utils import get_mma_micro_size + +from tilelang.intrinsics.mfma_layout import ( + shared_16x4_to_local_64x1_layout_A, + shared_16x16_to_local_64x4_layout_A, + shared_16x32_to_local_64x8_layout_A, + shared_16x64_to_local_64x16_layout_A, +) + + +def make_mfma_load_base_layout( + dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False +) -> T.Fragment: + """ + Create a layout function for storing MFMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mfma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + dtype : str + The data type of the matrix. + matrix : Literal["A", "B"] + The mfma operand to be loaded. + k_dim : int + The k dimension of the mfma. + transposed : bool + Whether the matrix is transposed, by default False. + + Returns + ------- + T.Fragment + Describes how threads and indices in fragment are laid out. + + """ + + assert matrix in ["A", "B"], "matrix should be either A or B" + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + + if k_dim == 4: + transform_func_sr_a = shared_16x4_to_local_64x1_layout_A + transform_func_sr_b = shared_16x4_to_local_64x1_layout_A + elif k_dim == 16: + transform_func_sr_a = shared_16x16_to_local_64x4_layout_A + transform_func_sr_b = shared_16x16_to_local_64x4_layout_A + elif k_dim == 32: + transform_func_sr_a = shared_16x32_to_local_64x8_layout_A + transform_func_sr_b = shared_16x32_to_local_64x8_layout_A + elif k_dim == 64: + transform_func_sr_a = shared_16x64_to_local_64x16_layout_A + transform_func_sr_b = shared_16x64_to_local_64x16_layout_A + else: + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix == "A" and not transposed) + is_sr_conditions.append(matrix == "B" and transposed) + is_sr_axis_order = any(is_sr_conditions) + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix == "A": + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + micro_size_s, micro_size_r = micro_size_x, micro_size_k + elif matrix == "B": + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + micro_size_s, micro_size_r = micro_size_k, micro_size_y + else: + raise ValueError(f"Unsupported matrix {matrix}") + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + return base_fragment + + +block_rows = 2 +block_cols = 2 +warp_rows = 2 +warp_cols = 2 +chunk = 2 + +from tilelang.tools import plot_layout + +# ldmatrix layout 16x16 +base_layout = make_mfma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) +print(base_layout) +plot_layout(base_layout, name="base_layout") + +# warp layout 32x32 +warp_layout = base_layout.repeat([warp_rows, warp_cols], repeat_on_thread=False, lower_dim_first=False) +print(warp_layout) +plot_layout(warp_layout, name="warp_layout") + +# block layout 64x32 +block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, lower_dim_first=True).replicate(block_cols) +print(block_layout) +plot_layout(block_layout, name="block_layout") diff --git a/examples/plot_layout/fragment_mma_load_a.py b/examples/plot_layout/fragment_mma_load_a.py new file mode 100644 index 0000000000000000000000000000000000000000..df4a0b88701192c44e9743360ad7ead14d4f0dbd --- /dev/null +++ b/examples/plot_layout/fragment_mma_load_a.py @@ -0,0 +1,122 @@ +import tilelang.language as T +from typing import Literal, Callable +from tvm import DataType +from tvm.tir import IndexMap +from tilelang.intrinsics.utils import get_mma_micro_size + + +def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + dtype : str + The data type of the matrix. + matrix : Literal["A", "B"] + The mma operand to be loaded. + transposed : bool + Whether the matrix is transposed, by default False. + + Returns + ------- + T.Fragment + Describes how threads and indices in fragment are laid out. + + """ + from tilelang.intrinsics.mma_layout import ( + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, + shared_16x8_to_mma_32x4_layout_sr_b, + shared_16x16_to_mma_32x8_layout_sr_b, + shared_16x32_to_mma_32x16_layout_sr_b, + ) + + assert matrix in ["A", "B"], "matrix should be either A or B" + dtype_bits = DataType(dtype).bits + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + if dtype_bits == 32: + transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a + transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b + elif dtype_bits == 16: + transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a + transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b + elif dtype_bits == 8: + transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a + transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_conditions = [False] + is_sr_conditions.append(matrix == "A" and not transposed) + is_sr_conditions.append(matrix == "B" and transposed) + is_sr_axis_order = any(is_sr_conditions) + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix == "A": + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + micro_size_s, micro_size_r = micro_size_x, micro_size_k + elif matrix == "B": + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i) + micro_size_s, micro_size_r = micro_size_k, micro_size_y + else: + raise ValueError(f"Unsupported matrix {matrix}") + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + return base_fragment + + +block_rows = 2 +block_cols = 2 +warp_rows = 4 +warp_cols = 4 +chunk = 2 + +from tilelang.tools import plot_layout + +# ldmatrix layout 16x16 +base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) +print(base_layout) +plot_layout(base_layout, name="base_layout") + +# warp layout 32x16 +warp_layout = base_layout.repeat([block_rows, 1], repeat_on_thread=True).replicate(block_cols) +print(warp_layout) +plot_layout(warp_layout, name="warp_layout") + +# block layout 128x32 +block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False) +print(block_layout) +plot_layout(block_layout, name="block_layout") diff --git a/examples/plot_layout/images/base_layout.png b/examples/plot_layout/images/base_layout.png new file mode 100644 index 0000000000000000000000000000000000000000..e8ebcf8b6971170b7dc2dfd5e66168bb487b7794 Binary files /dev/null and b/examples/plot_layout/images/base_layout.png differ diff --git a/examples/pytest.ini b/examples/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..5f820048e6dfb0c195518233427628c5f4da027a --- /dev/null +++ b/examples/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +norecursedirs = bitnet-1.58b diff --git a/examples/quickstart.py b/examples/quickstart.py new file mode 100644 index 0000000000000000000000000000000000000000..e99fc0dbceff115a0569495b563764170f05fa89 --- /dev/null +++ b/examples/quickstart.py @@ -0,0 +1,87 @@ +import tilelang +import tilelang.language as T + + +# @tilelang.jit(target="cuda") +# target currently can be "cuda" or "hip" or "cpu". +# if not specified, it will be inferred from the input tensors during compile time +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M = 1024 # M = T.dynamic("m") if you want to use dynamic shape +N = 1024 +K = 1024 +block_M = 128 +block_N = 128 +block_K = 32 + +# Define the kernel (matmul) and compile/lower it into an executable module +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) +# Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +matmul_relu_kernel(a, b, c) + +print(c) +# Reference multiplication using PyTorch +ref_c = torch.relu(a @ b) + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = matmul_relu_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/examples/rand/rand_uint.py b/examples/rand/rand_uint.py new file mode 100644 index 0000000000000000000000000000000000000000..466a51b7a312643c66d02c7069db021f6f3cb036 --- /dev/null +++ b/examples/rand/rand_uint.py @@ -0,0 +1,57 @@ +import tilelang +import tilelang.language as T +import torch +import triton +import triton.language as tl + + +@tilelang.jit +def tilelang_rand_1d(M=1024, seed=42): + num_per_thread = 128 + threads = 1 + blk_M = num_per_thread * threads + + @T.prim_func + def rand_kernel(A: T.Tensor((M,), "uint32")): + with T.Kernel(T.ceildiv(M, threads * num_per_thread), threads=threads) as bx: + tx = T.get_thread_binding() + T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread) + for i, j in T.Parallel(threads, num_per_thread): + offsets = (bx * threads + i) * num_per_thread + idx = offsets + j + if idx < M: + A[idx] = T.rng_rand() + + return rand_kernel + + +@triton.jit +def triton_rand_1d(X, M, elements_per_thread, seed): + pid = tl.program_id(0) + offset = pid * elements_per_thread + tl.arange(0, elements_per_thread) + + r0, r1, r2, r3 = tl.randint4x(seed, offset) + + base_idx = offset * 4 + tl.store(X + base_idx, r0, mask=base_idx < M) + tl.store(X + base_idx + 1, r1, mask=(base_idx + 1) < M) + tl.store(X + base_idx + 2, r2, mask=(base_idx + 2) < M) + tl.store(X + base_idx + 3, r3, mask=(base_idx + 3) < M) + + +def test_rand_1d(M, seed): + kernel = tilelang_rand_1d(M, seed) + tilelang_result = torch.empty(M, dtype=torch.uint32, device="cuda") + kernel(tilelang_result) + + triton_result = torch.empty(M, dtype=torch.uint32, device="cuda") + grid = (triton.cdiv(M, 128),) + triton_rand_1d[grid](triton_result, tl.constexpr(M), tl.constexpr(128 // 4), seed) + + torch.testing.assert_close(tilelang_result, triton_result) + + +if __name__ == "__main__": + test_rand_1d(1024, 42) + test_rand_1d(512, 123) + test_rand_1d(128, 0) diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..25741f97cce73d6e8d9c06c9928590db5a53d68c --- /dev/null +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -0,0 +1,254 @@ +import math +import torch + +import tilelang +import tilelang.language as T +import torch.nn.functional as F + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@tilelang.jit( + out_idx=[4], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 0 + threads = 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.int8 + + def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = T.ceildiv(seq_kv, block_N) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[k] != 0: + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + past_len = seq_kv - seq_q + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i + past_len >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + return kernel_func(block_M, block_N, num_stages, threads) + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 4, 2, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + # Run tilelang kernel + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + print("ref_output", ref_output) + print("tilelang_output", tilelang_output) + + # Verify accuracy + assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), "TileLang output doesn't match reference" + print("Pass topk sparse attention test with qlen == klen") + + +def test_topk_sparse_attention_qlen_lt_klen(): + # Config + BATCH, N_HEADS = 1, 1 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. + TOPK = 1 + BLOCK = 64 # block size used in downsampling + torch.manual_seed(0) + + # Create inputs. + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.float16) + sm_scale = 1.0 / (D_HEAD**0.5) + + downsample_factor = BLOCK + downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.float16) + # Force the first column to be high so that the first block is always selected. + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn(BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + print(kernel.get_kernel_source()) + tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) + + past_len = K_LEN - Q_LEN + + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() + full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + + effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + + i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) + + final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) + + attn = attn.masked_fill(~final_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + print("ref_output", ref_output) + print("tilelang_output", tilelang_output) + + # Verify accuracy. + torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2) + + print("Pass topk sparse attention test with qlen < klen") + + +def main(): + test_topk_sparse_attention() + test_topk_sparse_attention_qlen_lt_klen() + + +if __name__ == "__main__": + main() diff --git a/examples/seer_attention/block_sparse_attn_triton.py b/examples/seer_attention/block_sparse_attn_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..b4cc3cd00c854cdde0757af38dcf6302976cd49f --- /dev/null +++ b/examples/seer_attention/block_sparse_attn_triton.py @@ -0,0 +1,347 @@ +# ruff: noqa: E712 +import math +import torch + +import triton +import triton.language as tl +import torch.nn.functional as F + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + k_block_col_idx, + block_mask_ptr, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, + sm_scale, + past_len, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) + + if mask_val == True: + start_n = k_block_col_idx * BLOCK_N + # -- compute qk ---- + + k = tl.load(k_ptrs + start_n * stride_kt) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + start_n * stride_vt) + + p = p.to(v.type.element_ty) + + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + return acc, l_i, m_i + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + block_mask_ptr, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, + PAST_LEN, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + Q_LEN = N_CTX - PAST_LEN + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_h = off_hz % H + off_z = off_hz // H + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + mask_ptrs = block_mask_ptr + start_m * stride_bmm + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) + + k_block_start = 0 + k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) + + # loop over k, v and update accumulator + for col_idx in range(k_block_start, k_block_end): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + col_idx, + mask_ptrs, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, + sm_scale, + PAST_LEN, + BLOCK_M, + BLOCK_N, + ) + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(Out.dtype.element_ty) + + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) + + +def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None): + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert k.shape[2] == v.shape[2] + o = out if out is not None else torch.empty_like(q).contiguous() + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) + + assert q.shape[-1] in [64, 128] + BLOCK_DMODEL = q.shape[-1] + + if is_hip(): + num_warps, num_stages = 8, 1 + else: + num_warps, num_stages = 4, 2 + + N_CTX = k.shape[2] + PAST_LEN = N_CTX - q.shape[2] + print("PAST_LEN", PAST_LEN) + H = q.shape[1] + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + block_sparse_mask, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), + *o.stride(), + H, + N_CTX, + PAST_LEN, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o + + +class _sparse_attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, block_sparse_dense, sm_scale): + # shape constraints + return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) + + @staticmethod + def backward(ctx, do): + # No gradient propagation. + raise NotImplementedError("It does not support gradient propagation yet") + return None, None, None, None, None + + +block_sparse_triton_fn = _sparse_attention.apply + + +def test_topk_sparse_attention(): + # Config + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 # Keep top 8 elements per row + BLOCK = 64 + torch.manual_seed(0) + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + print("downsample_len", downsample_len) + + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + print("x_ds.shape", x_ds.shape) + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + # print("block_mask", block_mask) + print("block_mask.shape", block_mask.shape) + + # Run Triton kernel + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + # print("ref_output", ref_output) + # print("triton_output", triton_output) + + # Verify accuracy + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference" + print("Pass topk sparse attention test with qlen == klen") + + +def test_topk_sparse_attention_qlt_kl(): + BATCH, N_HEADS = 1, 1 + Q_LEN, K_LEN, D_HEAD = 64, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128. + TOPK = 1 + BLOCK = 64 # block size used in downsampling + torch.manual_seed(0) + + # Create inputs. + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16) + # softmax scale + sm_scale = 1.0 / (D_HEAD**0.5) + + downsample_factor = BLOCK + downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension + x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16) + # Force the first column to be high so that the first block is always selected. + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + # Run Triton kernel. + triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + past_len = K_LEN - Q_LEN + + attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale + + full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool() + full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] + + effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN) + + i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1) + j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN) + causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN) + + final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN) + + attn = attn.masked_fill(~final_mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v) + + # Verify accuracy. + assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen" + + print("Pass topk sparse attention test with qlen < klen") + + +if __name__ == "__main__": + test_topk_sparse_attention() + test_topk_sparse_attention_qlt_kl() diff --git a/examples/seer_attention/test_block_sparse_attn_tilelang.py b/examples/seer_attention/test_block_sparse_attn_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..da175d05c7f71b66deb71b6785506fd98d85be54 --- /dev/null +++ b/examples/seer_attention/test_block_sparse_attn_tilelang.py @@ -0,0 +1,12 @@ +import tilelang.testing + +import block_sparse_attn_tilelang + + +@tilelang.testing.requires_cuda +def test_block_sparse_attn_tilelang(): + block_sparse_attn_tilelang.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/sparse_tensorcore/test_example_sparse_tensorcore.py b/examples/sparse_tensorcore/test_example_sparse_tensorcore.py new file mode 100644 index 0000000000000000000000000000000000000000..72292e44868dc30c7f2b6b5044a4449c5e9f559e --- /dev/null +++ b/examples/sparse_tensorcore/test_example_sparse_tensorcore.py @@ -0,0 +1,13 @@ +import tilelang.testing +import tilelang +import tilelang_example_sparse_tensorcore + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tilelang_example_sparse_tensorcore(): + tilelang_example_sparse_tensorcore.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py new file mode 100644 index 0000000000000000000000000000000000000000..14339ff02932819d4273acc818e8da0256354dcc --- /dev/null +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -0,0 +1,117 @@ +import torch +import tilelang +from tilelang.utils.sparse import compress_sm90 +from tilelang.layout import make_cutlass_metadata_layout +from tilelang import language as T +import tilelang.testing + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) + B_shape = (K, N) + A_shared_shape = (block_M, block_K // 2) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // 8), "uint8"), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // 8), "uint8") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="9.0", block_k=block_K), + } + ) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // 8], E_shared) + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device="cpu"): + if shape[-1] % 4 != 0: + raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.") + + full_tensor = torch.randn(shape, dtype=dtype, device=device) + group_count = shape[-1] // 4 + group_shape = shape[:-1] + (group_count, 4) + + rand_vals = torch.rand(group_shape, device=device) + topk_indices = rand_vals.topk(k=2, dim=-1).indices + mask = torch.zeros(group_shape, dtype=torch.bool, device=device) + mask.scatter_(-1, topk_indices, True) + mask = mask.view(shape) + + sparse_tensor = full_tensor * mask + return sparse_tensor + + +def run_gemm_sp( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + block_M, + block_N, + block_K, + num_stages, + num_threads, +): + kernel = matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + ) + + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda") + A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) + B = torch.randn((K, N), device="cuda", dtype=torch.float16) + + C_sp = kernel(A_sparse, E, B).half() + C = torch.matmul(A, B) + torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3) + print("pass") + + +def main(): + run_gemm_sp(512, 1024, 768, T.float16, T.float16, T.float32, 128, 128, 128, 2, 128) + + +if __name__ == "__main__": + main() diff --git a/examples/topk/example_topk.py b/examples/topk/example_topk.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f0c8bfb28a116ca418633db3e0450d75bbf55e --- /dev/null +++ b/examples/topk/example_topk.py @@ -0,0 +1,93 @@ +import tilelang +import tilelang.language as T +import torch +import itertools +import argparse + + +def get_configs(): + iter_params = dict( + blk_m=[64, 128, 256], + threads=[128, 256, 512], + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[1, 2]) +def tl_topk( + M, + N, + topk, + blk_m, + threads=128, +): + dtype = T.float32 + + @T.prim_func + def topk_kernel( + logits: T.Tensor([M, N], dtype), + topk_gates: T.Tensor([M, topk], dtype), + topk_indices: T.Tensor([M, topk], T.int32), + ): + with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx: + logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype) + max_val = T.alloc_fragment([blk_m], dtype=dtype) + expand_max_idx = T.alloc_fragment([blk_m, N], T.int32) + max_idx = T.alloc_fragment([blk_m], T.int32) + + T.copy(logits[bx * blk_m, 0], logits_frag) + + for k in T.serial(topk): + T.fill(expand_max_idx, -1) + T.reduce_max(logits_frag, max_val, dim=1, clear=True) + + for i, j in T.Parallel(blk_m, N): + expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j]) + + T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True) + + for i, j in T.Parallel(blk_m, N): + logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j]) + + for i in T.Parallel(blk_m): + topk_gates[bx * blk_m + i, k] = max_val[i] + topk_indices[bx * blk_m + i, k] = max_idx[i] + + return topk_kernel + + +def ref_program(logits, top_k): + top_k_gates, top_k_indices = logits.topk(top_k, dim=1) + + return top_k_gates, top_k_indices.to(torch.int32) + + +def main(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=320, help="num_tokens") + parser.add_argument("--N", type=int, default=128, help="num_experts") + parser.add_argument("--topk", type=int, default=6, help="topk") + parser.add_argument("--blk_m", type=int, default=64, help="blk_m") + args = parser.parse_args(argv) + M, N, topk, blk_m = args.M, args.N, args.topk, args.blk_m + + logits = torch.rand((M, N), device="cuda", dtype=torch.float32) + + kernel = tl_topk(M=M, N=N, topk=topk, blk_m=blk_m) + tl_gates, tl_indices = kernel(logits) + + torch_gates, torch_indices = ref_program(logits, topk) + + # test accuracy + torch.testing.assert_close(tl_gates, torch_gates) + torch.testing.assert_close(tl_indices, torch_indices) + + # profile + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + tilelang_latency = profiler.do_bench() + print(f"Tilelang latency: {tilelang_latency}") + + +if __name__ == "__main__": + main() diff --git a/examples/topk/test_topk_tilelang.py b/examples/topk/test_topk_tilelang.py new file mode 100644 index 0000000000000000000000000000000000000000..54de01143ffa23496ac65233155ce0f68bc28b5c --- /dev/null +++ b/examples/topk/test_topk_tilelang.py @@ -0,0 +1,11 @@ +import tilelang.testing +import example_topk + + +@tilelang.testing.requires_cuda +def test_topk_tilelang(): + example_topk.main(argv=[]) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa1eaf854ec283042b2f3a1c7d2c9d1ae1dd457 --- /dev/null +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -0,0 +1,61 @@ +import tilelang +import tilelang.language as T + + +# use pass_configs to enable layout visualization +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg", + }, +) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(128, 128, 128, 32, 32, 32) + + import torch + + a = torch.randn(128, 128).cuda().half() + b = torch.randn(128, 128).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # print the layout visualization result and save figures to ./tmp. + """ + C_local inferenced layout: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] + """ + + +if __name__ == "__main__": + main() diff --git a/examples/warp_specialize/example_warp_specialize_flashmla.py b/examples/warp_specialize/example_warp_specialize_flashmla.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcd51aa7c9b885895f19aebe5bbc50f4687f14d --- /dev/null +++ b/examples/warp_specialize/example_warp_specialize_flashmla.py @@ -0,0 +1,372 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse + + +@tilelang.jit(out_idx=[6]) +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): + scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) + dtype = T.float16 + accum_dtype = T.float32 + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + h_dim = dim // 2 + + @T.macro + def flash_attn( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): + # smem_sQ + Q_shared_l = T.alloc_shared([block_H, h_dim], dtype) + Q_shared_r = T.alloc_shared([block_H, h_dim], dtype) + Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) + Q_pe_local_0 = T.alloc_fragment([block_H, pe_dim], dtype) + Q_pe_local_1 = T.alloc_fragment([block_H, pe_dim], dtype) + + # smem_sK0 + KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype) + KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype) + K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype) + + # smem_sK1 + KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype) + KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype) + K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype) + + # smem_sP0 + SP0_shared = T.alloc_shared([block_H, block_N], dtype) + + # smem_sP1 reuse Q_pe_shared + SP1_shared = Q_pe_shared + + # smem_sM + scores_max = T.alloc_shared([block_H], accum_dtype) + + # smem_sScale0 + scores_scale_0 = T.alloc_shared([block_H], accum_dtype) + # smem_sScale1 + scores_scale_1 = T.alloc_shared([block_H], accum_dtype) + + logsum = T.alloc_shared([block_H], accum_dtype) + + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + + acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_0_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_1_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype) + acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype) + scores_max_0 = T.alloc_fragment([block_H], accum_dtype) + scores_max_1 = T.alloc_fragment([block_H], accum_dtype) + + scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype) + + scores_sum_0 = T.alloc_fragment([block_H], accum_dtype) + scores_sum_1 = T.alloc_fragment([block_H], accum_dtype) + logsum_0 = T.alloc_fragment([block_H], accum_dtype) + logsum_1 = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = hid // (kv_group_num // block_H) + + T.annotate_layout( + { + O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l), + O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), + } + ) + + # barriers_Q + q_shared_ready_barrier = T.alloc_barrier(arrive_count=256) + + # barriers_K0 + kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128) + kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128) + kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128) + # barriers_K1 + kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128) + kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128) + kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128) + + # redundant barriers + score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128) + scale_1_ready_barrier = T.alloc_barrier(arrive_count=128) + p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128) + lse_0_ready_barrier = T.alloc_barrier(arrive_count=128) + lse_1_ready_barrier = T.alloc_barrier(arrive_count=128) + s_shared_ready_barrier = T.alloc_barrier(arrive_count=128) + + tx = T.get_thread_binding() + + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) + T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) + T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.barrier_arrive(q_shared_ready_barrier) + T.barrier_wait(q_shared_ready_barrier, 0) + + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(seqlen_kv, (block_N * 2)) + + if tx < 128: + T.copy(Q_pe_shared, Q_pe_local_0) + T.fill(acc_o_l, 0) + T.fill(logsum_0, 0) + + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) + T.barrier_arrive(kv_shared_1_l_is_ready) + + T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) + T.barrier_arrive(kv_shared_1_r_is_ready) + + T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1) + T.barrier_arrive(kv_shared_1_pe_is_ready) + + for k in T.serial(loop_range): + T.barrier_wait(kv_shared_0_l_is_ready, k % 2) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1) + T.barrier_wait(kv_shared_0_r_is_ready, k % 2) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1) + + T.barrier_wait(kv_shared_0_pe_is_ready, k % 2) + T.gemm(Q_pe_local_0, K_pe_shared_0, acc_s_0, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + # Step 3. + T.copy(scores_max, scores_max_0) + T.copy(scores_max_0, scores_max_prev_0) + T.fill(scores_max_0, -T.infinity(accum_dtype)) + T.reduce_max(acc_s_0, scores_max_0, dim=1, clear=False) + T.copy(scores_max_0, scores_max) + + # Step 4. + for i, j in T.Parallel(block_H, block_N): + acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale) + for i in T.Parallel(block_H): + scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale) + + T.reduce_sum(acc_s_0, scores_sum_0, dim=1) + + # Step 5. + T.copy(acc_s_0, acc_s_0_cast) + + for i, j in T.Parallel(block_H, h_dim): + acc_o_l[i, j] *= scores_scale_0[i] + + for i in T.Parallel(block_H): + logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i] + + # Step 6. + T.gemm(acc_s_0_cast, KV_shared_0_l, acc_o_l) + T.barrier_arrive(score_max_0_ready_barrier) + + T.barrier_wait(scale_1_ready_barrier, k % 2) + + if k < loop_range - 1: + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l) + T.barrier_arrive(kv_shared_0_l_is_ready) + + # Step 11. + for i, j in T.Parallel(block_H, block_N): + SP0_shared[i, j] = acc_s_0[i, j] * scores_scale_1[i] + + T.barrier_arrive(p0_1_1_ready_barrier) + + # Step 13. + for i, j in T.Parallel(block_H, h_dim): + acc_o_l[i, j] *= scores_scale_1[i] + for i in T.Parallel(block_H): + logsum_0[i] = logsum_0[i] * scores_scale_1[i] + T.barrier_wait(s_shared_ready_barrier, k % 2) + + # Step 14. + T.gemm(SP1_shared, KV_shared_1_l, acc_o_l) + + if k < loop_range - 1: + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l) + T.barrier_arrive(kv_shared_1_l_is_ready) + + T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1) + T.barrier_arrive(kv_shared_1_pe_is_ready) + + T.copy(logsum_0, logsum) + T.barrier_arrive(lse_0_ready_barrier) + T.barrier_wait(lse_1_ready_barrier, 0) + for i, j in T.Parallel(block_H, h_dim): + acc_o_l[i, j] /= logsum[i] + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim]) + + else: + T.copy(Q_pe_shared, Q_pe_local_1) + T.fill(acc_o_r, 0) + T.fill(logsum_1, 0) + + T.copy(KV[bid, :block_N, cur_kv_head, :h_dim], KV_shared_0_l) + T.barrier_arrive(kv_shared_0_l_is_ready) + T.copy(KV[bid, :block_N, cur_kv_head, h_dim:], KV_shared_0_r) + T.barrier_arrive(kv_shared_0_r_is_ready) + T.copy(K_pe[bid, :block_N, cur_kv_head, :], K_pe_shared_0) + T.barrier_arrive(kv_shared_0_pe_is_ready) + + for k in T.serial(loop_range): + # Step 2. + T.barrier_wait(kv_shared_1_l_is_ready, k % 2) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1) + + T.barrier_wait(kv_shared_1_r_is_ready, k % 2) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1) + + T.barrier_wait(kv_shared_1_pe_is_ready, k % 2) + T.gemm(Q_pe_local_1, K_pe_shared_1, acc_s_1, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + # Step 7. + T.barrier_wait(score_max_0_ready_barrier, k % 2) + + T.copy(scores_max, scores_max_prev_1) + T.fill(scores_max_1, -T.infinity(accum_dtype)) + T.reduce_max(acc_s_1, scores_max_1, dim=1, clear=False) + T.copy(scores_max_1, scores_max) + + for i in T.Parallel(block_H): + scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale) + + # Step 8. + for i, j in T.Parallel(block_H, block_N): + acc_s_1[i, j] = T.exp2(acc_s_1[i, j] * scale - scores_max[i] * scale) + + # Step 9. + T.reduce_sum(acc_s_1, scores_sum_1, dim=1) + + for i, j in T.Parallel(block_H, h_dim): + acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i]) + + for i in T.Parallel(block_H): + logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i] + + T.barrier_arrive(scale_1_ready_barrier) + + # Step 10. compute O1 with KV_shared_1_rd + T.copy(acc_s_1, acc_s_1_cast) + T.gemm(acc_s_1_cast, KV_shared_1_r, acc_o_r, wg_wait=-1) + T.copy(acc_s_1_cast, SP1_shared) + T.barrier_arrive(s_shared_ready_barrier) + + if k < loop_range - 1: + T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r) + T.barrier_arrive(kv_shared_1_r_is_ready) + + T.barrier_wait(p0_1_1_ready_barrier, k % 2) + # Step 12. + T.gemm(SP0_shared, KV_shared_0_r, acc_o_r) + + if k < loop_range - 1: + T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r) + T.barrier_arrive(kv_shared_0_r_is_ready) + + T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0) + T.barrier_arrive(kv_shared_0_pe_is_ready) + + T.barrier_wait(lse_0_ready_barrier, 0) + for i in T.Parallel(block_H): + logsum[i] += logsum_1[i] + T.barrier_arrive(lse_1_ready_barrier) + for i, j in T.Parallel(block_H, h_dim): + acc_o_r[i, j] /= logsum[i] + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:]) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, heads, dim], dtype), + Q_pe: T.Tensor([batch, heads, pe_dim], dtype), + KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), + K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), + Output: T.Tensor([batch, heads, dim], dtype), + ): + flash_attn(Q, Q_pe, KV, K_pe, Output) + + return main_no_split + + +def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): + # """ + # Inputs: + # - q (Tensor): [batch, heads, dim] + # - q_pe (Tensor): [batch, heads, pe_dim] + # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] + # - glse (Tensor): [batch, heads, num_split] + # - Output_partial (Tensor): [batch, heads, num_split, dim] + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + dim = q.shape[-1] + pe_dim = q_pe.shape[-1] + num_head_groups = q.shape[1] // kv.shape[2] + scale = (dim + pe_dim) ** 0.5 + q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim] + + q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim] + + kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim] + + k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim] + + query = torch.concat([q, q_pe], dim=-1) + key = torch.concat([kv, k_pe], dim=-1) + + scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv] + + attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv] + + out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim] + out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim] + return out + + +def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): + qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) + pv_flops = 2 * batch * heads * kv_ctx * dim + total_flops = qk_flops + pv_flops + BLOCK_N = 64 + BLOCK_H = 64 + num_split = 1 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) + print(kernel.get_kernel_source()) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + latency = profiler.do_bench(warmup=500) + print(f"Latency: {latency} ms") + print(f"TFlops: {total_flops / latency * 1e-9} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1, help="batch size") + parser.add_argument("--heads", type=int, default=128, help="q heads number") + parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number") + parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length") + parser.add_argument("--dim", type=int, default=512, help="head dim") + parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim") + args = parser.parse_args() + batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim + main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2aa00d929d0d91895cc98ea3668955a8883e8f --- /dev/null +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -0,0 +1,86 @@ +import tilelang +import tilelang.language as T + +tilelang.disable_cache() + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit(out_idx=[2]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + num_stages = 2 + mbarrier_list = [128, 128] * num_stages + + @T.prim_func + def main( + A: T.Tensor[(M, K), dtype], + B: T.Tensor[(K, N), dtype], + C: T.Tensor[(M, N), dtype], + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((num_stages, block_M, block_K), dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # create mbarrier for tma + T.create_list_of_mbarrier(mbarrier_list) + + with T.ws(0): + T.clear(C_local) + + for ko in range(T.ceildiv(K, block_K)): + with T.ws(1): + T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1) + T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :]) + T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :]) + T.mbarrier_arrive(mbarrier=ko % num_stages) + with T.ws(0): + T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) + T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local) + T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages) + + with T.ws(0): + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(M=16384, N=16384, K=16384): + tilelang.disable_cache() + block_M = 128 + block_N = 128 + block_K = 64 + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + + print(jit_kernel.get_kernel_source()) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py new file mode 100644 index 0000000000000000000000000000000000000000..7b22784323ba327d0054a0f46e53bd7e6eb6acdc --- /dev/null +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -0,0 +1,78 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit(out_idx=[2]) +def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + data_is_ready = T.alloc_barrier(arrive_count=128) + compute_is_done = T.alloc_barrier(arrive_count=128) + + with T.ws(1): + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + with T.ws(0): + T.barrier_wait(compute_is_done, (ko + 1) % 2) + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.barrier_arrive(data_is_ready) + with T.ws(1): + T.barrier_wait(data_is_ready, ko % 2) + T.gemm(A_shared, B_shared, C_local) + T.barrier_arrive(compute_is_done) + + with T.ws(1): + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(M=1024, N=1024, K=1024): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K) + + import torch + + # Create random input tensors on the GPU + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py new file mode 100644 index 0000000000000000000000000000000000000000..02d88019c7e1793c824e088f54d8cd3b3d871212 --- /dev/null +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -0,0 +1,79 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit(out_idx=[2]) +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + data_is_ready = T.alloc_barrier(arrive_count=128) + compute_is_done = T.alloc_barrier(arrive_count=128) + + with T.ws(0): + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + with T.ws(1): + T.barrier_wait(compute_is_done, (ko + 1) % 2) + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.barrier_arrive(data_is_ready) + with T.ws(0): + T.barrier_wait(data_is_ready, ko % 2) + T.gemm(A_shared, B_shared, C_local) + T.barrier_arrive(compute_is_done) + + with T.ws(0): + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py new file mode 100644 index 0000000000000000000000000000000000000000..5468aa6eace4ca259e885db9bd33d9e6a77b459a --- /dev/null +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py @@ -0,0 +1,96 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + }, +) +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + warp_group_num = 2 + threads = 128 * warp_group_num + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") + B_shared_g0 = T.alloc_shared((block_K, block_N // warp_group_num), dtype, "shared") + B_shared_g1 = T.alloc_shared((block_K, block_N // warp_group_num), dtype, "shared") + + C_local_g0 = T.alloc_fragment((block_M, block_N // warp_group_num), accum_dtype) + C_local_g1 = T.alloc_fragment((block_M, block_N // warp_group_num), accum_dtype) + + with T.ws(1): + T.clear(C_local_g1) + with T.ws(0): + T.clear(C_local_g0) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + with T.ws(1): + T.copy(B[ko * block_K, bx * block_N], B_shared_g1) + T.gemm(A_shared, B_shared_g1, C_local_g1) + with T.ws(0): + T.copy(B[ko * block_K, bx * block_N + block_N // warp_group_num], B_shared_g0) + T.gemm(A_shared, B_shared_g0, C_local_g0) + + with T.ws(1): + T.copy(C_local_g1, C[by * block_M, bx * block_N]) + with T.ws(0): + T.copy(C_local_g0, C[by * block_M, bx * block_N + block_N // warp_group_num]) + + return main + + +def main(): + M = 128 + N = 128 + K = 64 + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) + print(jit_kernel.get_kernel_source()) + # 3. Test the kernel in Python with PyTorch data + import torch + + # Create random input tensors on the GPU + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + print(c) + + # Reference multiplication using PyTorch + ref_c = a @ b + print(ref_c) + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py new file mode 100644 index 0000000000000000000000000000000000000000..31d156f327a6ccde2c6d1ac236b4622e60048dfe --- /dev/null +++ b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -0,0 +1,82 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +@tilelang.jit(out_idx=[2]) +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor[(M, K), dtype], + B: T.Tensor[(K, N), dtype], + C: T.Tensor[(M, N), dtype], + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # create mbarrier for tma + data_is_ready = T.alloc_barrier(arrive_count=128) + compute_is_done = T.alloc_barrier(arrive_count=128) + + with T.ws(0): + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2): + with T.ws(1): + T.barrier_wait(compute_is_done, (ko + 1) % 2) + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.barrier_arrive(data_is_ready) + with T.ws(0): + T.barrier_wait(data_is_ready, ko % 2) + T.gemm(A_shared, B_shared, C_local) + T.barrier_arrive(compute_is_done) + + with T.ws(0): + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + + # 3. Test the kernel in Python with PyTorch data + import torch + + # Create random input tensors on the GPU + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Run the kernel through the Profiler + c = jit_kernel(a, b) + + # Reference multiplication using PyTorch + ref_c = a @ b + + # Validate correctness + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + # 4. Retrieve and inspect the generated CUDA source (optional) + # cuda_source = jit_kernel.get_kernel_source() + # print("Generated CUDA kernel:\n", cuda_source) + + # 5.Profile latency with kernel + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + latency = profiler.do_bench() + + print(f"Latency: {latency} ms") + + +if __name__ == "__main__": + main() diff --git a/examples/warp_specialize/test_example_warp_specialize.py b/examples/warp_specialize/test_example_warp_specialize.py new file mode 100644 index 0000000000000000000000000000000000000000..dee507790b129d462d614d183dcd322910ff0f5d --- /dev/null +++ b/examples/warp_specialize/test_example_warp_specialize.py @@ -0,0 +1,42 @@ +import tilelang.testing + +import example_warp_specialize_gemm_barrierpipe_stage2 +import example_warp_specialize_gemm_copy_0_gemm_1 +import example_warp_specialize_gemm_copy_1_gemm_0 +import example_warp_specialize_gemm_softpipe_stage2 + +# TODO: skip for now as non-deterministic on H20 +# CC @cunxiao +# @tilelang.testing.requires_cuda +# @tilelang.testing.requires_cuda_compute_version_eq(9, 0) +# def test_example_warp_specialize_flashmla(): +# import example_warp_specialize_flashmla +# example_warp_specialize_flashmla.main() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_warp_specialize_gemm_barrierpipe_stage2(): + example_warp_specialize_gemm_barrierpipe_stage2.main(M=1024, N=1024, K=1024) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_warp_specialize_gemm_copy_0_gemm_1(): + example_warp_specialize_gemm_copy_0_gemm_1.main(M=1024, N=1024, K=1024) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_warp_specialize_gemm_copy_1_gemm_0(): + example_warp_specialize_gemm_copy_1_gemm_0.main(M=1024, N=1024, K=1024) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_warp_specialize_gemm_softpipe_stage2(): + example_warp_specialize_gemm_softpipe_stage2.main(M=1024, N=1024, K=1024) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/format.sh b/format.sh new file mode 100755 index 0000000000000000000000000000000000000000..3cc4390dbe2a3b33e928c1c52f546dbf2cdc21bf --- /dev/null +++ b/format.sh @@ -0,0 +1,183 @@ +#!/usr/bin/env bash +# Usage: +# # Do work and commit your work. +# +# # Format files that differ from origin/main. +# bash format.sh +# +# # Format all files. +# bash format.sh --all +# +# +# Ruff (format) + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# You are encouraged to run this locally before pushing changes for review. + +# Cause the script to exit if a single command fails +set -eo pipefail + +if [[ -z "${BASH_VERSION}" ]]; then + echo "Please run this script using bash." >&2 + exit 1 +fi + +# this stops git rev-parse from failing if we run this from the .git directory +builtin cd "$(dirname "${BASH_SOURCE:-$0}")" +ROOT="$(git rev-parse --show-toplevel)" +builtin cd "$ROOT" || exit 1 + +ALL_FILES='' +ONLY_CHANGED='' +FILES=() +if (($# == 0)); then + # Default: allow dirty workspace; run on changed files (committed + worktree) + ONLY_CHANGED='true' +else + while (($# > 0)); do + case "$1" in + --files) + shift + while (($# > 0)); do + FILES+=("$1") + shift + done + ;; + --all) + ALL_FILES='true' + shift + ;; + *) + echo "Unknown argument: '$1'" >&2 + exit 1 + ;; + esac + done +fi + +MERGE_BASE="" +get_merge_base() { + UPSTREAM_REPO="https://github.com/tile-ai/tilelang" + if git ls-remote --exit-code "${UPSTREAM_REPO}" main &>/dev/null; then + # First try to use the upstream repository directly + MERGE_BASE="$(git fetch "${UPSTREAM_REPO}" main &>/dev/null && git merge-base FETCH_HEAD HEAD)" + elif git show-ref --verify --quiet refs/remotes/origin/main; then + # Fall back to origin/main if available + BASE_BRANCH="origin/main" + MERGE_BASE="$(git merge-base "${BASE_BRANCH}" HEAD)" + else + # Last resort, use local main + BASE_BRANCH="main" + MERGE_BASE="$(git merge-base "${BASE_BRANCH}" HEAD)" + fi + echo "${MERGE_BASE}" +} + +if [[ -n "${ALL_FILES}" ]]; then + echo "Checking all files..." >&2 +elif [[ -n "${ONLY_CHANGED}" ]]; then + MERGE_BASE="$(get_merge_base)" + echo "Checking changed files vs merge base (${MERGE_BASE}) and working tree..." >&2 +elif [[ "${#FILES[@]}" -gt 0 ]]; then + echo "Checking specified files: ${FILES[*]}..." >&2 +fi + +# Some systems set pip's default to --user, which breaks isolated virtualenvs. +export PIP_USER=0 + +# If pre-commit is not installed, install it. +if ! python3 -m pre_commit --version &>/dev/null; then + python3 -m pip install pre-commit --user +fi + +echo 'tile-lang pre-commit: Check Start' + +if [[ -n "${ALL_FILES}" ]]; then + python3 -m pre_commit run --all-files +elif [[ -n "${ONLY_CHANGED}" ]]; then + # Collect changed files (committed since merge-base + current worktree) + CHANGED_FILES="$(git diff --name-only --diff-filter=ACM "${MERGE_BASE}" 2>/dev/null || true)" + if [[ -n "${CHANGED_FILES}" ]]; then + echo "Running pre-commit on changed files:" + echo "${CHANGED_FILES}" + # Convert newline-separated files to space-separated and run pre-commit once + CHANGED_FILES_SPACE="$(echo "${CHANGED_FILES}" | tr '\n' ' ')" + python3 -m pre_commit run --files ${CHANGED_FILES_SPACE} + else + echo "No files changed relative to merge base and worktree. Skipping pre-commit." + fi +elif [[ "${#FILES[@]}" -gt 0 ]]; then + python3 -m pre_commit run --files "${FILES[@]}" +fi + +echo 'tile-lang pre-commit: Done' + +echo 'tile-lang clang-tidy: Check Start' +# If clang-tidy is available, run it; otherwise, skip +if [[ -x "$(command -v run-clang-tidy)" ]]; then + # Check if clang-tidy is available + if [[ ! -x "$(command -v clang-tidy)" ]]; then + python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" --user + fi + # Get clang-tidy version + CLANG_TIDY_VERSION="$(clang-tidy --version | head -n1 | awk '{print $4}')" + echo "Using clang-tidy version: ${CLANG_TIDY_VERSION}" + + # Check if build directory exists + if [[ ! -d "${ROOT}/build" ]]; then + echo "Build directory not found. Skipping clang-tidy checks." + else + # Run clang-tidy on specified files + clang_tidy_files() { + run-clang-tidy -j 64 "$@" -p build + } + + # Run clang-tidy on all C/C++ source files + clang_tidy_all() { + run-clang-tidy -j 64 src/*.cc -p build + } + + # Run clang-tidy on changed C/C++ files relative to main + clang_tidy_changed() { + # Get changed C/C++ files + CHANGED_FILES="$(git diff --name-only --diff-filter=ACM "${MERGE_BASE}" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' 2>/dev/null || true)" + + if [[ -n "${CHANGED_FILES}" ]]; then + echo "Running clang-tidy on changed files:" + echo "${CHANGED_FILES}" + # Convert newline-separated files to space-separated and run clang-tidy once + CHANGED_FILES_SPACE="$(echo "${CHANGED_FILES}" | tr '\n' ' ')" + run-clang-tidy -j 64 ${CHANGED_FILES_SPACE} -p build -fix + else + echo "No C/C++ files changed. Skipping clang-tidy." + fi + } + + if [[ -n "${ALL_FILES}" ]]; then + # If --all is given, run clang-tidy on all source files + clang_tidy_all + elif [[ -n "${ONLY_CHANGED}" ]]; then + # Otherwise, run clang-tidy only on changed C/C++ files + clang_tidy_changed + elif [[ "${#FILES[@]}" -gt 0 ]]; then + # If --files is given, run clang-tidy only on the provided files + clang_tidy_files "${FILES[@]}" + fi + fi + +else + echo "run-clang-tidy not found. Skipping clang-tidy checks." + echo "To install clang-tidy tools, you may need to install clang-tidy and run-clang-tidy." +fi +echo 'tile-lang clang-tidy: Done' + +# Check if there are any uncommitted changes after all formatting steps. +# If there are, ask the user to review and stage them. +if ! git diff --quiet &>/dev/null; then + echo 'Reformatted files. Please review and stage the changes.' + echo 'Changes not staged for commit:' + echo + git --no-pager diff --name-only + + exit 1 +fi + +echo 'tile-lang: All checks passed' diff --git a/images/MatmulExample.png b/images/MatmulExample.png new file mode 100644 index 0000000000000000000000000000000000000000..555ae30a75b2486bffb8acf27f72802d2c96ec3d Binary files /dev/null and b/images/MatmulExample.png differ diff --git a/images/MatmulExample.svg b/images/MatmulExample.svg new file mode 100644 index 0000000000000000000000000000000000000000..6e20daf554d6ebf18bb28af827f8822238861cf2 Binary files /dev/null and b/images/MatmulExample.svg differ diff --git a/images/logo-row.svg b/images/logo-row.svg new file mode 100644 index 0000000000000000000000000000000000000000..633243f3a9a003a903b859e8d8da5273b0f4cbf3 Binary files /dev/null and b/images/logo-row.svg differ diff --git a/images/mha_performance_h100.png b/images/mha_performance_h100.png new file mode 100644 index 0000000000000000000000000000000000000000..54c7cf94bf632badc2732716891b808b649de68d Binary files /dev/null and b/images/mha_performance_h100.png differ diff --git a/images/op_benchmark_a100_wq_gemv.png b/images/op_benchmark_a100_wq_gemv.png new file mode 100644 index 0000000000000000000000000000000000000000..c31c80e50f8cb792aa637f380e524f4e190d3894 Binary files /dev/null and b/images/op_benchmark_a100_wq_gemv.png differ diff --git a/images/op_benchmark_consistent_gemm_fp16.png b/images/op_benchmark_consistent_gemm_fp16.png new file mode 100644 index 0000000000000000000000000000000000000000..840e423e7199a96e8127cfe2750f7ebb60058bb3 Binary files /dev/null and b/images/op_benchmark_consistent_gemm_fp16.png differ diff --git a/images/op_benchmark_h100.png b/images/op_benchmark_h100.png new file mode 100644 index 0000000000000000000000000000000000000000..3480ec522c90475c67db341e78c2d4b28b6f7c83 Binary files /dev/null and b/images/op_benchmark_h100.png differ diff --git a/images/op_benchmark_mi300_fp16_gemm_normalized_latency.png b/images/op_benchmark_mi300_fp16_gemm_normalized_latency.png new file mode 100644 index 0000000000000000000000000000000000000000..90839aea728155fc51f944e04b37962b78e9f8c2 Binary files /dev/null and b/images/op_benchmark_mi300_fp16_gemm_normalized_latency.png differ diff --git a/maint/gemm_v2/correctness_evaluation.py b/maint/gemm_v2/correctness_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..44441cdeb7978a2d441c08df6ee843970b3ab840 --- /dev/null +++ b/maint/gemm_v2/correctness_evaluation.py @@ -0,0 +1,739 @@ +# pytest correctness_evaluation.py -n 32 +import pytest +from tilelang import tvm as tvm +import tilelang.testing +from tilelang import language as T +import torch + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def _compile_and_check( + program, + trans_A, + trans_B, + in_dtype, + out_dtype, +): + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, + }, + ) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print("assert_allclose") + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) + # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +def matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + B_frag_shape = B_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_frag) + T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul_sr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +def matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + B_frag_shape = B_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.copy(B_shared, B_frag) + T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul_rr( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +M_VALUES = [64, 128, 256] +N_VALUES = [16, 32, 64, 128, 256, 512] +K_VALUES = [16, 32, 64, 128] +K_VALUES_8Bit = [32, 64, 128] +FALSE_TRUE_CASES = ( + [ + pytest.param( + k, + T.float16, + T.float16, + T.float16, + id=f"K{k}-float16-float16-float16", + ) + for k in K_VALUES + ] + + [ + pytest.param( + k, + T.int8, + T.int32, + T.int32, + id="K32-int8-int32-int32", + ) + for k in K_VALUES_8Bit + ] + + [ + pytest.param( + k, + T.float8_e5m2, + T.float32, + T.float32, + id="K32-float8_e5m2-float32-float32", + ) + for k in K_VALUES_8Bit + ] + + [ + pytest.param( + k, + T.float8_e4m3fn, + T.float32, + T.float32, + id="K32-float8_e4m3-float32-float32", + ) + for k in K_VALUES_8Bit + ] +) + + +def _ensure_torch_dtypes(*dtype_names): + import torch + + for name in set(dtype_names): + if not hasattr(torch, name): + pytest.skip(f"Torch does not expose dtype {name}") + + +def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k) + + +def run_gemm_rs_false_false(m, n, k): + run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rs_true_false(m, n, k): + run_gemm_rs(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rs_true_true(m, n, k): + run_gemm_rs(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + run_gemm_sr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k) + + +def run_gemm_sr_false_false(m, n, k): + run_gemm_sr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_sr_true_false(m, n, k): + run_gemm_sr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_sr_true_true(m, n, k): + run_gemm_sr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + run_gemm_rr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k) + + +def run_gemm_rr_false_false(m, n, k): + run_gemm_rr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rr_true_false(m, n, k): + run_gemm_rr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) + + +def run_gemm_rr_true_true(m, n, k): + run_gemm_rr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) + + +TRANS_CASES = [ + pytest.param(False, False, id="nn"), + pytest.param(False, True, id="nt"), + pytest.param(True, False, id="tn"), + pytest.param(True, True, id="tt"), +] + + +@pytest.fixture(scope="module", autouse=True) +def _setup_tilelang_environment(): + tilelang.disable_cache() + tilelang.testing.set_random_seed(42) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + import torch + + required_torch_attrs = { + in_dtype, + out_dtype, + accum_dtype, + } + for attr in required_torch_attrs: + if not hasattr(torch, attr): + pytest.skip(f"Torch does not expose dtype {attr}") + run_gemm( + m, + n, + k * 3, + False, + True, + in_dtype, + out_dtype, + accum_dtype, + m, + n, + k, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_false_false(m, n, k): + run_gemm( + m, + n, + k * 3, + False, + False, + T.float16, + T.float16, + T.float16, + m, + n, + k, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_true_false(m, n, k): + run_gemm( + m, + n, + k * 3, + True, + False, + T.float16, + T.float16, + T.float16, + m, + n, + k, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_true_true(m, n, k): + run_gemm( + m, + n, + k * 3, + True, + True, + T.float16, + T.float16, + T.float16, + m, + n, + k, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) + run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rs_false_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rs_false_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rs_true_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rs_true_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rs_true_true(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rs_true_true(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) + run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_sr_false_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_sr_false_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_sr_true_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_sr_true_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_sr_true_true(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_sr_true_true(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) + run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rr_false_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rr_false_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rr_true_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rr_true_false(m, n, k) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rr_true_true(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rr_true_true(m, n, k) + + +if __name__ == "__main__": + tilelang.testing.main() + + # # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} False False =============================") + # run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} True False =============================") + # run_gemm(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m}, {n} {k} Pass") + # print(f"Test {n} Pass") + + # # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} True True =============================") + # run_gemm(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m}, {n} {k} Pass") + # print(f"Test {n} Pass") + + # Test Pass + # for m in [64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm_rs(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # run_gemm_rs(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256) + # print(f"Test {64} {n} {k} Pass") + + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # run_gemm(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256) + # print(f"Test {64} {n} {k} Pass") diff --git a/maint/gemm_v2/correctness_evaluation_sm70.py b/maint/gemm_v2/correctness_evaluation_sm70.py new file mode 100644 index 0000000000000000000000000000000000000000..606d10261100e617cdd3360823d41aa16261004e --- /dev/null +++ b/maint/gemm_v2/correctness_evaluation_sm70.py @@ -0,0 +1,350 @@ +# pytest maint/gemm_v2/correctness_evaluation_sm70.py -n 32 +import pytest +from tilelang import tvm as tvm +import tilelang.testing +from tilelang import language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def _compile_and_check( + program, + trans_A, + trans_B, + in_dtype, + out_dtype, +): + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, + }, + ) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print("assert_allclose") + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn") + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B) + # T.gemm(A_frag, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +M_VALUES = [64, 128] +N_VALUES = [32, 64, 128] +K_VALUES = [16, 32, 64] +FALSE_TRUE_CASES = [ + pytest.param( + k, + T.float16, + T.float16, + T.float16, + id=f"K{k}-float16-float16-float16", + ) + for k in K_VALUES +] + [ + pytest.param( + k, + T.float16, + T.float16, + T.float32, + id=f"K{k}-float16-float16-float32", + ) + for k in K_VALUES +] + + +def _ensure_torch_dtypes(*dtype_names): + import torch + + for name in set(dtype_names): + if not hasattr(torch, name): + pytest.skip(f"Torch does not expose dtype {name}") + + +def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128) + + +def run_gemm_rs_false_false(m, n, k): + run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) + + +TRANS_CASES = [ + pytest.param(False, False, id="nn"), + pytest.param(False, True, id="nt"), + pytest.param(True, False, id="tn"), + pytest.param(True, True, id="tt"), +] + + +@pytest.fixture(scope="module", autouse=True) +def _setup_tilelang_environment(): + tilelang.disable_cache() + tilelang.testing.set_random_seed(42) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + import torch + + required_torch_attrs = { + in_dtype, + out_dtype, + accum_dtype, + } + for attr in required_torch_attrs: + if not hasattr(torch, attr): + pytest.skip(f"Torch does not expose dtype {attr}") + run_gemm( + m, + n, + k * 3, + False, + True, + in_dtype, + out_dtype, + accum_dtype, + m, + n, + k, + 2, + 128, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_false_false(m, n, k): + run_gemm( + m, + n, + k * 3, + False, + False, + T.float16, + T.float16, + T.float16, + m, + n, + k, + 2, + 128, + ) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + _ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype) + run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype) + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") +def test_gemm_rs_false_false(m, n, k): + _ensure_torch_dtypes(T.float16) + run_gemm_rs_false_false(m, n, k) + + +if __name__ == "__main__": + tilelang.testing.main() + + # # Test Pass + # for m in [64, 128]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64]: + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [64, 128]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64]: + # print(f"======================= Test {m} {n} {k} False False =============================") + # run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") diff --git a/maint/gemm_v2/correctness_evaluation_tcgen05.py b/maint/gemm_v2/correctness_evaluation_tcgen05.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9728182b2606b6398e470639c1e049c0bd7ec6 --- /dev/null +++ b/maint/gemm_v2/correctness_evaluation_tcgen05.py @@ -0,0 +1,218 @@ +# pytest correctness_evaluation.py -n 32 +import pytest +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def _compile_and_check( + program, + trans_A, + trans_B, + in_dtype, + out_dtype, +): + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + print(kernel.get_kernel_source()) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + if in_dtype == T.float32: + A = (A.view(torch.int32) - 0x1000).view(torch.float32) + B = (B.view(torch.int32) - 0x1000).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print("assert_allclose") + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=2, + num_threads=128, +): + if block_N >= 256 or block_M >= 256 or block_K >= 256: + num_stages = 0 + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + num_stages, + num_threads, + ) + + _compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype) + + +M_VALUES = [32, 64, 128, 256] +N_VALUES = [64, 128, 256, 512] +K_VALUES = [16, 32, 64, 128] +K_VALUES_8Bit = [32, 64, 128] +FALSE_TRUE_CASES = [ + pytest.param( + k, + T.float16, + T.float32, + T.float32, + id=f"K{k}-float16-float-float", + ) + for k in K_VALUES +] + [ + pytest.param( + k, + T.float8_e5m2, + T.float32, + T.float32, + id="K32-float8_e5m2-float32-float32", + ) + for k in K_VALUES_8Bit +] + +TRANS_CASES = [ + pytest.param(False, True, id="nt"), +] + + +@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}") +@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") +@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES) +def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): + import torch + + required_torch_attrs = { + in_dtype, + out_dtype, + accum_dtype, + } + for attr in required_torch_attrs: + if not hasattr(torch, attr): + pytest.skip(f"Torch does not expose dtype {attr}") + run_gemm( + m, + n, + k * 3, + False, + True, + in_dtype, + out_dtype, + accum_dtype, + m, + n, + k, + ) + + +if __name__ == "__main__": + tilelang.testing.main() + + # # Test Pass + # for m in [32, 64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [16, 32, 64, 128]: + # if m in [32, 64] and (n not in [64, 128, 256]): + # continue + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 128) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [32, 64, 128, 256]: + # for n in [32, 64, 128]: + # for k in [16, 32, 64, 128]: + # if m in [32, 64] and (n not in [64, 128, 256]): + # continue + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 256) + # print(f"Test {m} {n} {k} Pass") + + # # Test Pass + # for m in [32, 64, 128, 256]: + # for n in [16, 32, 64, 128]: + # for k in [32, 64, 128]: + # if m in [32, 64] and (n not in [64, 128, 256]): + # continue + # print(f"======================= Test {m} {n} {k} False True =============================") + # run_gemm(m, n, k * 3, False, True, T.float8_e5m2, T.float, T.float, m, n, k, 2, 128) diff --git a/maint/gemm_v2/latency.py b/maint/gemm_v2/latency.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b2a2af95971456aec80ebab4117fd7cff81edd --- /dev/null +++ b/maint/gemm_v2/latency.py @@ -0,0 +1,98 @@ +import tilelang +import tilelang.language as T +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--use_v2", action="store_true") +args = parser.parse_args() + +use_v2 = args.use_v2 + + +# @tilelang.jit(target="cuda") +# target currently can be "cuda" or "hip" or "cpu". +# if not specified, it will be inferred from the input tensors during compile time +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + if use_v2: + T.gemm_v2(A_shared, B_shared, C_local) + else: + T.gemm_v1(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M = 16384 # M = T.dynamic("m") if you want to use dynamic shape +N = 16384 +K = 16384 +block_M = 128 +block_N = 128 +block_K = 32 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) + +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +matmul_relu_kernel(a, b, c) + +print(c) +# Reference multiplication using PyTorch +ref_c = torch.relu(a @ b) + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = jit_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/maint/gemm_v2/latency_gemm.py b/maint/gemm_v2/latency_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0450e0230e18c053b9114642a7931b616926ba --- /dev/null +++ b/maint/gemm_v2/latency_gemm.py @@ -0,0 +1,98 @@ +import tilelang +import tilelang.language as T +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--use_v2", action="store_true") +args = parser.parse_args() + +use_v2 = args.use_v2 + + +# @tilelang.jit(target="cuda") +# target currently can be "cuda" or "hip" or "cpu". +# if not specified, it will be inferred from the input tensors during compile time +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + if use_v2: + T.gemm_v2(A_shared, B_shared, C_local) + else: + T.gemm_v1(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M = 16384 # M = T.dynamic("m") if you want to use dynamic shape +N = 16384 +K = 16384 +block_M = 128 +block_N = 128 +block_K = 64 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) + +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +matmul_relu_kernel(a, b, c) + +print(c) +# Reference multiplication using PyTorch +ref_c = torch.relu(a @ b) + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = jit_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/maint/gemm_v2/latency_mha_fwd_bhsd.py b/maint/gemm_v2/latency_mha_fwd_bhsd.py new file mode 100644 index 0000000000000000000000000000000000000000..7a83d7cec8ad840e09dc2ae5c6da6ef156760fb7 --- /dev/null +++ b/maint/gemm_v2/latency_mha_fwd_bhsd.py @@ -0,0 +1,228 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + +parser = argparse.ArgumentParser() +parser.add_argument("--batch", type=int, default=128, help="batch size") +parser.add_argument("--heads", type=int, default=16, help="heads") +parser.add_argument("--seq_q", type=int, default=1024, help="query sequence length") +parser.add_argument("--seq_kv", type=int, default=1024, help="key/value sequence length") +parser.add_argument("--dim", type=int, default=256, help="dim") +parser.add_argument("--is_causal", action="store_true", help="causal") +parser.add_argument("--tune", action="store_true", help="tune configs") +parser.add_argument("--use_v2", action="store_true") + +args = parser.parse_args() + +use_v2 = args.use_v2 + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + dtype = T.float16 + accum_dtype = T.float32 + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + if use_v2: + T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + else: + T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) + # T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if use_v2: + T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + else: + T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) + if is_causal + else T.ceildiv(seq_kv, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_q = Q.size(2) + seq_kv = K.size(2) + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128) + print(kernel.get_kernel_source()) + ref_program_processed = partial(ref_program, is_causal=is_causal) + + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print(f"Ref: {latency:.2f} ms") + print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops") + latency = profiler.do_bench(warmup=500) + print(f"Tile-lang: {latency:.2f} ms") + print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops") + else: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + tilelang.disable_cache() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/maint/host_checks/01_num_args_mismatch.py b/maint/host_checks/01_num_args_mismatch.py new file mode 100644 index 0000000000000000000000000000000000000000..9528652eea985b6c1a57c80897329cd3c1eafef4 --- /dev/null +++ b/maint/host_checks/01_num_args_mismatch.py @@ -0,0 +1,22 @@ +"""Reproduce: Argument count mismatch. + +Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output. +Calling with the wrong number of inputs raises a ValueError before host entry. +""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + # Missing b + # Expected: ValueError with message about expected vs. actual inputs + fn(a) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/02_pointer_type_error.py b/maint/host_checks/02_pointer_type_error.py new file mode 100644 index 0000000000000000000000000000000000000000..188a4f8cc02254a47665f82b538a0210fa51d768 --- /dev/null +++ b/maint/host_checks/02_pointer_type_error.py @@ -0,0 +1,23 @@ +"""Reproduce: Pointer-type argument expected but scalar provided. + +We pass an integer for A; wrapper forwards it to the host where a pointer is expected. +Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param). +""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 256 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # Wrong type for A (int instead of tensor) + a = 1 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/03_ndim_mismatch.py b/maint/host_checks/03_ndim_mismatch.py new file mode 100644 index 0000000000000000000000000000000000000000..76637e8deda8a52165209c026661e45a2cf6da75 --- /dev/null +++ b/maint/host_checks/03_ndim_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: ndim (rank) mismatch for A.""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A has rank 3 instead of 2 + a = torch.empty((M, K, 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/04_dtype_mismatch.py b/maint/host_checks/04_dtype_mismatch.py new file mode 100644 index 0000000000000000000000000000000000000000..f3554c1d6ace76971b1b787d8febfb7ce286dc09 --- /dev/null +++ b/maint/host_checks/04_dtype_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: dtype mismatch for A (float32 vs expected float16).""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + print(fn.get_host_source()) + + a = torch.empty((M, K), device="cuda", dtype=torch.float32) # should be float16 + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/05_shape_mismatch.py b/maint/host_checks/05_shape_mismatch.py new file mode 100644 index 0000000000000000000000000000000000000000..a48248176501d66c2a46207da7357d22b8519f7c --- /dev/null +++ b/maint/host_checks/05_shape_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: shape constant/symbol mismatch on A.""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + # A's second dimension is wrong (K+1 instead of K) + a = torch.empty((M, K + 1), device="cuda", dtype=torch.float16) + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/06_strides_mismatch.py b/maint/host_checks/06_strides_mismatch.py new file mode 100644 index 0000000000000000000000000000000000000000..7e523cd64ee2f9e2c762f8813901db4b50938a74 --- /dev/null +++ b/maint/host_checks/06_strides_mismatch.py @@ -0,0 +1,19 @@ +"""Reproduce: strides check failure (non-contiguous A via transpose).""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 128 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda", dtype=torch.float16) + a_nc = a.t() # non-contiguous after transpose + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a_nc, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/07_device_type_mismatch.py b/maint/host_checks/07_device_type_mismatch.py new file mode 100644 index 0000000000000000000000000000000000000000..af8e5efd5dfcf9514de445aec8eb8bedd676f1f0 --- /dev/null +++ b/maint/host_checks/07_device_type_mismatch.py @@ -0,0 +1,18 @@ +"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cpu", dtype=torch.float16) + b = torch.empty((K, N), device="cpu", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/08_device_id_mismatch.py b/maint/host_checks/08_device_id_mismatch.py new file mode 100644 index 0000000000000000000000000000000000000000..280aca1570d9200aaca9f21260c6270c760e8c5b --- /dev/null +++ b/maint/host_checks/08_device_id_mismatch.py @@ -0,0 +1,25 @@ +"""Reproduce: device_id mismatch (requires >=2 CUDA devices).""" + +import torch +from common import build_matmul_kernel + + +def main(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + if torch.cuda.device_count() < 2: + print("[SKIP] Need at least 2 CUDA devices to reproduce device_id mismatch.") + return + + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = torch.empty((M, K), device="cuda:0", dtype=torch.float16) + b = torch.empty((K, N), device="cuda:1", dtype=torch.float16) + # Output device is derived by the adapter; mismatch occurs in host checks + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/09_null_data_pointer.py b/maint/host_checks/09_null_data_pointer.py new file mode 100644 index 0000000000000000000000000000000000000000..09f5de1aff04438b3f24113c06979c0bd1c61dd5 --- /dev/null +++ b/maint/host_checks/09_null_data_pointer.py @@ -0,0 +1,26 @@ +"""Reproduce: NULL data pointer (advanced). + +Passing None for a tensor argument will be forwarded through the adapter. Depending on +FFI handling, this commonly triggers a pointer-type assertion (e.g., "Expect buffer to be pointer or tensor") +or a host-side non-NULL pointer check. + +Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script +demonstrates passing None, which still reproduces the intended class of failure. +""" + +import torch +from common import build_matmul_kernel + + +def main(): + M = N = K = 64 + fn = build_matmul_kernel(M, N, K, target="cuda") + + a = None # attempt to pass a null-like pointer + b = torch.empty((K, N), device="cuda", dtype=torch.float16) + + fn(a, b) + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/10_scalar_type_mismatch.py b/maint/host_checks/10_scalar_type_mismatch.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2c90b8d1df8a51eb0bdffa94e2223ecbcf47fb --- /dev/null +++ b/maint/host_checks/10_scalar_type_mismatch.py @@ -0,0 +1,15 @@ +"""Reproduce: scalar parameter type mismatch (int/bool).""" + +from common import build_scalar_check_kernel + + +def main(): + fn = build_scalar_check_kernel(target="cuda") + + # Wrong types + fn(1.0, True) # x should be int -> Expect arg[0] to be int + fn(1, 2.5) # flag should be bool -> Expect arg[1] to be boolean + + +if __name__ == "__main__": + main() diff --git a/maint/host_checks/README.md b/maint/host_checks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ac23d6fd2ad5a44a3d4fcb98486ebf500375ce39 --- /dev/null +++ b/maint/host_checks/README.md @@ -0,0 +1,21 @@ +# Host-Side Check Repro Scripts + +This folder contains standalone scripts that deliberately trigger host-side (and adapter-side) validation errors described in `docs/compiler_internals/tensor_checks.md`. Each script can be run directly and will reproduce the corresponding error with a minimal example. + +Prerequisites +- CUDA-capable environment (most scripts compile a CUDA-targeted kernel) +- Python packages: torch, tilelang + +Usage +- Run any script, e.g.: + - `python 01_num_args_mismatch.py` + - `python 02_pointer_type_error.py` + - ... up to `10_scalar_type_mismatch.py` + +- Or run all at once with a summary: + - `python run_all.py` + - Logs per test are saved under `logs/` as `