Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
......@@ -7,7 +7,7 @@ env:
jobs:
format-check:
runs-on: ubuntu-latest
runs-on: self-hosted
permissions:
contents: write
......@@ -26,21 +26,37 @@ jobs:
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install dependencies
- name: Ensure venv (local & persistent)
run: |
python -m pip install --upgrade pip
pip install yapf==0.40.2 toml==0.10.2 tomli==2.0.1 ruff==0.6.5 codespell==2.3.0 clang-format==15.0.7
set -e
REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
echo "venv exists and hash matches – reuse it"
else
echo "venv stale or missing – recreating"
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
# shellcheck source=/dev/null
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install . --no-user
touch "$MARKER"
fi
- name: Run format check
run: |
git clone https://github.com/tile-ai/tilelang.git main_repo
cp main_repo/format.sh .
rm -rf main_repo
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
if ! output=$(./format.sh 2>&1); then
echo "------------------------------------"
echo "message:"
echo "$output"
echo "------------------------------------"
printf '%s\n' "$output" | grep "Please review and stage the changes."
echo "------------------------------------"
exit 1
fi
- name: Commit and Push Changes
......
Subproject commit 979c8e7f94473db7d71a41b26ccf51db7e17a734
Subproject commit a08b7c34d4a59f89f4dea252fa1a7e458e298ef0
......@@ -11,6 +11,14 @@ endif()
# Enable compile command export
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if(NOT Python_EXECUTABLE)
execute_process(
COMMAND which python
OUTPUT_VARIABLE Python_EXECUTABLE
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(Python_EXECUTABLE "${Python_EXECUTABLE}" CACHE FILEPATH "Path to the Python executable")
endif()
# Define a custom macro for globbing files with conditional CONFIGURE_DEPENDS
if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.12.0")
......@@ -39,7 +47,8 @@ else()
# Set default build type to RelWithDebInfo if not provided
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "Build type" FORCE)
# Set default build type to Release if not provided
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
message(STATUS "Setting default build type to ${CMAKE_BUILD_TYPE}")
endif()
endif()
......@@ -145,6 +154,7 @@ message(STATUS "TVM_SOURCE_DIR: ${TVM_SOURCE_DIR}")
# Include directories for TileLang
set(TILE_LANG_INCLUDES
${TVM_SOURCE_DIR}/include
${TVM_SOURCE_DIR}/ffi/include
${TVM_SOURCE_DIR}/src
${TVM_SOURCE_DIR}/3rdparty/dlpack/include
${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include
......@@ -212,6 +222,11 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
target_compile_definitions(tilelang_static PRIVATE "TVM_LOG_DEBUG")
endif()
# Building tvm_cython modules
if(NOT DEFINED TVM_PREBUILD_PATH)
add_dependencies(tilelang tvm_cython)
endif()
# Module shared library
add_library(tilelang_module SHARED $<TARGET_OBJECTS:tilelang_objs>)
target_link_libraries(tilelang_module PUBLIC tvm)
......
......@@ -54,10 +54,8 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
arch = CDNA("hip") if torch.version.hip is not None else CUDA("cuda")
topk = 10
carve_template = MatmulTemplate(
......@@ -158,7 +156,7 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "e4m3_float8"
dtype = "float8_e4m3"
accum_dtype = "float"
@T.prim_func
......
......@@ -17,7 +17,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
@T.prim_func
def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor(
(BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "e4m3_float8"), X_amax: T.Tensor(
(BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor(
(BG, M_max, T.ceildiv(N, group_size)), accum_dtype)):
with T.Kernel(
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
......@@ -28,7 +28,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), accum_dtype)
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "e4m3_float8")
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_local((1,), "int32")
T.annotate_layout({
......
......@@ -15,7 +15,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max = 448.0
@T.prim_func
def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "e4m3_float8"),
def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"),
X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx
......@@ -24,7 +24,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "e4m3_float8")
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
T.annotate_layout({
y_local:
......
......@@ -20,8 +20,8 @@ def tl_gemm(
accum_dtype,
):
assert in_dtype in [
"e4m3_float8",
], "Currently only e4m3_float8 is supported"
"float8_e4m3",
], "Currently only float8_e4m3 is supported"
assert out_dtype in [
"bfloat16",
"float32",
......@@ -179,11 +179,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
def main():
assert_tl_gemm_correctness(1024, 1024, 8192, 128, "e4m3_float8", "bfloat16", "float32")
assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32")
if __name__ == "__main__":
for dtype in ["e4m3_float8"]:
for dtype in ["float8_e4m3"]:
for out_dtype in ["bfloat16", "float32"]:
for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32")
......@@ -11,7 +11,7 @@ import argparse
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
q_dtype = "e4m3_float8"
q_dtype = "float8_e4m3"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
......
......@@ -57,8 +57,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main():
test_gemm_fp8(1024, 1024, 1024, 'e4m3_float8')
test_gemm_fp8(1024, 1024, 1024, 'e5m2_float8')
test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2')
if __name__ == "__main__":
......
......@@ -74,8 +74,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main():
test_gemm_fp8(1024, 1024, 8192, 'e4m3_float8')
test_gemm_fp8(1024, 1024, 8192, 'e5m2_float8')
test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2')
if __name__ == "__main__":
......
......@@ -40,8 +40,8 @@ def tl_matmul(
):
assert in_dtype in [
"float16",
"e4m3_float8",
"e5m2_float8",
"float8_e4m3",
"float8_e5m2",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
......@@ -52,7 +52,7 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
......@@ -216,8 +216,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def main():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
if __name__ == "__main__":
......
......@@ -9,6 +9,7 @@ import argparse
tilelang.disable_cache()
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
......@@ -79,7 +80,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128)
lse_0_ready_barrier = T.alloc_barrier(arrive_count=128)
lse_1_ready_barrier = T.alloc_barrier(arrive_count=128)
s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
k_pe_shared_1_free_barrier = T.alloc_barrier(arrive_count=128)
k_pe_shared_0_free_barrier = T.alloc_barrier(arrive_count=128)
......@@ -401,8 +401,7 @@ def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
BLOCK_H = 64
num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
......
# Should be mirrored in pyproject.toml
Cython
build
cmake>=3.26
packaging
......
# lint requirements
-r requirements-lint.txt
# build requirements
Cython
cmake>=3.26
# runtime requirements
cffi
......
......@@ -815,7 +815,7 @@ class TilelangExtensionBuild(build_ext):
# -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go
# -DPYTHON_EXECUTABLE ensures that the correct Python is used
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", f"-DPython_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={'Debug' if DEBUG_MODE else 'Release'}"
]
if not USE_ROCM:
......
......@@ -6,7 +6,9 @@
#include "./transform/common/attr.h"
#include "op/builtin.h"
#include "tvm/ffi/any.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h>
namespace tvm {
......@@ -65,7 +67,7 @@ ForFrame ParallelFor(Array<PrimExpr> extents,
Var var = vars[i];
body =
For(var, dom->min, dom->extent, ForKind::kParallel, std::move(body),
/*thread_binding=*/NullOpt, /*annotations=*/annotations);
/*thread_binding=*/std::nullopt, /*annotations=*/annotations);
}
return body;
};
......@@ -99,7 +101,7 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages,
anno.Set("tl_pipeline_group", groups);
body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial,
std::move(body),
/*thread_binding=*/NullOpt, /*annotations=*/anno);
/*thread_binding=*/std::nullopt, /*annotations=*/anno);
return body;
};
return ForFrame(n);
......@@ -157,7 +159,7 @@ ForFrame PersistentFor(Array<PrimExpr> domain, PrimExpr wave_size,
Stmt());
Stmt outer = For(loop_var, 0, waves, ForKind::kSerial,
SeqStmt({out_if, body}), NullOpt, anno);
SeqStmt({out_if, body}), std::nullopt, anno);
for (int i = 0; i < vars.size() - 1; ++i) {
outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
}
......@@ -178,9 +180,10 @@ class KernelLaunchFrameNode : public TIRFrameNode {
public:
Array<TIRFrame> frames;
void VisitAttrs(tvm::AttrVisitor *v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("frames", &frames);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<KernelLaunchFrameNode>().def_ro(
"frames", &KernelLaunchFrameNode::frames);
}
static constexpr const char *_type_key = "tl.KernelLaunchFrame";
......@@ -213,14 +216,16 @@ public:
};
KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
Array<PrimExpr> block_size,
Map<String, ObjectRef> attrs) {
Optional<Array<PrimExpr>> block_size_opt,
Map<String, ffi::Any> attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame =
attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame);
auto block_size = block_size_opt.value_or(Array<PrimExpr>());
if (is_cpu_kernel_frame) {
// Launch CPU Kernel
ICHECK(grid_size.size() >= 0);
......@@ -279,18 +284,23 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);
TVM_REGISTER_GLOBAL("tl.Parallel").set_body_typed(ParallelFor);
TVM_REGISTER_GLOBAL("tl.Pipelined").set_body_typed(PipelinedFor);
TVM_REGISTER_GLOBAL("tl.Persistent").set_body_typed(PersistentFor);
TVM_REGISTER_GLOBAL("tl.KernelLaunch").set_body_typed(KernelLaunch);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch);
});
class WarpSpecializeFrameNode : public TIRFrameNode {
public:
Array<TIRFrame> frames;
void VisitAttrs(tvm::AttrVisitor *v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("frames", &frames);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<WarpSpecializeFrameNode>().def_ro(
"frames", &WarpSpecializeFrameNode::frames);
}
static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
......@@ -359,7 +369,12 @@ WarpSpecializeFrame WarpSpecialize(Array<IntImm> warp_group_ids,
}
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
TVM_REGISTER_GLOBAL("tl.WarpSpecialize").set_body_typed(WarpSpecialize);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection();
});
} // namespace tl
} // namespace tvm
......@@ -4,6 +4,7 @@
*/
#include "layout.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/op.h>
......@@ -73,9 +74,11 @@ Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
data_ = std::move(n);
}
void LayoutNode::VisitAttrs(AttrVisitor *v) {
v->Visit("input_size", &input_size_);
v->Visit("forward_index", &forward_index_);
void LayoutNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LayoutNode>()
.def_ro("input_size", &LayoutNode::input_size_)
.def_ro("forward_index", &LayoutNode::forward_index_);
}
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
......@@ -155,7 +158,7 @@ Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
auto new_forward_thread =
Substitute(forward_thread_, vmap) + thread_size * repeats_index;
return Fragment(new_input_size, new_forward_index, new_forward_thread,
replicate_size_, NullOpt);
replicate_size_, std::nullopt);
} else {
ICHECK(OutputDim() == 1);
PrimExpr frag_len = OutputShape()[0];
......@@ -163,7 +166,7 @@ Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
frag_len * repeats_index};
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
return Fragment(new_input_size, new_forward_index, new_forward_thread,
replicate_size_, NullOpt);
replicate_size_, std::nullopt);
}
}
......@@ -176,7 +179,7 @@ Fragment FragmentNode::Replicate(int repeats) const {
Substitute(forward_thread_, vmap) +
ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
return Fragment(input_size_, forward_index_, new_forward_thread,
ReplicateExtent() * repeats, NullOpt);
ReplicateExtent() * repeats, std::nullopt);
}
Fragment FragmentNode::DeReplicate() const {
......@@ -198,7 +201,7 @@ Fragment FragmentNode::DeReplicate() const {
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
return Fragment(input_size_, new_forward_index, new_forward_thread,
int(*rep_size) / factor, NullOpt);
int(*rep_size) / factor, std::nullopt);
}
Fragment FragmentNode::BindThreadRange(Range thread_range) const {
......@@ -304,18 +307,11 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
data_ = std::move(n);
}
void FragmentNode::VisitAttrs(tvm::AttrVisitor *v) {
LayoutNode::VisitAttrs(v);
v->Visit("forward_thread", &forward_thread_);
v->Visit("replicate_size", &replicate_size_);
}
PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
auto ist = analyzer.int_set(forward_thread_ + 1);
// CHECK(is_one(ist.min()));
return ist.max();
}
......@@ -435,64 +431,69 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
return ret;
}
void FragmentNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FragmentNode>()
.def_ro("forward_thread", &FragmentNode::forward_thread_)
.def_ro("replicate_size", &FragmentNode::replicate_size_);
}
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1]));
});
TVM_REGISTER_GLOBAL("tl.Layout_input_shape").set_body_typed([](Layout layout) {
return layout->InputShape();
});
TVM_REGISTER_GLOBAL("tl.Layout_output_shape").set_body_typed([](Layout layout) {
return layout->OutputShape();
});
TVM_REGISTER_GLOBAL("tl.Layout_inverse").set_body_typed([](Layout layout) {
return layout->Inverse();
});
TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) {
return layout->GetForwardIndex();
});
TVM_REGISTER_GLOBAL("tl.Layout_forward_vars").set_body_typed([](Layout layout) {
return layout->GetForwardVars();
});
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Fragment(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size")
.set_body_typed([](Fragment fragment) { return fragment->ThreadExtent(); });
TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) {
return fragment->GetForwardThread();
});
TVM_REGISTER_GLOBAL("tl.Fragment_repeat")
.set_body_typed([](Fragment fragment, Array<PrimExpr> repeats,
bool repeat_on_thread, bool lower_dim_first) {
return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first);
});
TVM_REGISTER_GLOBAL("tl.Fragment_replicate")
.set_body_typed([](Fragment fragment, int repeats) {
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("tl.Layout",
[](PackedArgs args, Any *rv) {
*rv = Layout(args[0].cast<Array<IterVar>>(),
args[1].cast<Array<PrimExpr>>());
})
.def("tl.Layout_input_shape",
[](Layout layout) { return layout->InputShape(); })
.def("tl.Layout_output_shape",
[](Layout layout) { return layout->OutputShape(); })
.def("tl.Layout_inverse", [](Layout layout) { return layout->Inverse(); })
.def("tl.Layout_index",
[](Layout layout) { return layout->GetForwardIndex(); })
.def("tl.Layout_forward_vars",
[](Layout layout) { return layout->GetForwardVars(); })
.def_packed("tl.Fragment",
[](PackedArgs args, Any *rv) {
*rv = Fragment(
/*forward_var=*/args[0].cast<Array<IterVar>>(),
/*forward_index=*/args[1].cast<Array<PrimExpr>>(),
/*forward_thread=*/args[2].cast<PrimExpr>(),
/*thread_replicate=*/args[3].cast<IterVar>());
})
.def("tl.Fragment_thread_size",
[](Fragment fragment) { return fragment->ThreadExtent(); })
.def("tl.Fragment_thread",
[](Fragment fragment) { return fragment->GetForwardThread(); })
.def("tl.Fragment_repeat",
[](Fragment fragment, Array<PrimExpr> repeats, bool repeat_on_thread,
bool lower_dim_first) {
return fragment->Repeat(repeats, repeat_on_thread,
lower_dim_first);
})
.def("tl.Fragment_replicate",
[](Fragment fragment, int repeats) {
return fragment->Replicate(repeats);
})
.def("tl.Fragment_condense_rep_var",
[](Fragment fragment) { return fragment->CondenseReplicateVar(); })
.def("tl.make_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, continuous,
element_size, 0);
});
});
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var")
.set_body_typed([](Fragment fragment) {
return fragment->CondenseReplicateVar();
});
TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
.set_body_typed([](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, continuous, element_size, 0);
});
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection();
});
} // namespace tl
} // namespace tvm
......@@ -44,7 +44,7 @@ public:
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor *v);
static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
protected:
......@@ -101,7 +101,7 @@ public:
bool IsEqual(const FragmentNode *other, bool skip_index = false) const;
void VisitAttrs(tvm::AttrVisitor *v);
static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char *_type_key = "tl.Fragment";
......
......@@ -97,8 +97,9 @@ SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
data_ = std::move(n);
}
void SwizzledLayoutNode::VisitAttrs(tvm::AttrVisitor *v) {
LayoutNode::VisitAttrs(v);
void SwizzledLayoutNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<SwizzledLayoutNode>();
}
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment