Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
......@@ -4,7 +4,9 @@ ExtraArgs: []
FormatStyle: file
UseColor: true
WarningsAsErrors: '*'
# FIXME: Use `ExcludeHeaderFilterRegex` instead when all maintainers upgraded their `clang-tidy`
HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*'
# ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$'
# NOTE: there must be no spaces before the '-', so put the comma last.
Checks: >-
......
blank_issues_enabled: false
blank_issues_enabled: true
......@@ -40,7 +40,7 @@ jobs:
timeout-minutes: 30
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 0
submodules: recursive
......@@ -93,7 +93,7 @@ jobs:
name: self-hosted-amd
# Format: [Nightly-]ROCm-<major>.<minor>[.<patch>]. E.g., "ROCm-6.4" or "Nightly-ROCm-7.0".
# Use "Nightly-" prefix to use torch nightly builds.
toolkit: ROCm-6.3
toolkit: Nightly-ROCm-7.1
- tags: [macos-latest]
name: macos-latest
toolkit: Metal # or Nightly-Metal
......@@ -104,7 +104,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 0
submodules: recursive
......@@ -288,35 +288,59 @@ jobs:
echo "Clearing uv cache at ${UV_CACHE_DIR} due to failure."
uv cache clean
- name: Enable core dump generation (Linux / GitHub-hosted runners)
if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }}
run: |
sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P"
sudo sysctl -w kernel.core_uses_pid=0
sudo sysctl -w fs.suid_dumpable=1
sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable
- name: Enable core dump generation (macOS / GitHub-hosted runners)
if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }}
run: |
sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P"
sudo sysctl -w kern.coredump=1
sudo sysctl -w kern.sugid_coredump=1
sysctl kern.corefile kern.coredump kern.sugid_coredump
- name: Install project (wheel form)
run: |
uv pip install -v .
- name: Run clang-tidy
id: clang-tidy
if: runner.os == 'Linux'
run: |
echo "\$ $(command -v clang-tidy) --version" && clang-tidy --version
if [[ -x "$(command -v run-clang-tidy)" ]]; then
echo "Using run-clang-tidy from $(command -v run-clang-tidy)"
CLANG_TIDY=(run-clang-tidy)
else
RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py
echo "Downloading run-clang-tidy script from ${RCT_URL}"
echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script -
CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py)
fi
# Download run-clang-tidy script
RCT_URL=https://raw.githubusercontent.com/llvm/llvm-project/refs/heads/release/21.x/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py
echo "Downloading run-clang-tidy script from ${RCT_URL}"
echo "import urllib.request; url = '${RCT_URL}'.rstrip('/'); urllib.request.urlretrieve(url, url.split('/')[-1])" | uv run --no-project --script -
RUN_CLANG_TIDY=(uv run --no-project --script -- run-clang-tidy.py)
if [[ -x "$(command -v clang-apply-replacements)" ]]; then
echo "Using clang-apply-replacements from $(command -v clang-apply-replacements)"
CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)")
RUN_CLANG_TIDY+=(-fix -clang-apply-replacements-binary="$(command -v clang-apply-replacements)")
else
echo "::warning::clang-apply-replacements not found in PATH, automatic fixing disabled."
fi
# Run cmake to create the build directory with compile_commands.json
cmake -S . -B cmake-build --fresh ${CLANG_TIDY_CMAKE_OPTIONS} # no quotes here
echo "::group::compile_commands.json"
ls -alh cmake-build/compile_commands.json
uv run --no-project -m -- json.tool --no-ensure-ascii cmake-build/compile_commands.json
echo "::endgroup::"
CXX_FILES=$(find src -type f -iname "*.[ch]pp" -o -iname "*.cc" -o -iname "*.c" -o -iname "*.h")
rc=0
"${CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \
echo "::group::run-clang-tidy"
"${RUN_CLANG_TIDY[@]}" -clang-tidy-binary="$(command -v clang-tidy)" \
-exclude-header-filter='^(3rdparty|tvm)/.*$' \
-p="cmake-build" ${CXX_FILES} || rc="$?"
echo "::endgroup::"
rm -rf cmake-build run-clang-tidy.py
if (( rc != 0 )); then
echo "::error::clang-tidy found issues (exit code: ${rc}). Please run 'clang-tidy --fix' locally to fix them."
......@@ -324,26 +348,6 @@ jobs:
exit "${rc}"
fi
- name: Enable core dump generation (Linux / GitHub-hosted runners)
if: ${{ runner.os == 'Linux' && !startsWith(matrix.runner.name, 'self-hosted') }}
run: |
sudo sysctl -w kernel.core_pattern="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P"
sudo sysctl -w kernel.core_uses_pid=0
sudo sysctl -w fs.suid_dumpable=1
sysctl kernel.core_pattern kernel.core_uses_pid fs.suid_dumpable
- name: Enable core dump generation (macOS / GitHub-hosted runners)
if: ${{ runner.os == 'macOS' && !startsWith(matrix.runner.name, 'self-hosted') }}
run: |
sudo sysctl -w kern.corefile="core.${{ matrix.python-version }}.${{ matrix.runner.toolkit }}.%P"
sudo sysctl -w kern.coredump=1
sudo sysctl -w kern.sugid_coredump=1
sysctl kern.corefile kern.coredump kern.sugid_coredump
- name: Install project (wheel form)
run: |
uv pip install -v .
- name: Run examples with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
if: contains(matrix.runner.toolkit, 'CUDA')
run: |
......@@ -366,8 +370,27 @@ jobs:
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
--ignore=./python/jit/test_tilelang_jit_cutedsl.py \
./python
# CuTeDSL JIT tests require GEMM v1 (must be set before importing tilelang).
# Run them in a dedicated step to avoid changing the default GEMM selection
# (and to keep the rest of the CUDA tests on GEMM v2).
- name: Run CuTeDSL JIT tests (GEMM v1) with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: cutedsl-tests
if: contains(matrix.runner.toolkit, 'CUDA')
env:
TILELANG_USE_GEMM_V1: "1"
run: |
cd testing
PYTEST=(
uv run --no-project -m --
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
# Avoid xdist contention on a single GPU by running this file in one worker.
"${PYTEST[@]}" --maxfail=3 --numprocesses=1 \
./python/jit/test_tilelang_jit_cutedsl.py
# AMD ROCm tests
- name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: rocm-tests
......
name: Dist
on:
workflow_dispatch:
schedule:
# gemini said this is 6:00 china time
- cron: "0 22 * * *"
......@@ -52,7 +53,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 1
submodules: recursive
......@@ -91,7 +92,7 @@ jobs:
- name: Upload SDist
# Not PR to save artifact storage, as SDist is only needed for releases.
if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]')
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: sdist
path: dist/*.tar.gz
......@@ -105,14 +106,12 @@ jobs:
strategy:
matrix:
target:
- { runner: ubuntu-latest, toolkit: "CUDA-12.1" }
- { runner: ubuntu-latest, toolkit: "CUDA-12.8" }
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" }
- { runner: macos-latest, toolkit: "Metal" }
python-version:
# Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8.
# Only build wheels against Python 3.8 Limited API to save CI resources.
# FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support
# Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9.
- "3.9"
fail-fast: false
timeout-minutes: 120
......@@ -122,7 +121,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 1
submodules: recursive
......@@ -160,7 +159,7 @@ jobs:
fi
- name: Build wheels
uses: pypa/cibuildwheel@v3.2
uses: pypa/cibuildwheel@v3.3
with:
package-dir: .
output-dir: wheelhouse
......@@ -169,7 +168,7 @@ jobs:
- name: Upload wheels
# Not PR to save artifact storage, as wheels are only needed for releases.
if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]')
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: wheels-${{ matrix.python-version }}-${{ runner.os }}-${{ runner.arch }}-${{ matrix.target.toolkit }}
path: wheelhouse/*.whl
......@@ -184,7 +183,7 @@ jobs:
timeout-minutes: 15
steps:
- name: Download built SDist
uses: actions/download-artifact@v6
uses: actions/download-artifact@v7
with:
# unpacks default artifact into dist/
# if `name: artifact` is omitted, the action will create extra parent dir
......@@ -192,7 +191,7 @@ jobs:
path: dist
- name: Download built wheels
uses: actions/download-artifact@v6
uses: actions/download-artifact@v7
with:
pattern: wheels-*
path: dist
......@@ -202,7 +201,7 @@ jobs:
run: ls -lh dist/*
- name: Upload artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: artifacts
path: dist/*
......
......@@ -33,7 +33,7 @@ jobs:
runs-on: [self-hosted, nvidia]
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
ref: refs/pull/${{ github.event.issue.number }}/merge
fetch-depth: 0
......
......@@ -25,7 +25,7 @@ jobs:
runs-on: [self-hosted, nvidia]
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 0
submodules: recursive
......
......@@ -108,3 +108,15 @@ cmake-build-*/
# pre-commit cache
.pre-commit-cache/*
# host checks logs
maint/host_checks/logs/*
# ncu
*.ncu-rep
# csv
*.csv
# clang-tidy
/run-clang-tidy.py
......@@ -32,30 +32,17 @@ repos:
args: [--ignore-case]
files: ^docs/spelling_wordlist\.txt$
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v21.1.2 # sync with requirements-lint.txt
rev: v21.1.7 # sync with requirements-lint.txt
hooks:
- id: clang-format
exclude: |
(?ix)(
^.+\.(cu|cuh)$|
^.+\.json$
)
types_or: [c++, c]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3 # sync with requirements-lint.txt
rev: v0.14.9 # sync with requirements-lint.txt
hooks:
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/google/yapf
rev: v0.43.0 # sync with requirements-lint.txt
hooks:
- id: yapf
name: yapf-multiproc-bugfix
# yapf is not multiprocess safe, so we run a dummy yapf first.
args: [--in-place, docs/conf.py]
always_run: true
pass_filenames: false
- id: yapf
args: [--recursive, --in-place]
- id: ruff-format
args: [--exit-non-zero-on-format]
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1 # sync with requirements-lint.txt
hooks:
......
Subproject commit 1c45ca35dd5c215e0c1db1f40f01556f467f52a8
Subproject commit b38bb492a1a55b5abb0c345962143c0f9c482cfb
Subproject commit 093b2cdb2187140b197336496d65d61ace89e8ff
Subproject commit 79ed747db67e60d3a1889d8afd33473bc2424ade
......@@ -136,14 +136,21 @@ file(GLOB TILE_LANG_SRCS
src/*.cc
src/layout/*.cc
src/transform/*.cc
src/transform/common/*.cc
src/op/*.cc
src/target/utils.cc
src/target/codegen_c_host.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
# intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc
)
# Always include CPU-safe runtime helpers
list(APPEND TILE_LANG_SRCS
src/runtime/error_helpers.cc
)
# Track if the user explicitly selected a backend via cache options.
set(TILELANG_BACKEND_USER_SELECTED OFF)
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
......@@ -205,16 +212,28 @@ elseif(USE_CUDA)
cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA)
file(GLOB TILE_LANG_CUDA_SRCS
src/runtime/*.cc
src/runtime/runtime.cc
src/target/ptx.cc
src/target/codegen_cuda.cc
src/target/codegen_py.cc
src/target/codegen_utils.cc
src/target/codegen_cutedsl.cc
src/target/rt_mod_cuda.cc
src/target/rt_mod_cutedsl.cc
)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS})
list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS})
endif()
set(USE_Z3 ON CACHE STRING "Use Z3 SMT solver for TileLang optimizations")
set(USE_PYPI_Z3 ON CACHE BOOL "Use Z3 provided by PyPI z3-solver package")
if(USE_Z3 AND USE_PYPI_Z3)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/pypi-z3")
find_package(Z3 REQUIRED)
endif()
# Include tvm after configs have been populated
add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL)
......@@ -222,7 +241,11 @@ add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL)
add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS})
# Set debug mode compile definitions
# We open the deubg option of TVM, i.e. TVM_LOG_DEBUG
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
message(STATUS "Building TileLang with DEBUG mode")
target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG")
endif()
......@@ -232,6 +255,18 @@ add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
add_library(tilelang_module SHARED $<TARGET_OBJECTS:tilelang_objs>)
target_link_libraries(tilelang PUBLIC tvm_runtime tvm)
target_link_libraries(tilelang_module PUBLIC tvm)
# Place dev build outputs under build/lib for consistency
set_target_properties(tilelang PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
)
set_target_properties(tilelang_module PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
)
# Build cython extension
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT})
......@@ -251,25 +286,46 @@ if(NOT "${SKBUILD_SABI_VERSION}" STREQUAL "")
endif()
python_add_library(tilelang_cython_wrapper MODULE "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" ${USE_SABI} WITH_SOABI)
# Install extension into the tilelang package directory
# Ensure dev builds drop the extension into build/lib alongside other shared libs
set_target_properties(tilelang_cython_wrapper PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
)
# Install the extension into tilelang/lib inside the wheel
install(TARGETS tilelang_cython_wrapper
LIBRARY DESTINATION tilelang
RUNTIME DESTINATION tilelang
ARCHIVE DESTINATION tilelang)
LIBRARY DESTINATION tilelang/lib
RUNTIME DESTINATION tilelang/lib
ARCHIVE DESTINATION tilelang/lib)
# add python z3 lib path to rpath for running tests and dev in current folder
if(USE_Z3 AND USE_PYPI_Z3)
set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/lib)
set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/bin)
endif()
# let libtilelang to search tvm/tvm_runtime in same dir
if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set(TILELANG_INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
if(USE_Z3 AND USE_PYPI_Z3)
# some z3 is placed in lib/ and some in bin/, we add both in rpath
list(APPEND TILELANG_INSTALL_RPATH "@loader_path/../../z3/lib" "@loader_path/../../z3/bin")
endif()
elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set(TILELANG_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
if(USE_Z3 AND USE_PYPI_Z3)
# cmake uses ; by default, we explicitly use : for linux
string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../z3/lib")
endif()
endif()
# let libtilelang to search tvm/tvm_runtime in same dir
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
install(
TARGETS tvm tvm_runtime tilelang_module tilelang
LIBRARY DESTINATION tilelang/lib
......
......@@ -81,6 +81,8 @@ in the main directory. This installation is removable by:
python3 -m pip uninstall tilelang
```
We also recommend installing TileLang in a more manual way for better control over the build process, by compiling the C++ extensions first and set the `PYTHONPATH`. See [Working from Source via `PYTHONPATH`](https://tilelang.com/get_started/Installation.html#working-from-source-via-pythonpath) for detailed instructions.
## Lint Check
To check the linting, run:
......
......@@ -13,6 +13,9 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png />
## Latest News
- 12/18/2025 🚀: Added [CuTeDSL backend](https://github.com/tile-ai/tilelang/pull/1421) support, enabling compilation to NVIDIA CUTLASS CuTe DSL! Join us in building and optimizing this exciting new backend: [Issue #1454](https://github.com/tile-ai/tilelang/issues/1454).
- 12/17/2025 🔬: Integrated [Z3 theorem prover](https://github.com/tile-ai/tilelang/pull/1367) into TVM Arith Analyzer, bringing SMT-based symbolic reasoning for enhanced optimizations and automatic correctness verification!
- 10/31/2025 🔧: Migrated to [apache-tvm-ffi](https://github.com/tile-ai/tilelang/pull/1108), significantly reducing CPU overhead!
- 10/30/2025 📦: We have released v0.1.6.post2, which is the last version compatible with Python 3.8.
- 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details.
- 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported!
......@@ -137,7 +140,7 @@ import tilelang.language as T
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func
def matmul_relu_kernel(
......@@ -209,7 +212,7 @@ torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# cuda_source = matmul_relu_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
......
......@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def benchmark_topk_sparse_attention():
from benchmark_configs import configs
torch.manual_seed(0)
# Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
import flash_attn
......
......@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -39,16 +36,15 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N = 64
num_stages = 2
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "bool"
dtype = T.float16
accum_dtype = T.float32
block_mask_dtype = T.bool
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
......@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -79,22 +74,24 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
......@@ -114,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -144,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -153,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k]:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
......@@ -175,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
def benchmark_topk_sparse_attention():
from benchmark_configs import configs
torch.manual_seed(0)
# Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
program = blocksparse_flashattn(
BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=4)
def benchmark_fn():
......
......@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def benchmark_topk_sparse_attention():
from benchmark_configs import configs
torch.manual_seed(0)
# Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
def benchmark_fn():
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
return ref_output
ref_latency = do_bench(
......
......@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
if mask_val == True:
......@@ -72,8 +68,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0,
float('-inf'))
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
......@@ -153,7 +148,7 @@ def _fwd_kernel(
v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
......@@ -191,24 +186,12 @@ def _fwd_kernel(
acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[
None, :] * stride_od
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx,
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous()
......@@ -253,7 +236,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints
......@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply
def benchmark_topk_sparse_attention():
from benchmark_configs import configs
torch.manual_seed(0)
# Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......
......@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril(
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
out = torch.einsum(
"bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
out_prev = (
torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
)
out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None:
......@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
@helion.kernel()
def helion_mamba2_chunk_scan_kernel(
cb: torch.Tensor,
......@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
dtype = cb.dtype
accum_dtype = torch.float32
assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype ==
dtype)
assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype
out = torch.empty_like(x)
......@@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile(
[nheads, chunk_size, headdim, batch, nchunks],
block_size=[1, block_m, block_n, 1, 1],
block_size=[1, block_m, block_n, 1, 1],
):
acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype)
dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin,
tile_m].to(torch.float32)
dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32)
scale_m_local = torch.exp2(dA_cumsum_local_m * p)
C_local = C[
......@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
tile_m,
tile_k,
]
dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin,
tile_k].to(torch.float32)
cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p -
dA_cumsum_local_k[None, :] * p)
dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p)
dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
cb_local = (cb_local * dt_local[None, :]).to(dtype)
pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :]
......@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
acc_o = hl.dot(cb_local, x_local, acc=acc_o)
D_local = D[tile_h.begin].to(torch.float32)
x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin,
tile_n].to(torch.float32)
x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32)
acc_o += x_residual * D_local
out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin,
tile_n] = acc_o.to(dtype=dtype)
out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype)
return out
......@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
def get_configs():
iter_params = dict(
block_M=[64, 128, 256],
block_N=[32, 64],
block_K=[64, 128, 256],
block_Dstate=[128],
num_stages=[1, 2, 3, 4, 5])
iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
......@@ -198,40 +187,42 @@ def get_configs():
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def chunk_scan_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128):
dtype = "float16"
accum_dtype = "float"
def chunk_scan_fwd(
batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128,
):
dtype = T.float16
accum_dtype = T.float32
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
@T.prim_func
def main(
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as (
bz,
bx,
by,
):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
......@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch,
m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
})
T.annotate_layout(
{
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared),
}
)
T.no_set_max_nreg()
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o)
for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
T.copy(
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
0:block_Dstate], prev_state_shared)
C[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz // (nheads // ngroups),
0:block_Dstate,
],
C_shared,
)
T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i]
......@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch,
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
cb_shared)
cb[
batch_idx,
chunk_idx,
bz // (nheads // ngroups),
m_idx * block_M : (m_idx + 1) * block_M,
k * block_K : (k + 1) * block_K,
],
cb_shared,
)
T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cs_k_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i,
j] = cb_local[i,
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
cb_local[i, j], 0)
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0)
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
x[
batch_idx,
chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_shared,
)
T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz]
T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
x_residual_shared)
x[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_residual_shared,
)
T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0]
......@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch,
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N])
Output[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
)
return main
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=80, help='heads')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--dstate', type=int, default=128, help='dstate')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=80, help="heads")
parser.add_argument("--groups", type=int, default=1, help="groups")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument("--dstate", type=int, default=128, help="dstate")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
batch, heads, groups, seq_len, chunk_size, dim, dstate = (
args.batch,
args.heads,
args.groups,
args.seq_len,
args.chunk_size,
args.dim,
args.dstate,
)
nchunks = math.ceil(seq_len / chunk_size)
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
......@@ -360,8 +382,7 @@ if __name__ == "__main__":
D = torch.randn(heads).half().cuda()
print("Benchmarking Triton...")
triton_latency = do_bench(
lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}")
print("Benchmarking Helion...")
......
......@@ -6,6 +6,7 @@ import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
......@@ -61,9 +62,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float32,
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -101,9 +102,7 @@ def get_configs(args, kwargs):
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
return configs
......@@ -112,7 +111,9 @@ def get_configs(args, kwargs):
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
@jit(
out_idx=[2],
)
def matmul(
M,
N,
......@@ -154,14 +155,14 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
......@@ -176,7 +177,6 @@ def matmul(
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment