Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
...@@ -7,7 +7,7 @@ env: ...@@ -7,7 +7,7 @@ env:
jobs: jobs:
format-check: format-check:
runs-on: ubuntu-latest runs-on: self-hosted
permissions: permissions:
contents: write contents: write
...@@ -26,21 +26,37 @@ jobs: ...@@ -26,21 +26,37 @@ jobs:
with: with:
python-version: ${{ env.PYTHON_VERSION }} python-version: ${{ env.PYTHON_VERSION }}
- name: Install dependencies - name: Ensure venv (local & persistent)
run: | run: |
python -m pip install --upgrade pip set -e
pip install yapf==0.40.2 toml==0.10.2 tomli==2.0.1 ruff==0.6.5 codespell==2.3.0 clang-format==15.0.7 REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
echo "venv exists and hash matches – reuse it"
else
echo "venv stale or missing – recreating"
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
# shellcheck source=/dev/null
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install . --no-user
touch "$MARKER"
fi
- name: Run format check - name: Run format check
run: | run: |
git clone https://github.com/tile-ai/tilelang.git main_repo source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cp main_repo/format.sh .
rm -rf main_repo
if ! output=$(./format.sh 2>&1); then if ! output=$(./format.sh 2>&1); then
echo "------------------------------------"
echo "message:" echo "message:"
echo "$output" echo "$output"
echo "------------------------------------"
printf '%s\n' "$output" | grep "Please review and stage the changes." printf '%s\n' "$output" | grep "Please review and stage the changes."
echo "------------------------------------"
exit 1
fi fi
- name: Commit and Push Changes - name: Commit and Push Changes
......
Subproject commit 979c8e7f94473db7d71a41b26ccf51db7e17a734 Subproject commit a08b7c34d4a59f89f4dea252fa1a7e458e298ef0
...@@ -11,6 +11,14 @@ endif() ...@@ -11,6 +11,14 @@ endif()
# Enable compile command export # Enable compile command export
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if(NOT Python_EXECUTABLE)
execute_process(
COMMAND which python
OUTPUT_VARIABLE Python_EXECUTABLE
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(Python_EXECUTABLE "${Python_EXECUTABLE}" CACHE FILEPATH "Path to the Python executable")
endif()
# Define a custom macro for globbing files with conditional CONFIGURE_DEPENDS # Define a custom macro for globbing files with conditional CONFIGURE_DEPENDS
if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.12.0") if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.12.0")
...@@ -39,7 +47,8 @@ else() ...@@ -39,7 +47,8 @@ else()
# Set default build type to RelWithDebInfo if not provided # Set default build type to RelWithDebInfo if not provided
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "Build type" FORCE) # Set default build type to Release if not provided
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
message(STATUS "Setting default build type to ${CMAKE_BUILD_TYPE}") message(STATUS "Setting default build type to ${CMAKE_BUILD_TYPE}")
endif() endif()
endif() endif()
...@@ -145,6 +154,7 @@ message(STATUS "TVM_SOURCE_DIR: ${TVM_SOURCE_DIR}") ...@@ -145,6 +154,7 @@ message(STATUS "TVM_SOURCE_DIR: ${TVM_SOURCE_DIR}")
# Include directories for TileLang # Include directories for TileLang
set(TILE_LANG_INCLUDES set(TILE_LANG_INCLUDES
${TVM_SOURCE_DIR}/include ${TVM_SOURCE_DIR}/include
${TVM_SOURCE_DIR}/ffi/include
${TVM_SOURCE_DIR}/src ${TVM_SOURCE_DIR}/src
${TVM_SOURCE_DIR}/3rdparty/dlpack/include ${TVM_SOURCE_DIR}/3rdparty/dlpack/include
${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include ${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include
...@@ -212,6 +222,11 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug") ...@@ -212,6 +222,11 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
target_compile_definitions(tilelang_static PRIVATE "TVM_LOG_DEBUG") target_compile_definitions(tilelang_static PRIVATE "TVM_LOG_DEBUG")
endif() endif()
# Building tvm_cython modules
if(NOT DEFINED TVM_PREBUILD_PATH)
add_dependencies(tilelang tvm_cython)
endif()
# Module shared library # Module shared library
add_library(tilelang_module SHARED $<TARGET_OBJECTS:tilelang_objs>) add_library(tilelang_module SHARED $<TARGET_OBJECTS:tilelang_objs>)
target_link_libraries(tilelang_module PUBLIC tvm) target_link_libraries(tilelang_module PUBLIC tvm)
......
...@@ -54,10 +54,8 @@ def get_configs(args, kwargs): ...@@ -54,10 +54,8 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
import torch import torch
if torch.version.hip is not None: arch = CDNA("hip") if torch.version.hip is not None else CUDA("cuda")
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
...@@ -158,7 +156,7 @@ def matmul( ...@@ -158,7 +156,7 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "e4m3_float8" dtype = "float8_e4m3"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
......
lm_eval==0.3.0 lm_eval==0.3.0
flash_attn flash_attn
transformers==4.52.1 transformers==4.52.1
\ No newline at end of file
...@@ -17,7 +17,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -17,7 +17,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
@T.prim_func @T.prim_func
def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor( def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor(
(BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "e4m3_float8"), X_amax: T.Tensor( (BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor(
(BG, M_max, T.ceildiv(N, group_size)), accum_dtype)): (BG, M_max, T.ceildiv(N, group_size)), accum_dtype)):
with T.Kernel( with T.Kernel(
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
...@@ -28,7 +28,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -28,7 +28,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), accum_dtype) y_amax_local = T.alloc_fragment((blk_m,), accum_dtype)
y_s_local = T.alloc_fragment((blk_m,), accum_dtype) y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "e4m3_float8") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_local((1,), "int32") row_offset = T.alloc_local((1,), "int32")
T.annotate_layout({ T.annotate_layout({
......
...@@ -15,7 +15,7 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -15,7 +15,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max = 448.0 fp8_max = 448.0
@T.prim_func @T.prim_func
def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "e4m3_float8"), def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"),
X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)): X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx row = bx
...@@ -24,7 +24,7 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -24,7 +24,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), dtype) y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype) y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype) y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "e4m3_float8") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
T.annotate_layout({ T.annotate_layout({
y_local: y_local:
......
...@@ -20,8 +20,8 @@ def tl_gemm( ...@@ -20,8 +20,8 @@ def tl_gemm(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"e4m3_float8", "float8_e4m3",
], "Currently only e4m3_float8 is supported" ], "Currently only float8_e4m3 is supported"
assert out_dtype in [ assert out_dtype in [
"bfloat16", "bfloat16",
"float32", "float32",
...@@ -179,11 +179,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp ...@@ -179,11 +179,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
def main(): def main():
assert_tl_gemm_correctness(1024, 1024, 8192, 128, "e4m3_float8", "bfloat16", "float32") assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32")
if __name__ == "__main__": if __name__ == "__main__":
for dtype in ["e4m3_float8"]: for dtype in ["float8_e4m3"]:
for out_dtype in ["bfloat16", "float32"]: for out_dtype in ["bfloat16", "float32"]:
for block_N in [16, 32, 64, 128]: for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32") assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32")
...@@ -11,7 +11,7 @@ import argparse ...@@ -11,7 +11,7 @@ import argparse
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
q_dtype = "e4m3_float8" q_dtype = "float8_e4m3"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
......
...@@ -57,8 +57,8 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -57,8 +57,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main(): def main():
test_gemm_fp8(1024, 1024, 1024, 'e4m3_float8') test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 1024, 'e5m2_float8') test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main(): def main():
test_gemm_fp8(1024, 1024, 8192, 'e4m3_float8') test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 8192, 'e5m2_float8') test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -40,8 +40,8 @@ def tl_matmul( ...@@ -40,8 +40,8 @@ def tl_matmul(
): ):
assert in_dtype in [ assert in_dtype in [
"float16", "float16",
"e4m3_float8", "float8_e4m3",
"e5m2_float8", "float8_e5m2",
"int8", "int8",
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
...@@ -52,7 +52,7 @@ def tl_matmul( ...@@ -52,7 +52,7 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"] is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
if out_dtype == "int32" or is_float8: if out_dtype == "int32" or is_float8:
micro_size_k = 32 micro_size_k = 32
...@@ -216,8 +216,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -216,8 +216,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def main(): def main():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32") assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32") assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -9,6 +9,7 @@ import argparse ...@@ -9,6 +9,7 @@ import argparse
tilelang.disable_cache() tilelang.disable_cache()
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
...@@ -79,7 +80,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -79,7 +80,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128) p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128)
lse_0_ready_barrier = T.alloc_barrier(arrive_count=128) lse_0_ready_barrier = T.alloc_barrier(arrive_count=128)
lse_1_ready_barrier = T.alloc_barrier(arrive_count=128) lse_1_ready_barrier = T.alloc_barrier(arrive_count=128)
s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256) q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
k_pe_shared_1_free_barrier = T.alloc_barrier(arrive_count=128) k_pe_shared_1_free_barrier = T.alloc_barrier(arrive_count=128)
k_pe_shared_0_free_barrier = T.alloc_barrier(arrive_count=128) k_pe_shared_0_free_barrier = T.alloc_barrier(arrive_count=128)
...@@ -401,8 +401,7 @@ def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64): ...@@ -401,8 +401,7 @@ def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
BLOCK_H = 64 BLOCK_H = 64
num_split = 1 num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
......
# Should be mirrored in pyproject.toml # Should be mirrored in pyproject.toml
Cython
build build
cmake>=3.26 cmake>=3.26
packaging packaging
......
# lint requirements # lint requirements
-r requirements-lint.txt -r requirements-lint.txt
# build requirements # build requirements
Cython
cmake>=3.26 cmake>=3.26
# runtime requirements # runtime requirements
cffi cffi
......
...@@ -815,7 +815,7 @@ class TilelangExtensionBuild(build_ext): ...@@ -815,7 +815,7 @@ class TilelangExtensionBuild(build_ext):
# -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go # -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go
# -DPYTHON_EXECUTABLE ensures that the correct Python is used # -DPYTHON_EXECUTABLE ensures that the correct Python is used
cmake_args = [ cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", f"-DPYTHON_EXECUTABLE={sys.executable}", f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", f"-DPython_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={'Debug' if DEBUG_MODE else 'Release'}" f"-DCMAKE_BUILD_TYPE={'Debug' if DEBUG_MODE else 'Release'}"
] ]
if not USE_ROCM: if not USE_ROCM:
......
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
#include "./transform/common/attr.h" #include "./transform/common/attr.h"
#include "op/builtin.h" #include "op/builtin.h"
#include "tvm/ffi/any.h"
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h> #include <tvm/script/ir_builder/tir/ir.h>
namespace tvm { namespace tvm {
...@@ -65,7 +67,7 @@ ForFrame ParallelFor(Array<PrimExpr> extents, ...@@ -65,7 +67,7 @@ ForFrame ParallelFor(Array<PrimExpr> extents,
Var var = vars[i]; Var var = vars[i];
body = body =
For(var, dom->min, dom->extent, ForKind::kParallel, std::move(body), For(var, dom->min, dom->extent, ForKind::kParallel, std::move(body),
/*thread_binding=*/NullOpt, /*annotations=*/annotations); /*thread_binding=*/std::nullopt, /*annotations=*/annotations);
} }
return body; return body;
}; };
...@@ -99,7 +101,7 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages, ...@@ -99,7 +101,7 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages,
anno.Set("tl_pipeline_group", groups); anno.Set("tl_pipeline_group", groups);
body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial,
std::move(body), std::move(body),
/*thread_binding=*/NullOpt, /*annotations=*/anno); /*thread_binding=*/std::nullopt, /*annotations=*/anno);
return body; return body;
}; };
return ForFrame(n); return ForFrame(n);
...@@ -157,7 +159,7 @@ ForFrame PersistentFor(Array<PrimExpr> domain, PrimExpr wave_size, ...@@ -157,7 +159,7 @@ ForFrame PersistentFor(Array<PrimExpr> domain, PrimExpr wave_size,
Stmt()); Stmt());
Stmt outer = For(loop_var, 0, waves, ForKind::kSerial, Stmt outer = For(loop_var, 0, waves, ForKind::kSerial,
SeqStmt({out_if, body}), NullOpt, anno); SeqStmt({out_if, body}), std::nullopt, anno);
for (int i = 0; i < vars.size() - 1; ++i) { for (int i = 0; i < vars.size() - 1; ++i) {
outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer); outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
} }
...@@ -178,9 +180,10 @@ class KernelLaunchFrameNode : public TIRFrameNode { ...@@ -178,9 +180,10 @@ class KernelLaunchFrameNode : public TIRFrameNode {
public: public:
Array<TIRFrame> frames; Array<TIRFrame> frames;
void VisitAttrs(tvm::AttrVisitor *v) { static void RegisterReflection() {
TIRFrameNode::VisitAttrs(v); namespace refl = tvm::ffi::reflection;
v->Visit("frames", &frames); refl::ObjectDef<KernelLaunchFrameNode>().def_ro(
"frames", &KernelLaunchFrameNode::frames);
} }
static constexpr const char *_type_key = "tl.KernelLaunchFrame"; static constexpr const char *_type_key = "tl.KernelLaunchFrame";
...@@ -213,14 +216,16 @@ public: ...@@ -213,14 +216,16 @@ public:
}; };
KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size, KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
Array<PrimExpr> block_size, Optional<Array<PrimExpr>> block_size_opt,
Map<String, ObjectRef> attrs) { Map<String, ffi::Any> attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>(); ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads. // If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame = bool is_cpu_kernel_frame =
attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame); attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame);
auto block_size = block_size_opt.value_or(Array<PrimExpr>());
if (is_cpu_kernel_frame) { if (is_cpu_kernel_frame) {
// Launch CPU Kernel // Launch CPU Kernel
ICHECK(grid_size.size() >= 0); ICHECK(grid_size.size() >= 0);
...@@ -279,18 +284,23 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size, ...@@ -279,18 +284,23 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode); TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);
TVM_REGISTER_GLOBAL("tl.Parallel").set_body_typed(ParallelFor); TVM_FFI_STATIC_INIT_BLOCK({
TVM_REGISTER_GLOBAL("tl.Pipelined").set_body_typed(PipelinedFor); namespace refl = tvm::ffi::reflection;
TVM_REGISTER_GLOBAL("tl.Persistent").set_body_typed(PersistentFor); refl::GlobalDef()
TVM_REGISTER_GLOBAL("tl.KernelLaunch").set_body_typed(KernelLaunch); .def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch);
});
class WarpSpecializeFrameNode : public TIRFrameNode { class WarpSpecializeFrameNode : public TIRFrameNode {
public: public:
Array<TIRFrame> frames; Array<TIRFrame> frames;
void VisitAttrs(tvm::AttrVisitor *v) { static void RegisterReflection() {
TIRFrameNode::VisitAttrs(v); namespace refl = tvm::ffi::reflection;
v->Visit("frames", &frames); refl::ObjectDef<WarpSpecializeFrameNode>().def_ro(
"frames", &WarpSpecializeFrameNode::frames);
} }
static constexpr const char *_type_key = "tl.WarpSpecializeFrame"; static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
...@@ -359,7 +369,12 @@ WarpSpecializeFrame WarpSpecialize(Array<IntImm> warp_group_ids, ...@@ -359,7 +369,12 @@ WarpSpecializeFrame WarpSpecialize(Array<IntImm> warp_group_ids,
} }
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode); TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
TVM_REGISTER_GLOBAL("tl.WarpSpecialize").set_body_typed(WarpSpecialize); TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection();
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include "layout.h" #include "layout.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/arith/pattern.h> #include <tvm/arith/pattern.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -73,9 +74,11 @@ Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) { ...@@ -73,9 +74,11 @@ Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
data_ = std::move(n); data_ = std::move(n);
} }
void LayoutNode::VisitAttrs(AttrVisitor *v) { void LayoutNode::RegisterReflection() {
v->Visit("input_size", &input_size_); namespace refl = tvm::ffi::reflection;
v->Visit("forward_index", &forward_index_); refl::ObjectDef<LayoutNode>()
.def_ro("input_size", &LayoutNode::input_size_)
.def_ro("forward_index", &LayoutNode::forward_index_);
} }
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const { void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
...@@ -155,7 +158,7 @@ Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats, ...@@ -155,7 +158,7 @@ Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
auto new_forward_thread = auto new_forward_thread =
Substitute(forward_thread_, vmap) + thread_size * repeats_index; Substitute(forward_thread_, vmap) + thread_size * repeats_index;
return Fragment(new_input_size, new_forward_index, new_forward_thread, return Fragment(new_input_size, new_forward_index, new_forward_thread,
replicate_size_, NullOpt); replicate_size_, std::nullopt);
} else { } else {
ICHECK(OutputDim() == 1); ICHECK(OutputDim() == 1);
PrimExpr frag_len = OutputShape()[0]; PrimExpr frag_len = OutputShape()[0];
...@@ -163,7 +166,7 @@ Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats, ...@@ -163,7 +166,7 @@ Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
frag_len * repeats_index}; frag_len * repeats_index};
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap); PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
return Fragment(new_input_size, new_forward_index, new_forward_thread, return Fragment(new_input_size, new_forward_index, new_forward_thread,
replicate_size_, NullOpt); replicate_size_, std::nullopt);
} }
} }
...@@ -176,7 +179,7 @@ Fragment FragmentNode::Replicate(int repeats) const { ...@@ -176,7 +179,7 @@ Fragment FragmentNode::Replicate(int repeats) const {
Substitute(forward_thread_, vmap) + Substitute(forward_thread_, vmap) +
ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent()); ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
return Fragment(input_size_, forward_index_, new_forward_thread, return Fragment(input_size_, forward_index_, new_forward_thread,
ReplicateExtent() * repeats, NullOpt); ReplicateExtent() * repeats, std::nullopt);
} }
Fragment FragmentNode::DeReplicate() const { Fragment FragmentNode::DeReplicate() const {
...@@ -198,7 +201,7 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -198,7 +201,7 @@ Fragment FragmentNode::DeReplicate() const {
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap); PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)}; Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
return Fragment(input_size_, new_forward_index, new_forward_thread, return Fragment(input_size_, new_forward_index, new_forward_thread,
int(*rep_size) / factor, NullOpt); int(*rep_size) / factor, std::nullopt);
} }
Fragment FragmentNode::BindThreadRange(Range thread_range) const { Fragment FragmentNode::BindThreadRange(Range thread_range) const {
...@@ -304,18 +307,11 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, ...@@ -304,18 +307,11 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
data_ = std::move(n); data_ = std::move(n);
} }
void FragmentNode::VisitAttrs(tvm::AttrVisitor *v) {
LayoutNode::VisitAttrs(v);
v->Visit("forward_thread", &forward_thread_);
v->Visit("replicate_size", &replicate_size_);
}
PrimExpr FragmentNode::ThreadExtent() const { PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1); Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer; arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer); UpdateAnalyzer(&analyzer);
auto ist = analyzer.int_set(forward_thread_ + 1); auto ist = analyzer.int_set(forward_thread_ + 1);
// CHECK(is_one(ist.min()));
return ist.max(); return ist.max();
} }
...@@ -435,64 +431,69 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const { ...@@ -435,64 +431,69 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
return ret; return ret;
} }
void FragmentNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FragmentNode>()
.def_ro("forward_thread", &FragmentNode::forward_thread_)
.def_ro("replicate_size", &FragmentNode::replicate_size_);
}
TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode); TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue *ret) { TVM_FFI_STATIC_INIT_BLOCK({
*ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1])); namespace refl = tvm::ffi::reflection;
}); refl::GlobalDef()
.def_packed("tl.Layout",
TVM_REGISTER_GLOBAL("tl.Layout_input_shape").set_body_typed([](Layout layout) { [](PackedArgs args, Any *rv) {
return layout->InputShape(); *rv = Layout(args[0].cast<Array<IterVar>>(),
}); args[1].cast<Array<PrimExpr>>());
})
TVM_REGISTER_GLOBAL("tl.Layout_output_shape").set_body_typed([](Layout layout) { .def("tl.Layout_input_shape",
return layout->OutputShape(); [](Layout layout) { return layout->InputShape(); })
}); .def("tl.Layout_output_shape",
[](Layout layout) { return layout->OutputShape(); })
TVM_REGISTER_GLOBAL("tl.Layout_inverse").set_body_typed([](Layout layout) { .def("tl.Layout_inverse", [](Layout layout) { return layout->Inverse(); })
return layout->Inverse(); .def("tl.Layout_index",
}); [](Layout layout) { return layout->GetForwardIndex(); })
.def("tl.Layout_forward_vars",
TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) { [](Layout layout) { return layout->GetForwardVars(); })
return layout->GetForwardIndex(); .def_packed("tl.Fragment",
}); [](PackedArgs args, Any *rv) {
*rv = Fragment(
TVM_REGISTER_GLOBAL("tl.Layout_forward_vars").set_body_typed([](Layout layout) { /*forward_var=*/args[0].cast<Array<IterVar>>(),
return layout->GetForwardVars(); /*forward_index=*/args[1].cast<Array<PrimExpr>>(),
/*forward_thread=*/args[2].cast<PrimExpr>(),
/*thread_replicate=*/args[3].cast<IterVar>());
})
.def("tl.Fragment_thread_size",
[](Fragment fragment) { return fragment->ThreadExtent(); })
.def("tl.Fragment_thread",
[](Fragment fragment) { return fragment->GetForwardThread(); })
.def("tl.Fragment_repeat",
[](Fragment fragment, Array<PrimExpr> repeats, bool repeat_on_thread,
bool lower_dim_first) {
return fragment->Repeat(repeats, repeat_on_thread,
lower_dim_first);
})
.def("tl.Fragment_replicate",
[](Fragment fragment, int repeats) {
return fragment->Replicate(repeats);
})
.def("tl.Fragment_condense_rep_var",
[](Fragment fragment) { return fragment->CondenseReplicateVar(); })
.def("tl.make_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, continuous,
element_size, 0);
});
}); });
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue *ret) { TVM_FFI_STATIC_INIT_BLOCK({
*ret = Fragment(args[0], args[1], args[2], args[3]); namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection();
}); });
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size")
.set_body_typed([](Fragment fragment) { return fragment->ThreadExtent(); });
TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) {
return fragment->GetForwardThread();
});
TVM_REGISTER_GLOBAL("tl.Fragment_repeat")
.set_body_typed([](Fragment fragment, Array<PrimExpr> repeats,
bool repeat_on_thread, bool lower_dim_first) {
return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first);
});
TVM_REGISTER_GLOBAL("tl.Fragment_replicate")
.set_body_typed([](Fragment fragment, int repeats) {
return fragment->Replicate(repeats);
});
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var")
.set_body_typed([](Fragment fragment) {
return fragment->CondenseReplicateVar();
});
TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
.set_body_typed([](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, continuous, element_size, 0);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
static constexpr bool _type_has_method_sequal_reduce = true; static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout"; static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const; bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor *v); static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
protected: protected:
...@@ -101,7 +101,7 @@ public: ...@@ -101,7 +101,7 @@ public:
bool IsEqual(const FragmentNode *other, bool skip_index = false) const; bool IsEqual(const FragmentNode *other, bool skip_index = false) const;
void VisitAttrs(tvm::AttrVisitor *v); static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char *_type_key = "tl.Fragment"; static constexpr const char *_type_key = "tl.Fragment";
......
...@@ -97,8 +97,9 @@ SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size, ...@@ -97,8 +97,9 @@ SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
data_ = std::move(n); data_ = std::move(n);
} }
void SwizzledLayoutNode::VisitAttrs(tvm::AttrVisitor *v) { void SwizzledLayoutNode::RegisterReflection() {
LayoutNode::VisitAttrs(v); namespace refl = tvm::ffi::reflection;
refl::ObjectDef<SwizzledLayoutNode>();
} }
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other, bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment