Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
...@@ -4,7 +4,9 @@ ExtraArgs: [] ...@@ -4,7 +4,9 @@ ExtraArgs: []
FormatStyle: file FormatStyle: file
UseColor: true UseColor: true
WarningsAsErrors: '*' WarningsAsErrors: '*'
# FIXME: Use `ExcludeHeaderFilterRegex` instead when all maintainers upgraded their `clang-tidy`
HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*' HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*'
# ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$'
# NOTE: there must be no spaces before the '-', so put the comma last. # NOTE: there must be no spaces before the '-', so put the comma last.
Checks: >- Checks: >-
......
blank_issues_enabled: false blank_issues_enabled: true
...@@ -40,7 +40,7 @@ jobs: ...@@ -40,7 +40,7 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v6
with: with:
fetch-depth: 0 fetch-depth: 0
submodules: recursive submodules: recursive
...@@ -93,7 +93,7 @@ jobs: ...@@ -93,7 +93,7 @@ jobs:
name: self-hosted-amd name: self-hosted-amd
# Format: [Nightly-]ROCm-<major>.<minor>[.<patch>]. E.g., "ROCm-6.4" or "Nightly-ROCm-7.0". # Format: [Nightly-]ROCm-<major>.<minor>[.<patch>]. E.g., "ROCm-6.4" or "Nightly-ROCm-7.0".
# Use "Nightly-" prefix to use torch nightly builds. # Use "Nightly-" prefix to use torch nightly builds.
toolkit: ROCm-6.3 toolkit: Nightly-ROCm-7.1
- tags: [macos-latest] - tags: [macos-latest]
name: macos-latest name: macos-latest
toolkit: Metal # or Nightly-Metal toolkit: Metal # or Nightly-Metal
...@@ -104,7 +104,7 @@ jobs: ...@@ -104,7 +104,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v6
with: with:
fetch-depth: 0 fetch-depth: 0
submodules: recursive submodules: recursive
...@@ -288,35 +288,59 @@ jobs: ...@@ -288,35 +288,59 @@ jobs:
echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure." echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure."
uv cache clean uv cache clean
- name: Enable core dump generation (Linux / GitHub-hosted runners)
if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }}
run: |
sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P"
sudo sysctl -w kernel.core_uses_pid=0
sudo sysctl -w fs.suid_dumpable=1
sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable
- name: Enable core dump generation (macOS / GitHub-hosted runners)
if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }}
run: |
sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P"
sudo sysctl -w kern.coredump=1
sudo sysctl -w kern.sugid_coredump=1
sysctl kern.corefile kern.coredump kern.sugid_coredump
- name: Install project (wheel form)
run: |
uv pip install -v .
- name: Run clang-tidy - name: Run clang-tidy
id: clang-tidy id: clang-tidy
if: runner.os == 'Linux' if: runner.os == 'Linux'
run: | run: |
echo "\$ $(command -v clang-tidy) --version" && clang-tidy --version echo "\$ $(command -v clang-tidy) --version" && clang-tidy --version
if [[ -x "$(command -v run-clang-tidy)" ]]; then # Download run-clang-tidy script
echo "Using run-clang-tidy from $(command -v run-clang-tidy)" RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py
CLANG_TIDY=(run-clang-tidy) echo "Downloading run-clang-tidy script from ${RCT_URL}"
else echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script -
RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py RUN_CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py)
echo "Downloading run-clang-tidy script from ${RCT_URL}"
echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script -
CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py)
fi
if [[ -x "$(command -v clang-apply-replacements)" ]]; then if [[ -x "$(command -v clang-apply-replacements)" ]]; then
echo "Using clang-apply-replacements from $(command -v clang-apply-replacements)" echo "Using clang-apply-replacements from $(command -v clang-apply-replacements)"
CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)") RUN_CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)")
else else
echo "::warning::clang-apply-replacements not found in PATH, automatic fixing disabled." echo "::warning::clang-apply-replacements not found in PATH, automatic fixing disabled."
fi fi
# Run cmake to create the build directory with compile_commands.json # Run cmake to create the build directory with compile_commands.json
cmake -S . -B cmake-build --fresh ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here cmake -S . -B cmake-build --fresh ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here
echo "::group::compile_commands.json"
ls -alh cmake-build/compile_commands.json
uv run --no-project -m -- json.tool --no-ensure-ascii cmake-build/compile_commands.json
echo "::endgroup::"
CXX_FILES=$(find src -type f -iname "*.[ch]pp" -o -iname "*.cc" -o -iname "*.c" -o -iname "*.h") CXX_FILES=$(find src -type f -iname "*.[ch]pp" -o -iname "*.cc" -o -iname "*.c" -o -iname "*.h")
rc=0 rc=0
"${CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \ echo "::group::run-clang-tidy"
"${RUN_CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \
-exclude-header-filter='^(3rdparty|tvm)/.*$' \
-p="cmake-build" ${CXX_FILES} || rc="$?" -p="cmake-build" ${CXX_FILES} || rc="$?"
echo "::endgroup::"
rm -rf cmake-build run-clang-tidy.py rm -rf cmake-build run-clang-tidy.py
if (( rc != 0 )); then if (( rc != 0 )); then
echo "::error::clang-tidy found issues (exit code: ${rc}). Please run 'clang-tidy --fix' locally to fix them." echo "::error::clang-tidy found issues (exit code: ${rc}). Please run 'clang-tidy --fix' locally to fix them."
...@@ -324,26 +348,6 @@ jobs: ...@@ -324,26 +348,6 @@ jobs:
exit "${rc}" exit "${rc}"
fi fi
- name: Enable core dump generation (Linux / GitHub-hosted runners)
if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }}
run: |
sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P"
sudo sysctl -w kernel.core_uses_pid=0
sudo sysctl -w fs.suid_dumpable=1
sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable
- name: Enable core dump generation (macOS / GitHub-hosted runners)
if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }}
run: |
sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P"
sudo sysctl -w kern.coredump=1
sudo sysctl -w kern.sugid_coredump=1
sysctl kern.corefile kern.coredump kern.sugid_coredump
- name: Install project (wheel form)
run: |
uv pip install -v .
- name: Run examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) - name: Run examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
if: contains(matrix.runner.toolkit, 'CUDA') if: contains(matrix.runner.toolkit, 'CUDA')
run: | run: |
...@@ -366,8 +370,27 @@ jobs: ...@@ -366,8 +370,27 @@ jobs:
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
) )
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
--ignore=./python/jit/test_tilelang_jit_cutedsl.py \
./python ./python
# CuTeDSL JIT tests require GEMM v1 (must be set before importing tilelang).
# Run them in a dedicated step to avoid changing the default GEMM selection
# (and to keep the rest of the CUDA tests on GEMM v2).
- name: Run CuTeDSL JIT tests (GEMM v1) with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: cutedsl-tests
if: contains(matrix.runner.toolkit, 'CUDA')
env:
TILELANG_USE_GEMM_V1: "1"
run: |
cd testing
PYTEST=(
uv run --no-project -m --
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
# Avoid xdist contention on a single GPU by running this file in one worker.
"${PYTEST[@]}" --maxfail=3 --numprocesses=1 \
./python/jit/test_tilelang_jit_cutedsl.py
# AMD ROCm tests # AMD ROCm tests
- name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) - name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: rocm-tests id: rocm-tests
......
name: Dist name: Dist
on: on:
workflow_dispatch:
schedule: schedule:
# gemini said this is 6:00 china time # gemini said this is 6:00 china time
- cron: "0 22 * * *" - cron: "0 22 * * *"
...@@ -52,7 +53,7 @@ jobs: ...@@ -52,7 +53,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v6
with: with:
fetch-depth: 1 fetch-depth: 1
submodules: recursive submodules: recursive
...@@ -91,7 +92,7 @@ jobs: ...@@ -91,7 +92,7 @@ jobs:
- name: Upload SDist - name: Upload SDist
# Not PR to save artifact storage, as SDist is only needed for releases. # Not PR to save artifact storage, as SDist is only needed for releases.
if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]')
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
name: sdist name: sdist
path: dist/*.tar.gz path: dist/*.tar.gz
...@@ -105,14 +106,12 @@ jobs: ...@@ -105,14 +106,12 @@ jobs:
strategy: strategy:
matrix: matrix:
target: target:
- { runner: ubuntu-latest, toolkit: "CUDA-12.1" } - { runner: ubuntu-latest, toolkit: "CUDA-12.8" }
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" } - { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" }
- { runner: macos-latest, toolkit: "Metal" } - { runner: macos-latest, toolkit: "Metal" }
python-version: python-version:
# Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8. # Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8.
# Only build wheels against Python 3.8 Limited API to save CI resources. # Only build wheels against Python 3.8 Limited API to save CI resources.
# FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support
# Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9.
- "3.9" - "3.9"
fail-fast: false fail-fast: false
timeout-minutes: 120 timeout-minutes: 120
...@@ -122,7 +121,7 @@ jobs: ...@@ -122,7 +121,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v6
with: with:
fetch-depth: 1 fetch-depth: 1
submodules: recursive submodules: recursive
...@@ -160,7 +159,7 @@ jobs: ...@@ -160,7 +159,7 @@ jobs:
fi fi
- name: Build wheels - name: Build wheels
uses: pypa/cibuildwheel@v3.2 uses: pypa/cibuildwheel@v3.3
with: with:
package-dir: . package-dir: .
output-dir: wheelhouse output-dir: wheelhouse
...@@ -169,7 +168,7 @@ jobs: ...@@ -169,7 +168,7 @@ jobs:
- name: Upload wheels - name: Upload wheels
# Not PR to save artifact storage, as wheels are only needed for releases. # Not PR to save artifact storage, as wheels are only needed for releases.
if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]') if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]')
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
name: wheels-${{ matrix.python-version }}-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }} name: wheels-${{ matrix.python-version }}-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }}
path: wheelhouse/*.whl path: wheelhouse/*.whl
...@@ -184,7 +183,7 @@ jobs: ...@@ -184,7 +183,7 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Download built SDist - name: Download built SDist
uses: actions/download-artifact@v6 uses: actions/download-artifact@v7
with: with:
# unpacks default artifact into dist/ # unpacks default artifact into dist/
# if `name: artifact` is omitted, the action will create extra parent dir # if `name: artifact` is omitted, the action will create extra parent dir
...@@ -192,7 +191,7 @@ jobs: ...@@ -192,7 +191,7 @@ jobs:
path: dist path: dist
- name: Download built wheels - name: Download built wheels
uses: actions/download-artifact@v6 uses: actions/download-artifact@v7
with: with:
pattern: wheels-* pattern: wheels-*
path: dist path: dist
...@@ -202,7 +201,7 @@ jobs: ...@@ -202,7 +201,7 @@ jobs:
run: ls -lh dist/* run: ls -lh dist/*
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
name: artifacts name: artifacts
path: dist/* path: dist/*
......
...@@ -33,7 +33,7 @@ jobs: ...@@ -33,7 +33,7 @@ jobs:
runs-on: [self-hosted, nvidia] runs-on: [self-hosted, nvidia]
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v6
with: with:
ref: refs/pull/${{ github.event.issue.number }}/merge ref: refs/pull/${{ github.event.issue.number }}/merge
fetch-depth: 0 fetch-depth: 0
......
...@@ -25,7 +25,7 @@ jobs: ...@@ -25,7 +25,7 @@ jobs:
runs-on: [self-hosted, nvidia] runs-on: [self-hosted, nvidia]
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v6
with: with:
fetch-depth: 0 fetch-depth: 0
submodules: recursive submodules: recursive
......
...@@ -108,3 +108,15 @@ cmake-build-*/ ...@@ -108,3 +108,15 @@ cmake-build-*/
# pre-commit cache # pre-commit cache
.pre-commit-cache/* .pre-commit-cache/*
# host checks logs
maint/host_checks/logs/*
# ncu
*.ncu-rep
# csv
*.csv
# clang-tidy
/run-clang-tidy.py
...@@ -32,30 +32,17 @@ repos: ...@@ -32,30 +32,17 @@ repos:
args: [--ignore-case] args: [--ignore-case]
files: ^docs/spelling_wordlist\.txt$ files: ^docs/spelling_wordlist\.txt$
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v21.1.2 # sync with requirements-lint.txt rev: v21.1.7 # sync with requirements-lint.txt
hooks: hooks:
- id: clang-format - id: clang-format
exclude: | types_or: [c++, c]
(?ix)(
^.+\.(cu|cuh)$|
^.+\.json$
)
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3 # sync with requirements-lint.txt rev: v0.14.9 # sync with requirements-lint.txt
hooks: hooks:
- id: ruff-check - id: ruff-check
args: [--fix, --exit-non-zero-on-fix] args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/google/yapf - id: ruff-format
rev: v0.43.0 # sync with requirements-lint.txt args: [--exit-non-zero-on-format]
hooks:
- id: yapf
name: yapf-multiproc-bugfix
# yapf is not multiprocess safe, so we run a dummy yapf first.
args: [--in-place, docs/conf.py]
always_run: true
pass_filenames: false
- id: yapf
args: [--recursive, --in-place]
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.1 # sync with requirements-lint.txt rev: v2.4.1 # sync with requirements-lint.txt
hooks: hooks:
......
Subproject commit 1c45ca35dd5c215e0c1db1f40f01556f467f52a8 Subproject commit b38bb492a1a55b5abb0c345962143c0f9c482cfb
Subproject commit 093b2cdb2187140b197336496d65d61ace89e8ff Subproject commit 79ed747db67e60d3a1889d8afd33473bc2424ade
...@@ -136,14 +136,21 @@ file(GLOB TILE_LANG_SRCS ...@@ -136,14 +136,21 @@ file(GLOB TILE_LANG_SRCS
src/*.cc src/*.cc
src/layout/*.cc src/layout/*.cc
src/transform/*.cc src/transform/*.cc
src/transform/common/*.cc
src/op/*.cc src/op/*.cc
src/target/utils.cc src/target/utils.cc
src/target/codegen_c_host.cc
src/target/codegen_cpp.cc src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc src/target/rt_mod_cpp.cc
# intrin_rule doesn't have system dependency # intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc src/target/intrin_rule*.cc
) )
# Always include CPU-safe runtime helpers
list(APPEND TILE_LANG_SRCS
src/runtime/error_helpers.cc
)
# Track if the user explicitly selected a backend via cache options. # Track if the user explicitly selected a backend via cache options.
set(TILELANG_BACKEND_USER_SELECTED OFF) set(TILELANG_BACKEND_USER_SELECTED OFF)
foreach(BACKEND IN LISTS TILELANG_BACKENDS) foreach(BACKEND IN LISTS TILELANG_BACKENDS)
...@@ -205,16 +212,28 @@ elseif(USE_CUDA) ...@@ -205,16 +212,28 @@ elseif(USE_CUDA)
cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA) cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA)
file(GLOB TILE_LANG_CUDA_SRCS file(GLOB TILE_LANG_CUDA_SRCS
src/runtime/*.cc src/runtime/runtime.cc
src/target/ptx.cc src/target/ptx.cc
src/target/codegen_cuda.cc src/target/codegen_cuda.cc
src/target/codegen_py.cc
src/target/codegen_utils.cc
src/target/codegen_cutedsl.cc
src/target/rt_mod_cuda.cc src/target/rt_mod_cuda.cc
src/target/rt_mod_cutedsl.cc
) )
list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS}) list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS})
list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS}) list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS})
endif() endif()
set(USE_Z3 ON CACHE STRING "Use Z3 SMT solver for TileLang optimizations")
set(USE_PYPI_Z3 ON CACHE BOOL "Use Z3 provided by PyPI z3-solver package")
if(USE_Z3 AND USE_PYPI_Z3)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/pypi-z3")
find_package(Z3 REQUIRED)
endif()
# Include tvm after configs have been populated # Include tvm after configs have been populated
add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL)
...@@ -222,7 +241,11 @@ add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) ...@@ -222,7 +241,11 @@ add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL)
add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>) add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS}) add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS})
# Set debug mode compile definitions
# We open the deubg option of TVM, i.e. TVM_LOG_DEBUG
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")
message(STATUS "Building TileLang with DEBUG mode")
target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG")
endif() endif()
...@@ -232,6 +255,18 @@ add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>) ...@@ -232,6 +255,18 @@ add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
add_library(tilelang_module SHARED $<TARGET_OBJECTS:tilelang_objs>) add_library(tilelang_module SHARED $<TARGET_OBJECTS:tilelang_objs>)
target_link_libraries(tilelang PUBLIC tvm_runtime tvm) target_link_libraries(tilelang PUBLIC tvm_runtime tvm)
target_link_libraries(tilelang_module PUBLIC tvm) target_link_libraries(tilelang_module PUBLIC tvm)
# Place dev build outputs under build/lib for consistency
set_target_properties(tilelang PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
)
set_target_properties(tilelang_module PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
)
# Build cython extension # Build cython extension
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT})
...@@ -251,25 +286,46 @@ if(NOT "${SKBUILD_SABI_VERSION}" STREQUAL "") ...@@ -251,25 +286,46 @@ if(NOT "${SKBUILD_SABI_VERSION}" STREQUAL "")
endif() endif()
python_add_library(tilelang_cython_wrapper MODULE "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" ${USE_SABI} WITH_SOABI) python_add_library(tilelang_cython_wrapper MODULE "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" ${USE_SABI} WITH_SOABI)
# Install extension into the tilelang package directory
# Ensure dev builds drop the extension into build/lib alongside other shared libs
set_target_properties(tilelang_cython_wrapper PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
)
# Install the extension into tilelang/lib inside the wheel
install(TARGETS tilelang_cython_wrapper install(TARGETS tilelang_cython_wrapper
LIBRARY DESTINATION tilelang LIBRARY DESTINATION tilelang/lib
RUNTIME DESTINATION tilelang RUNTIME DESTINATION tilelang/lib
ARCHIVE DESTINATION tilelang) ARCHIVE DESTINATION tilelang/lib)
# add python z3 lib path to rpath for running tests and dev in current folder
if(USE_Z3 AND USE_PYPI_Z3)
set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/lib)
set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/bin)
endif()
# let libtilelang to search tvm/tvm_runtime in same dir
if(APPLE) if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") set(TILELANG_INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") if(USE_Z3 AND USE_PYPI_Z3)
set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") # some z3 is placed in lib/ and some in bin/, we add both in rpath
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") list(APPEND TILELANG_INSTALL_RPATH "@loader_path/../../z3/lib" "@loader_path/../../z3/bin")
endif()
elseif(UNIX) elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") set(TILELANG_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") if(USE_Z3 AND USE_PYPI_Z3)
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") # cmake uses ; by default, we explicitly use : for linux
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../z3/lib")
endif()
endif() endif()
# let libtilelang to search tvm/tvm_runtime in same dir
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
install( install(
TARGETS tvm tvm_runtime tilelang_module tilelang TARGETS tvm tvm_runtime tilelang_module tilelang
LIBRARY DESTINATION tilelang/lib LIBRARY DESTINATION tilelang/lib
......
...@@ -81,6 +81,8 @@ in the main directory. This installation is removable by: ...@@ -81,6 +81,8 @@ in the main directory. This installation is removable by:
python3 -m pip uninstall tilelang python3 -m pip uninstall tilelang
``` ```
We also recommend installing TileLang in a more manual way for better control over the build process, by compiling the C++ extensions first and set the `PYTHONPATH`. See [Working from Source via `PYTHONPATH`](https://tilelang.com/get_started/Installation.html#working-from-source-via-pythonpath) for detailed instructions.
## Lint Check ## Lint Check
To check the linting, run: To check the linting, run:
......
...@@ -13,6 +13,9 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to ...@@ -13,6 +13,9 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png /> <img src=./images/MatmulExample.png />
## Latest News ## Latest News
- 12/18/2025 🚀: Added [CuTeDSL backend](https://github.com/tile-ai/tilelang/pull/1421) support, enabling compilation to NVIDIA CUTLASS CuTe DSL! Join us in building and optimizing this exciting new backend: [Issue #1454](https://github.com/tile-ai/tilelang/issues/1454).
- 12/17/2025 🔬: Integrated [Z3 theorem prover](https://github.com/tile-ai/tilelang/pull/1367) into TVM Arith Analyzer, bringing SMT-based symbolic reasoning for enhanced optimizations and automatic correctness verification!
- 10/31/2025 🔧: Migrated to [apache-tvm-ffi](https://github.com/tile-ai/tilelang/pull/1108), significantly reducing CPU overhead!
- 10/30/2025 📦: We have released v0.1.6.post2, which is the last version compatible with Python 3.8. - 10/30/2025 📦: We have released v0.1.6.post2, which is the last version compatible with Python 3.8.
- 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details. - 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details.
- 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported! - 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported!
...@@ -137,7 +140,7 @@ import tilelang.language as T ...@@ -137,7 +140,7 @@ import tilelang.language as T
# target currently can be "cuda" or "hip" or "cpu". # target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time # if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func @T.prim_func
def matmul_relu_kernel( def matmul_relu_kernel(
...@@ -209,7 +212,7 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) ...@@ -209,7 +212,7 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.") print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional) # 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source() # cuda_source = matmul_relu_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source) # print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel # 5.Profile latency with kernel
......
...@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def benchmark_topk_sparse_attention(): def benchmark_topk_sparse_attention():
from benchmark_configs import configs from benchmark_configs import configs
torch.manual_seed(0) torch.manual_seed(0)
# Config # Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
import flash_attn import flash_attn
......
...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -39,16 +36,15 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -39,16 +36,15 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N = 64 block_N = 64
num_stages = 2 num_stages = 2
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len] block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
block_mask_dtype = "bool" block_mask_dtype = T.bool
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
...@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -79,22 +74,24 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -79,22 +74,24 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf # To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done # This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps. # in the first ceil_div(kBlockM, kBlockN) steps.
...@@ -114,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -114,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -144,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -144,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -153,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -153,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj] block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k]: if block_mask[k]:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
scores_sum, logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
...@@ -175,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -175,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
def benchmark_topk_sparse_attention(): def benchmark_topk_sparse_attention():
from benchmark_configs import configs from benchmark_configs import configs
torch.manual_seed(0) torch.manual_seed(0)
# Config # Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
program = blocksparse_flashattn( program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=4) kernel = tilelang.compile(program, out_idx=4)
def benchmark_fn(): def benchmark_fn():
......
...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def benchmark_topk_sparse_attention(): def benchmark_topk_sparse_attention():
from benchmark_configs import configs from benchmark_configs import configs
torch.manual_seed(0) torch.manual_seed(0)
# Config # Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
def benchmark_fn(): def benchmark_fn():
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation # PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf')) attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
return ref_output return ref_output
ref_latency = do_bench( ref_latency = do_bench(
......
...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -56,7 +53,6 @@ def _fwd_kernel_inner( ...@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
if mask_val == True: if mask_val == True:
...@@ -72,8 +68,7 @@ def _fwd_kernel_inner( ...@@ -72,8 +68,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK: if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
float('-inf'))
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None] qk -= m_ij[:, None]
...@@ -153,7 +148,7 @@ def _fwd_kernel( ...@@ -153,7 +148,7 @@ def _fwd_kernel(
v_ptrs = V + off_v v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
...@@ -191,24 +186,12 @@ def _fwd_kernel( ...@@ -191,24 +186,12 @@ def _fwd_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty) acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
None, :] * stride_od
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx, def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2] assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous() o = out if out is not None else torch.empty_like(q).contiguous()
...@@ -253,7 +236,6 @@ def _forward(ctx, ...@@ -253,7 +236,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function): class _sparse_attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale): def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints # shape constraints
...@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply ...@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply
def benchmark_topk_sparse_attention(): def benchmark_topk_sparse_attention():
from benchmark_configs import configs from benchmark_configs import configs
torch.manual_seed(0) torch.manual_seed(0)
# Config # Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......
...@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): ...@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum) decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril( causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0) scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), out = torch.einsum(
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( out_prev = (
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
)
out = out + out_prev out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p") out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None: if D is not None:
...@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): ...@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
@helion.kernel() @helion.kernel()
def helion_mamba2_chunk_scan_kernel( def helion_mamba2_chunk_scan_kernel(
cb: torch.Tensor, cb: torch.Tensor,
...@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
dtype = cb.dtype dtype = cb.dtype
accum_dtype = torch.float32 accum_dtype = torch.float32
assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype
dtype)
out = torch.empty_like(x) out = torch.empty_like(x)
...@@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile(
[nheads, chunk_size, headdim, batch, nchunks], [nheads, chunk_size, headdim, batch, nchunks],
block_size=[1, block_m, block_n, 1, 1], block_size=[1, block_m, block_n, 1, 1],
): ):
acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype)
dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32)
tile_m].to(torch.float32)
scale_m_local = torch.exp2(dA_cumsum_local_m * p) scale_m_local = torch.exp2(dA_cumsum_local_m * p)
C_local = C[ C_local = C[
...@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
tile_m, tile_m,
tile_k, tile_k,
] ]
dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
tile_k].to(torch.float32) cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p)
cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p -
dA_cumsum_local_k[None, :] * p)
dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
cb_local = (cb_local * dt_local[None, :]).to(dtype) cb_local = (cb_local * dt_local[None, :]).to(dtype)
pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :]
...@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
acc_o = hl.dot(cb_local, x_local, acc=acc_o) acc_o = hl.dot(cb_local, x_local, acc=acc_o)
D_local = D[tile_h.begin].to(torch.float32) D_local = D[tile_h.begin].to(torch.float32)
x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32)
tile_n].to(torch.float32)
acc_o += x_residual * D_local acc_o += x_residual * D_local
out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype)
tile_n] = acc_o.to(dtype=dtype)
return out return out
...@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
def get_configs(): def get_configs():
iter_params = dict( iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5])
block_M=[64, 128, 256],
block_N=[32, 64],
block_K=[64, 128, 256],
block_Dstate=[128],
num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
...@@ -198,40 +187,42 @@ def get_configs(): ...@@ -198,40 +187,42 @@ def get_configs():
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, },
) )
def chunk_scan_fwd(batch, def chunk_scan_fwd(
seqlen, batch,
chunk_size, seqlen,
ngroups, chunk_size,
nheads, ngroups,
headdim, nheads,
dstate, headdim,
block_M=64, dstate,
block_N=64, block_M=64,
block_K=64, block_N=64,
block_Dstate=128, block_K=64,
num_stages=2, block_Dstate=128,
threads=128): num_stages=2,
dtype = "float16" threads=128,
accum_dtype = "float" ):
dtype = T.float16
accum_dtype = T.float32
nchunks = T.ceildiv(seqlen, chunk_size) nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504 p = 1.44269504
@T.prim_func @T.prim_func
def main( def main(
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as (
nheads, bz,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), bx,
batch * nchunks, by,
threads=threads) as (bz, bx, by): ):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
...@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch, ...@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch,
m_idx = bx // T.ceildiv(headdim, block_N) m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({ T.annotate_layout(
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), {
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
}) x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared),
}
)
T.no_set_max_nreg() T.no_set_max_nreg()
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared)
dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local) T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o) T.clear(acc_o)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy( T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + C[
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) batch_idx,
T.copy( chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, bz // (nheads // ngroups),
0:block_Dstate], prev_state_shared) 0:block_Dstate,
],
C_shared,
)
T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i] acc_o[i, j] *= scale_m_local[i]
...@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch, ...@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch,
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups), cb[
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], batch_idx,
cb_shared) chunk_idx,
bz // (nheads // ngroups),
m_idx * block_M : (m_idx + 1) * block_M,
k * block_K : (k + 1) * block_K,
],
cb_shared,
)
T.copy(cb_shared, cb_local) T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared)
dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local) T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
j] = cb_local[i, T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local) T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j] cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0)
cb_local[i, j], 0)
T.copy( T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + x[
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) batch_idx,
chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_shared,
)
T.gemm(cb_local, x_shared, acc_o) T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz] D_local[0] = D[bz]
T.copy( T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + x[
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], batch_idx,
x_residual_shared) chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_residual_shared,
)
T.copy(x_residual_shared, x_residual_local) T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0] acc_o[i, j] += x_residual_local[i, j] * D_local[0]
...@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch, ...@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch,
T.copy(acc_o, acc_o_shared) T.copy(acc_o, acc_o_shared)
T.copy( T.copy(
acc_o_shared, acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + Output[
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
)
return main return main
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=80, help='heads') parser.add_argument("--heads", type=int, default=80, help="heads")
parser.add_argument('--groups', type=int, default=1, help='groups') parser.add_argument("--groups", type=int, default=1, help="groups")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
parser.add_argument('--dim', type=int, default=64, help='dim') parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument('--dstate', type=int, default=128, help='dstate') parser.add_argument("--dstate", type=int, default=128, help="dstate")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate batch, heads, groups, seq_len, chunk_size, dim, dstate = (
args.batch,
args.heads,
args.groups,
args.seq_len,
args.chunk_size,
args.dim,
args.dstate,
)
nchunks = math.ceil(seq_len / chunk_size) nchunks = math.ceil(seq_len / chunk_size)
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
...@@ -360,8 +382,7 @@ if __name__ == "__main__": ...@@ -360,8 +382,7 @@ if __name__ == "__main__":
D = torch.randn(heads).half().cuda() D = torch.randn(heads).half().cuda()
print("Benchmarking Triton...") print("Benchmarking Triton...")
triton_latency = do_bench( triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}") print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}")
print("Benchmarking Helion...") print("Benchmarking Helion...")
......
...@@ -6,6 +6,7 @@ import tilelang ...@@ -6,6 +6,7 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import autotune from tilelang.autotuner import autotune
from tilelang import jit from tilelang import jit
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
...@@ -61,9 +62,9 @@ def get_configs(args, kwargs): ...@@ -61,9 +62,9 @@ def get_configs(args, kwargs):
M=M, M=M,
N=N, N=N,
K=K, K=K,
in_dtype="float16", in_dtype=T.float16,
out_dtype="float16", out_dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
).with_arch(arch) ).with_arch(arch)
func = carve_template.equivalent_function() func = carve_template.equivalent_function()
...@@ -101,9 +102,7 @@ def get_configs(args, kwargs): ...@@ -101,9 +102,7 @@ def get_configs(args, kwargs):
policy=[T.GemmWarpPolicy.Square], policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False], enable_rasteration=[True, False],
) )
return [{ return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs return configs
...@@ -112,7 +111,9 @@ def get_configs(args, kwargs): ...@@ -112,7 +111,9 @@ def get_configs(args, kwargs):
warmup=3, warmup=3,
rep=20, rep=20,
) )
@jit(out_idx=[2],) @jit(
out_idx=[2],
)
def matmul( def matmul(
M, M,
N, N,
...@@ -154,14 +155,14 @@ def matmul( ...@@ -154,14 +155,14 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
...@@ -176,7 +177,6 @@ def matmul( ...@@ -176,7 +177,6 @@ def matmul(
# Bind x-dimension to block index in N, # Bind x-dimension to block index in N,
# y-dimension to block index in M. # y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K) # Allocate shared memory for B sub-block of shape (block_N, block_K)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment