"src/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "cc08ba50dfd91d2a0f641d497cab3260639545b7"
Unverified Commit 74da3696 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Use tvm ffi as the default execution backend (#1259)

* [Refactor] Update FFI type handling and simplify argument management

* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity.
* Updated function registration in `runtime.cc` to utilize canonical names for better consistency.
* Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled.
* Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection.
* Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.

* [Update] Sync TVM submodule and enhance kernel source handling

* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes.
* Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging.
* Commented out the main execution call in test files to prevent unintended execution during testing.
* Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues.
* Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends.

* [Refactor] Clean up imports and improve code formatting

* Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code.
* Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency.
* Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality.

* Update execution backend options and improve resolution logic

- Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target.
- Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions.
- Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target.
- Updated documentation to reflect changes in execution backend options and their defaults.

* lint fix

* fix

* Enhance argument handling in CUDA and HIP runtime modules

- Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime.
- Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers.
- Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks.

* lint fix

* lint fix

* lint fix

* lint fix

* minor fix

* fix

* recover check

* Refactor argument binding and validation in `arg_binder.cc`

- Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers.
- Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards.
- Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling.
- Minor adjustments in test files to streamline kernel execution and improve readability.

* lint fix

* stride fix

* minor fix

* fix

* lint fix

* lint fix

* Add CUDA stream access policy window helpers and integrate with L2 persistent cache management

- Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage.
- Updated runtime files to include new FFI packed functions for managing stream attributes.
- Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown.
- Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source.

* check with symbolic

* support null ptr

* Update CMakeLists and lower.py for code generation and subproject status

- Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support.
- Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility.
- Marked the TVM subproject as dirty to indicate local modifications.

* lint fix

* Update comments for clarity in quickstart.py
parent 921b96a3
Subproject commit 093b2cdb2187140b197336496d65d61ace89e8ff Subproject commit f4105f89a646622acc9818584d1d91e2ca3f533d
...@@ -138,6 +138,7 @@ file(GLOB TILE_LANG_SRCS ...@@ -138,6 +138,7 @@ file(GLOB TILE_LANG_SRCS
src/transform/*.cc src/transform/*.cc
src/op/*.cc src/op/*.cc
src/target/utils.cc src/target/utils.cc
src/target/codegen_c_host.cc
src/target/codegen_cpp.cc src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc src/target/rt_mod_cpp.cc
# intrin_rule doesn't have system dependency # intrin_rule doesn't have system dependency
......
...@@ -166,7 +166,6 @@ def main(): ...@@ -166,7 +166,6 @@ def main():
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) enable_rasteration=DEFAULT_ENABLE_RASTERIZATION)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
# Create block mask with desired sparsity # Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K) mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity block_mask = torch.rand(mask_shape).cuda() > sparsity
......
...@@ -468,7 +468,6 @@ def run_test( ...@@ -468,7 +468,6 @@ def run_test(
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw, gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw,
block_DK, block_DV, threads, num_stages) block_DK, block_DV, threads, num_stages)
print(kernel.get_kernel_source())
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
if use_g: if use_g:
......
...@@ -117,6 +117,7 @@ def test_example_chunk_o_bwd_compilation(): ...@@ -117,6 +117,7 @@ def test_example_chunk_o_bwd_compilation():
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
block_DK, block_DV, threads, num_stages) block_DK, block_DV, threads, num_stages)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv,
W) # noqa: F841 W) # noqa: F841
if use_g: if use_g:
......
...@@ -55,10 +55,9 @@ block_M = 128 ...@@ -55,10 +55,9 @@ block_M = 128
block_N = 128 block_N = 128
block_K = 32 block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module # Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# Test the kernel in Python with PyTorch data
# 3. Test the kernel in Python with PyTorch data
import torch import torch
# Create random input tensors on the GPU # Create random input tensors on the GPU
......
...@@ -104,6 +104,7 @@ tilelang = "tilelang" ...@@ -104,6 +104,7 @@ tilelang = "tilelang"
# TVM # TVM
"tilelang/3rdparty/tvm/src" = "3rdparty/tvm/src" "tilelang/3rdparty/tvm/src" = "3rdparty/tvm/src"
"tilelang/3rdparty/tvm/python" = "3rdparty/tvm/python" "tilelang/3rdparty/tvm/python" = "3rdparty/tvm/python"
"tilelang/3rdparty/tvm/include" = "3rdparty/tvm/include"
"tilelang/3rdparty/tvm/version.py" = "3rdparty/tvm/version.py" "tilelang/3rdparty/tvm/version.py" = "3rdparty/tvm/version.py"
# CUTLASS # CUTLASS
"tilelang/3rdparty/cutlass/include" = "3rdparty/cutlass/include" "tilelang/3rdparty/cutlass/include" = "3rdparty/cutlass/include"
......
...@@ -13,6 +13,12 @@ ...@@ -13,6 +13,12 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
#if 1
// Thread-local storage for restoring the L2 persisting cache limit
static thread_local size_t __tl_prev_persisting_l2_cache_size = 0;
static thread_local bool __tl_prev_persisting_l2_cache_saved = false;
#endif
#if (CUDA_MAJOR_VERSION >= 12) #if (CUDA_MAJOR_VERSION >= 12)
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) { template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss; std::stringstream ss;
...@@ -91,19 +97,21 @@ struct TensorMapArgs { ...@@ -91,19 +97,21 @@ struct TensorMapArgs {
// set device api // set device api
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args, // Register using the canonical names defined in runtime.h
Any *ret) { refl::GlobalDef().def_packed(
TensorMapArgs T = TensorMapArgs::Extract(args); tl::tvm_tensormap_create_tiled, [](PackedArgs args, Any *ret) {
CUresult result = cuTensorMapEncodeTiled( TensorMapArgs T = TensorMapArgs::Extract(args);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, CUresult result = cuTensorMapEncodeTiled(
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, T.swizzle, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.l2Promotion, T.oobFill); T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
if (result != CUDA_SUCCESS) { T.swizzle, T.l2Promotion, T.oobFill);
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << '\n' if (result != CUDA_SUCCESS) {
<< T.ToDebugString(); LOG_FATAL << "Failed to initialize the TMA descriptor " << result
} << '\n'
*ret = static_cast<int>(result); << T.ToDebugString();
}); }
*ret = static_cast<int>(result);
});
} }
struct TensorMapIm2ColArgs { struct TensorMapIm2ColArgs {
...@@ -183,7 +191,7 @@ struct TensorMapIm2ColArgs { ...@@ -183,7 +191,7 @@ struct TensorMapIm2ColArgs {
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed( refl::GlobalDef().def_packed(
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) { tl::tvm_tensormap_create_im2col, [](PackedArgs args, Any *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args); TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col( CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
...@@ -201,5 +209,141 @@ TVM_FFI_STATIC_INIT_BLOCK() { ...@@ -201,5 +209,141 @@ TVM_FFI_STATIC_INIT_BLOCK() {
#endif // (CUDA_MAJOR_VERSION >= 12) #endif // (CUDA_MAJOR_VERSION >= 12)
//
// CUDA L2 Persisting Cache Access Policy Window helpers.
// Exposed as TVM FFI packed functions similar to TMA initialization.
//
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// Set stream access policy window and adjust persisting L2 cache size
// Args:
// [0]: void* base_ptr (required)
// [1]: int64 num_bytes (required)
// [2]: float hit_ratio (optional, default 0.8)
// [3]: void* stream (optional, default 0 => default stream)
// [4]: int64 l2_limit_bytes (optional, default = num_bytes)
refl::GlobalDef().def_packed(
tl::tvm_cuda_stream_set_access_policy_window,
[](PackedArgs args, Any *ret) {
ICHECK(args.size() >= 2) << "Expected at least base_ptr and num_bytes";
void *base_ptr = args[0].cast<void *>();
size_t num_bytes = static_cast<size_t>(args[1].cast<int64_t>());
float hit_ratio = 0.8f;
if (args.size() >= 3) {
// Accept double/float
hit_ratio = static_cast<float>(args[2].cast<double>());
}
CUstream stream = nullptr;
if (args.size() >= 4) {
stream = reinterpret_cast<CUstream>(args[3].cast<void *>());
}
size_t l2_limit_bytes = num_bytes;
if (args.size() >= 5) {
l2_limit_bytes = static_cast<size_t>(args[4].cast<int64_t>());
}
// Clamp requested limit to device capability
CUdevice device;
CUresult result = cuCtxGetDevice(&device);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to get current CUDA device: " << result;
}
int max_persisting = 0;
result = cuDeviceGetAttribute(
&max_persisting, CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE,
device);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to query MAX_PERSISTING_L2_CACHE_SIZE: "
<< result;
}
if (max_persisting > 0 &&
l2_limit_bytes > static_cast<size_t>(max_persisting)) {
l2_limit_bytes = static_cast<size_t>(max_persisting);
}
// Save current limit to restore later
size_t init_persisting_l2_cache_size = 0;
result = cuCtxGetLimit(&init_persisting_l2_cache_size,
CU_LIMIT_PERSISTING_L2_CACHE_SIZE);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to get current persisting L2 cache size limit: "
<< result;
}
__tl_prev_persisting_l2_cache_size = init_persisting_l2_cache_size;
__tl_prev_persisting_l2_cache_saved = true;
// Set new limit
result =
cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, l2_limit_bytes);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to set persisting L2 cache size limit: "
<< result;
}
// Apply access policy window to stream
CUstreamAttrValue stream_attribute;
memset(&stream_attribute, 0, sizeof(stream_attribute));
stream_attribute.accessPolicyWindow.base_ptr = base_ptr;
stream_attribute.accessPolicyWindow.num_bytes = l2_limit_bytes;
stream_attribute.accessPolicyWindow.hitRatio = hit_ratio;
stream_attribute.accessPolicyWindow.hitProp =
CU_ACCESS_PROPERTY_PERSISTING;
stream_attribute.accessPolicyWindow.missProp =
CU_ACCESS_PROPERTY_STREAMING;
result = cuStreamSetAttribute(stream,
CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
&stream_attribute);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to set stream access policy window: " << result;
}
*ret = static_cast<int>(result);
});
// Reset stream access policy window and restore the previous L2 cache size
// Args:
// [0]: void* stream (optional, default 0)
refl::GlobalDef().def_packed(
tl::tvm_cuda_stream_reset_access_policy_window,
[](PackedArgs args, Any *ret) {
CUstream stream = nullptr;
if (args.size() >= 1) {
stream = reinterpret_cast<CUstream>(args[0].cast<void *>());
}
CUstreamAttrValue stream_attribute;
memset(&stream_attribute, 0, sizeof(stream_attribute));
// num_bytes = 0 disables the access policy window on the stream
stream_attribute.accessPolicyWindow.num_bytes = 0;
CUresult result = cuStreamSetAttribute(
stream, CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
&stream_attribute);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to reset stream access policy window: "
<< result;
}
result = cuCtxResetPersistingL2Cache();
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to reset persisting L2 cache lines: " << result;
}
if (__tl_prev_persisting_l2_cache_saved) {
result = cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE,
__tl_prev_persisting_l2_cache_size);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to restore persisting L2 cache size limit: "
<< result;
}
__tl_prev_persisting_l2_cache_saved = false;
}
*ret = static_cast<int>(result);
});
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -16,7 +16,13 @@ constexpr const char *tvm_tensormap_create_tiled = ...@@ -16,7 +16,13 @@ constexpr const char *tvm_tensormap_create_tiled =
constexpr const char *tvm_tensormap_create_im2col = constexpr const char *tvm_tensormap_create_im2col =
"__tvm_tensormap_create_im2col"; "__tvm_tensormap_create_im2col";
#endif // (CUDA_MAJOR_VERSION >= 12) #endif // (CUDA_MAJOR_VERSION >= 12)
// CUDA stream access policy window helpers
constexpr const char *tvm_cuda_stream_set_access_policy_window =
"__tvm_cuda_stream_set_access_policy_window";
constexpr const char *tvm_cuda_stream_reset_access_policy_window =
"__tvm_cuda_stream_reset_access_policy_window";
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_ #endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.cc
*/
#include "codegen_c_host.h"
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/codegen.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
// For escaping strings embedded into generated C sources
#include "support/str_escape.h"
namespace tvm {
namespace tl {
CodeGenCHost::CodeGenCHost() {
module_name_ = name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx);
}
void CodeGenCHost::Init(bool output_ssa, bool emit_asserts,
bool emit_fwd_func_decl, std::string target_str,
const std::unordered_set<std::string> &devices) {
emit_asserts_ = emit_asserts;
emit_fwd_func_decl_ = emit_fwd_func_decl;
declared_globals_.clear();
decl_stream << "// tilelang target: " << target_str << "\n";
decl_stream << "#define TVM_EXPORTS\n";
decl_stream << "#include \"tvm/runtime/base.h\"\n";
decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
decl_stream << "#include \"tvm/ffi/c_api.h\"\n";
decl_stream << "#include <math.h>\n";
// snprintf for richer assert messages with actual values
decl_stream << "#include <stdio.h>\n";
decl_stream << "#include <stdbool.h>\n";
CodeGenCHost::InitGlobalContext();
tvm::codegen::CodeGenC::Init(output_ssa);
}
void CodeGenCHost::InitGlobalContext() {
decl_stream << "void* " << tvm::ffi::symbol::tvm_ffi_library_ctx
<< " = NULL;\n";
}
void CodeGenCHost::DefineModuleName() {
decl_stream << "void* " << module_name_ << " = NULL;\n";
}
void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &func) {
return AddFunction(gvar, func, /*emit_fwd_func_decl=*/false);
}
void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &func,
bool emit_fwd_func_decl) {
auto global_symbol =
func->GetAttr<tvm::ffi::String>(tvm::attr::kGlobalSymbol);
if (global_symbol) {
function_names_.push_back(global_symbol.value());
}
emit_fwd_func_decl_ = emit_fwd_func_decl;
tvm::codegen::CodeGenC::AddFunction(gvar, func);
if (func->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc) && !has_main_func_) {
ICHECK(global_symbol.has_value())
<< "CodeGenCHost: The entry func must have the global_symbol "
"attribute, "
<< "but function " << gvar << " only has attributes " << func->attrs;
function_names_.push_back(tvm::ffi::symbol::tvm_ffi_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(func->ret_type, stream);
stream << " " << tvm::ffi::symbol::tvm_ffi_main
<< "(void* self, void* args,int num_args, void* result) {\n";
stream << " return " << static_cast<std::string>(global_symbol.value())
<< "(self, args, num_args, result);\n";
stream << "}\n";
has_main_func_ = true;
}
}
void CodeGenCHost::GenerateForwardFunctionDeclarations(
tvm::ffi::String global_symbol, const tvm::ffi::Array<tvm::Type> &arg_types,
const tvm::Type &ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
for (auto &func_already_defined : GetFunctionNames()) {
if (global_symbol == func_already_defined) {
return;
}
}
this->PrintFuncPrefix(fwd_decl_stream);
this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 0; i < arg_types.size(); ++i) {
if (i > 0) {
fwd_decl_stream << ", ";
}
tvm::codegen::CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
void CodeGenCHost::PrintFuncPrefix(std::ostream &os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n";
}
void CodeGenCHost::PrintType(tvm::DataType t, std::ostream &os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
ICHECK_EQ(lanes, 1) << "does not support vector types";
os << "void*";
return;
}
if (t.is_void()) {
os << "void";
return;
}
if (t == tvm::DataType::Bool()) {
os << "bool";
return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16:
os << "half";
break;
case 32:
os << "float";
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1)
return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
}
if (t.is_bfloat16()) {
os << "__bf16";
return;
}
if (t.is_int() || t.is_uint()) {
if (t.is_uint()) {
os << 'u';
}
switch (t.bits()) {
case 8:
os << "int8_t";
break;
case 16:
os << "int16_t";
break;
case 32:
os << "int32_t";
break;
case 64:
os << "int64_t";
break;
case 1:
os << "int32_t";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1)
return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
void CodeGenCHost::VisitExpr_(const tvm::tir::BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
os << "((";
PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << v;
}
os << "))";
}
void CodeGenCHost::PrintGetFuncFromBackend(
const std::string &func_name, const std::string &packed_func_name) {
this->PrintIndent();
this->stream << "if (" << packed_func_name << " == NULL) {\n";
int packed_func_if_scope = this->BeginScope();
this->PrintIndent();
this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \""
<< func_name << "\""
<< ", &" << packed_func_name << ") != 0) {\n";
int get_func_env_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(get_func_env_scope);
this->PrintIndent();
this->stream << "}\n";
this->EndScope(packed_func_if_scope);
this->PrintIndent();
this->stream << "}\n";
}
void CodeGenCHost::PrintCallPacked(const tvm::tir::CallNode *op) {
using namespace tvm::tir;
const StringImmNode *func_name = op->args[0].as<StringImmNode>();
ICHECK(func_name != nullptr)
<< "tvm_call_[c]packed_lowered expects first argument as function name";
int64_t begin = op->args[2].as<IntImmNode>()->value;
int64_t end = op->args[3].as<IntImmNode>()->value;
int64_t num_args = end - begin;
ICHECK_GE(num_args, 0);
std::string packed_func_name;
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
packed_func_name = GetPackedName(op);
this->PrintGetFuncFromBackend(func_name->value, packed_func_name);
} else {
// directly use the original symbol
ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered()));
packed_func_name =
tvm::ffi::symbol::tvm_ffi_symbol_prefix + func_name->value;
}
std::string args_stack = PrintExpr(op->args[1]);
this->PrintIndent();
std::string result = name_supply_->FreshName("result");
this->stream << "TVMFFIAny " << result << ";\n";
this->PrintIndent();
// must make sure type_index is set to none
this->stream << result << ".type_index = kTVMFFINone;\n";
this->PrintIndent();
this->stream << result << ".zero_padding = 0;\n";
this->PrintIndent();
this->stream << result << ".v_int64 = 0;\n";
this->PrintIndent();
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
this->stream << "if (TVMFFIFunctionCall(" << packed_func_name << ", ";
} else {
this->stream << "if (" << packed_func_name << "(NULL, ";
}
this->stream << "(TVMFFIAny*) " << args_stack << ", " << num_args << ", "
<< "&" << result << ") != 0) {\n";
int func_call_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(func_call_scope);
this->PrintIndent();
this->stream << "}\n";
}
std::string CodeGenCHost::GetPackedName(const tvm::tir::CallNode *op) {
using namespace tvm::tir;
const StringImmNode *s = op->args[0].as<StringImmNode>();
ICHECK(s != nullptr)
<< "tvm_call_packed_lowered expects first argument as function name";
std::string func_name = s->value;
std::string packed_func_name = func_name + "_packed";
std::string unique_name;
auto it = declared_globals_.find(packed_func_name);
if (it != declared_globals_.end()) {
unique_name = it->second;
} else {
unique_name = name_supply_->FreshName(packed_func_name);
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
return unique_name;
}
void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op,
std::ostream &os) { // NOLINT(*)
using namespace tvm::tir;
if (op->op.same_as(builtin::tvm_stack_alloca())) {
std::string stack_name = name_supply_->FreshName("stack");
const std::string &type = op->args[0].as<StringImmNode>()->value;
const IntImmNode *num = op->args[1].as<IntImmNode>();
ICHECK(num != nullptr);
static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMFFIAny);
size_t size = 0;
if (type == "shape") {
size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit;
} else if (type == "tvm_ffi_any") {
size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit;
} else if (type == "array") {
size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
this->PrintIndent();
this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n";
os << stack_name;
} else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
this->PrintCallPacked(op);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
this->PrintCallPacked(op);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
} else {
tvm::codegen::CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
using namespace tvm::tir;
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
{
// Prepare the base error message
const auto *msg_node = op->message.as<StringImmNode>();
ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm";
const std::string &raw_msg = msg_node->value;
const std::string esc_msg = tvm::support::StrEscape(
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
/*escape_whitespace_special_chars=*/true);
// If the assertion condition contains any equality checks anywhere
// in a composite boolean expression, append the actual LHS/RHS values
// Collect all EQ nodes within the condition (including inside And/Or/Not)
std::vector<const EQNode *> eq_nodes;
{
std::vector<PrimExpr> stk;
stk.push_back(op->condition);
while (!stk.empty()) {
PrimExpr cur = stk.back();
stk.pop_back();
if (const auto *eq = cur.as<EQNode>()) {
eq_nodes.push_back(eq);
continue;
}
if (const auto *an = cur.as<AndNode>()) {
stk.push_back(an->a);
stk.push_back(an->b);
continue;
}
if (const auto *on = cur.as<OrNode>()) {
stk.push_back(on->a);
stk.push_back(on->b);
continue;
}
if (const auto *nn = cur.as<NotNode>()) {
stk.push_back(nn->a);
continue;
}
}
}
if (!eq_nodes.empty()) {
// Build a single detailed message that includes all LHS/RHS pairs
PrintIndent();
stream << "char __tvm_assert_msg_buf[1024];\n";
PrintIndent();
stream << "int __tvm_assert_msg_len = snprintf(__tvm_assert_msg_buf, "
"sizeof(__tvm_assert_msg_buf), \"%s\", \""
<< esc_msg << "\");\n";
auto escape_for_printf_literal = [&](const std::string &s) {
std::string out;
out.reserve(s.size());
for (char c : s) {
if (c == '%') {
out += "%%";
} else if (c == '"') {
out += "\\\"";
} else if (c == '\\') {
out += "\\\\";
} else {
out.push_back(c);
}
}
return out;
};
for (const auto *eq : eq_nodes) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
std::string lhs_disp = escape_for_printf_literal(lhs);
std::string rhs_disp = escape_for_printf_literal(rhs);
PrintIndent();
stream << "__tvm_assert_msg_len += snprintf(__tvm_assert_msg_buf + "
"__tvm_assert_msg_len, "
"sizeof(__tvm_assert_msg_buf) - __tvm_assert_msg_len, \"; ("
<< lhs_disp << " == " << rhs_disp
<< ") got: %lld, expected: %lld\", (long long)(" << lhs
<< "), (long long)(" << rhs << "));\n";
}
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
"__tvm_assert_msg_buf);\n";
} else {
// Fallback: just emit the base message
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg
<< "\");\n";
}
}
PrintIndent();
stream << "return -1;\n";
this->EndScope(assert_if_scope);
PrintIndent();
stream << "}\n";
}
this->PrintStmt(op->body);
}
void CodeGenCHost::VisitExpr_(const tvm::tir::MinNode *op,
std::ostream &os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os);
}
void CodeGenCHost::VisitExpr_(const tvm::tir::MaxNode *op,
std::ostream &os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os);
}
template <typename T>
inline void CodeGenCHost::PrintTernaryCondExpr(const T *op, const char *compare,
std::ostream &os) { // NOLINT(*)
std::ostringstream temp_a;
VisitExpr(op->a, temp_a);
std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
std::ostringstream temp_b;
VisitExpr(op->b, temp_b);
std::string b_id = SSAGetID(temp_b.str(), op->b.dtype());
os << "((" << a_id << ") " << compare << " (" << b_id << ") "
<< "? (" << a_id << ") : (" << b_id << "))";
}
} // namespace tl
} // namespace tvm
namespace tvm {
namespace tl {
using tvm::codegen::CodeGenSourceBase;
using tvm::codegen::CSourceModuleCreate;
using tvm::ffi::Array;
using tvm::ffi::Map;
using tvm::ffi::Module;
using tvm::ffi::String;
// Build function that mirrors TVM's host C codegen, registered under a
// TileLang-specific name.
::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod,
::tvm::Target target) {
bool output_ssa = false;
bool emit_asserts = true;
bool emit_fwd_func_decl = true;
std::unordered_set<std::string> devices;
if (mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>(
"device_contexts") != nullptr) {
::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String> device_contexts =
mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>(
"device_contexts")
.value();
for (auto const &context : device_contexts) {
devices.insert(context.second.data());
}
}
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
cg.SetConstantsByteAlignment(
target->GetAttr<::tvm::Integer>("constants-byte-alignment").value_or(16));
auto is_aot_executor_fn = [](::tvm::tir::PrimFunc const &func) -> bool {
return func->GetAttr<::tvm::Bool>("runner_function", ::tvm::Bool(false))
.value();
};
std::vector<std::pair<::tvm::GlobalVar, ::tvm::tir::PrimFunc>> funcs;
for (auto [gvar, base_func] : mod->functions) {
ICHECK(base_func->IsInstance<::tvm::tir::PrimFuncNode>())
<< "CodegenCHost: Can only take PrimFunc";
auto prim_func = ::tvm::Downcast<::tvm::tir::PrimFunc>(base_func);
funcs.push_back({gvar, prim_func});
}
auto sort_key = [&is_aot_executor_fn](const auto &kv) {
return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
};
std::sort(funcs.begin(), funcs.end(),
[&sort_key](const auto &kv_a, const auto &kv_b) {
return sort_key(kv_a) < sort_key(kv_b);
});
for (const auto &[gvar, prim_func] : funcs) {
cg.DeclareFunction(gvar, prim_func);
}
for (const auto &[gvar, prim_func] : funcs) {
cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}
std::string code = cg.Finish();
return ::tvm::codegen::CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_c", BuildTileLangCHost);
}
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.h
* \brief Generate C host code (TileLang copy).
*/
#ifndef TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
#define TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "target/source/codegen_c.h"
#include "tvm/target/codegen.h"
#include "tvm/tir/expr.h"
namespace tvm {
namespace tl {
// TileLang copy of TVM's CodeGenCHost, under the tl namespace.
// Inherits from tvm::codegen::CodeGenC.
class CodeGenCHost : public tvm::codegen::CodeGenC {
public:
CodeGenCHost();
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
std::string target_str,
const std::unordered_set<std::string> &devices);
void InitGlobalContext();
void AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &f) override;
void AddFunction(const tvm::GlobalVar &gvar, const tvm::tir::PrimFunc &f,
bool emit_fwd_func_decl);
/*!
* \brief Add functions from the (unordered) range to the current module in a
* deterministic order. This helps with debugging.
*
* \param functions A vector of unordered range of current module.
*/
void AddFunctionsOrdered(
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
void DefineModuleName();
using tvm::codegen::CodeGenC::PrintType;
void PrintType(tvm::DataType t, std::ostream &os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream &os) final; // NOLINT(*)
// overload visitor functions
void VisitExpr_(const tvm::tir::BroadcastNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const tvm::tir::CallNode *op,
std::ostream &os) override; // NOLINT(*)
// overload min and max to use the ternary operator, so we don't rely on the
// standard library implementations
void VisitExpr_(const tvm::tir::MinNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const tvm::tir::MaxNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitStmt_(const tvm::tir::AssertStmtNode *op) final; // NOLINT(*)
void GenerateForwardFunctionDeclarations(
tvm::ffi::String global_symbol,
const tvm::ffi::Array<tvm::Type> &arg_types,
const tvm::Type &ret_type) override;
tvm::ffi::Array<tvm::ffi::String> GetFunctionNames() {
return function_names_;
}
private:
std::string module_name_;
/* \brief mapping global packed func to the unique name */
std::unordered_map<std::string, std::string> declared_globals_;
/* \brief names of the functions declared in this module */
tvm::ffi::Array<tvm::ffi::String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
/*! \brief whether to emit forwared function declarations in the resulting C
* code */
bool emit_fwd_func_decl_;
/*! \brief whether to generate the entry function if encountered */
bool has_main_func_ = false;
std::string GetPackedName(const tvm::tir::CallNode *op);
void PrintGetFuncFromBackend(const std::string &func_name,
const std::string &packed_func_name);
void PrintCallPacked(const tvm::tir::CallNode *op);
/*!
* \brief Print ternary conditional operator implementing binary `op`
* Forces the operands to be in SSA form.
* \param op binary operator being expressed
* \param compare string representation of comparison operator
* \param os stream reference to print into
*/
template <typename T>
inline void PrintTernaryCondExpr(const T *op, const char *compare,
std::ostream &os); // NOLINT(*)
};
} // namespace tl
} // namespace tvm
#endif // TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
...@@ -203,12 +203,12 @@ void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name, ...@@ -203,12 +203,12 @@ void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name,
this->PrintIndent(); this->PrintIndent();
std::string ret_val = name_supply_->FreshName("ret_val"); std::string ret_val = name_supply_->FreshName("ret_val");
std::string ret_type_code = name_supply_->FreshName("ret_type_code"); std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n"; this->stream << "TVMFFIAny " << ret_val << ";\n";
this->PrintIndent(); this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n"; this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent(); this->PrintIndent();
this->stream << "if (TVMFuncCall(" << packed_func_name << ", " this->stream << "if (TVMFuncCall(" << packed_func_name << ", "
<< "(TVMValue*) stack_value" << "(TVMFFIAny*) stack_value"
<< ", " << ", "
<< "(int*) stack_tcode" << "(int*) stack_tcode"
<< ", " << num_args << ", " << ", " << num_args << ", "
...@@ -228,13 +228,13 @@ void CodeGenTileLangCPP::PrintFuncCallC( ...@@ -228,13 +228,13 @@ void CodeGenTileLangCPP::PrintFuncCallC(
this->PrintIndent(); this->PrintIndent();
std::string ret_val = name_supply_->FreshName("ret_val"); std::string ret_val = name_supply_->FreshName("ret_val");
std::string ret_type_code = name_supply_->FreshName("ret_type_code"); std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n"; this->stream << "TVMFFIAny " << ret_val << ";\n";
this->PrintIndent(); this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n"; this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent(); this->PrintIndent();
this->stream << "if (" << packed_func_name << "( " this->stream << "if (" << packed_func_name << "( "
<< "(TVMValue*) stack_value " << "(TVMFFIAny*) stack_value "
<< ", " << ", "
<< "(int*) stack_tcode" << "(int*) stack_tcode"
<< ", " << num_args << ", " << ", " << num_args << ", "
......
...@@ -24,7 +24,11 @@ ExtractFuncInfo(const IRModule &mod) { ...@@ -24,7 +24,11 @@ ExtractFuncInfo(const IRModule &mod) {
continue; continue;
} }
} }
info.arg_types.push_back(f->params[i].dtype()); DataType dtype = f->params[i].dtype();
// Device runtime cannot directly take bool arguments, map to int32.
if (dtype.is_bool())
dtype = DataType::Int(32);
info.arg_types.push_back(dtype);
} }
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>( if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) { tir::attr::kKernelLaunchParams)) {
......
...@@ -35,7 +35,11 @@ ExtractFuncInfo(const IRModule &mod) { ...@@ -35,7 +35,11 @@ ExtractFuncInfo(const IRModule &mod) {
continue; continue;
} }
} }
info.arg_types.push_back(f->params[i].dtype()); DataType dtype = f->params[i].dtype();
// Device runtime cannot directly take bool arguments, map to int32.
if (dtype.is_bool())
dtype = DataType::Int(32);
info.arg_types.push_back(dtype);
} }
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>( if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) { tir::attr::kKernelLaunchParams)) {
......
This diff is collapsed.
...@@ -154,6 +154,10 @@ public: ...@@ -154,6 +154,10 @@ public:
return def_handle_dtype_; return def_handle_dtype_;
} }
bool BindNullable(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_lets,
const PrimExpr &nullable_guard);
private: private:
// Internal bind function // Internal bind function
bool Bind_(const PrimExpr &arg, const PrimExpr &value, bool Bind_(const PrimExpr &arg, const PrimExpr &value,
......
...@@ -26,10 +26,13 @@ public: ...@@ -26,10 +26,13 @@ public:
LowerHopperIntrin substituter(disable_shuffle_elect); LowerHopperIntrin substituter(disable_shuffle_elect);
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
Map<Var, Array<PrimExpr>> init_desc_arg_map; Map<Var, Array<PrimExpr>> init_desc_arg_map;
// Collect prologue/epilogue statements for host-side setup/teardown
Array<Stmt> prologue_stmts;
Array<Stmt> epilogue_stmts;
for (const auto &[call, var] : substituter.desc_map_) { for (const auto &[call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack // Should allocate 128 bytes for TensorMap on stack
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(), Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
{StringImm("arg_value"), 16}); {StringImm("tvm_ffi_any"), 16});
Array<PrimExpr> init_desc_args; Array<PrimExpr> init_desc_args;
if (call->op.same_as(create_tma_descriptor())) { if (call->op.same_as(create_tma_descriptor())) {
init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled)); init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled));
...@@ -44,11 +47,66 @@ public: ...@@ -44,11 +47,66 @@ public:
// add to function attribute // add to function attribute
Call init_desc = Call init_desc =
Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args); Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
fptr->body = // Accumulate TMA descriptor init into prologue
LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body})); prologue_stmts.push_back(LetStmt(var, alloc_desc, Evaluate(init_desc)));
init_desc_arg_map.Set(var, init_desc_args); init_desc_arg_map.Set(var, init_desc_args);
} }
f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map); f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map);
// Additionally, if L2 persistent cache annotations were lowered earlier,
// materialize TVM FFI calls to set the stream access policy window.
if (f->attrs.defined() && f->attrs->dict.count("l2_persistent_map")) {
auto l2_map =
f->GetAttr<Map<String, Array<PrimExpr>>>("l2_persistent_map");
if (l2_map.defined()) {
// Build a lookup from buffer name to Buffer object
std::unordered_map<std::string, Buffer> name2buf;
for (const auto &kv : f->buffer_map) {
name2buf.emplace(kv.second->name, kv.second);
}
for (const auto &kv : l2_map.value()) {
const std::string buf_name = kv.first;
const Array<PrimExpr> &args = kv.second;
if (name2buf.count(buf_name) == 0) {
continue;
}
const Buffer &buf = name2buf.at(buf_name);
// Build base pointer expression (read access)
PrimExpr base_ptr = buf.access_ptr(1);
// Args packed: func_name, base_ptr, num_bytes, hit_ratio
Array<PrimExpr> packed_args;
packed_args.push_back(
StringImm(tvm_cuda_stream_set_access_policy_window));
packed_args.push_back(base_ptr);
// size_in_bytes (args[1]) then hit_ratio (args[0])
ICHECK_GE(args.size(), 2);
packed_args.push_back(args[1]);
packed_args.push_back(args[0]);
prologue_stmts.push_back(Evaluate(Call(
DataType::Int(32), builtin::tvm_call_packed(), packed_args)));
}
// Add a single epilogue call to reset the access policy window and
// restore L2 limit
Array<PrimExpr> reset_args;
reset_args.push_back(
StringImm(tvm_cuda_stream_reset_access_policy_window));
epilogue_stmts.push_back(Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), reset_args)));
}
}
// Stitch prologue statements before the original body
if (!prologue_stmts.empty()) {
// Chain the Let/Evaluate statements sequentially
Stmt seq = prologue_stmts.size() == 1 ? prologue_stmts[0]
: SeqStmt(prologue_stmts);
fptr->body = SeqStmt({seq, fptr->body});
}
if (!epilogue_stmts.empty()) {
Stmt seq_end = epilogue_stmts.size() == 1 ? epilogue_stmts[0]
: SeqStmt(epilogue_stmts);
fptr->body = SeqStmt({fptr->body, seq_end});
}
return f; return f;
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
/*! /*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API. * \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/ */
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h> #include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
...@@ -32,6 +33,7 @@ ...@@ -32,6 +33,7 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -43,13 +45,11 @@ namespace tvm { ...@@ -43,13 +45,11 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi; using namespace ffi;
static constexpr const char *kDeviceContextVar = "device_api_context";
namespace { namespace {
class ReturnRewriter : public StmtMutator { class ReturnRewriter : public StmtMutator {
public: public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) explicit ReturnRewriter(Var ret_var) : ret_var_(ret_var) {}
: ret_var_(std::move(ret_var)), ret_tcode_(std::move(ret_tcode)) {}
Stmt VisitStmt_(const ForNode *node) override { Stmt VisitStmt_(const ForNode *node) override {
if (node->kind == ForKind::kParallel) if (node->kind == ForKind::kParallel)
...@@ -79,8 +79,6 @@ private: ...@@ -79,8 +79,6 @@ private:
struct ConvertedInfo { struct ConvertedInfo {
int type_index{-1}; int type_index{-1};
PrimExpr expr; PrimExpr expr;
Buffer dummy_val_buffer;
Buffer dummy_tcode_buffer;
}; };
ConvertedInfo ConvertForFFI(const PrimExpr &val) { ConvertedInfo ConvertForFFI(const PrimExpr &val) {
...@@ -88,7 +86,11 @@ private: ...@@ -88,7 +86,11 @@ private:
// convert val's data type to FFI data type, return type code // convert val's data type to FFI data type, return type code
DataType dtype = val.dtype(); DataType dtype = val.dtype();
if (dtype.is_int() || dtype.is_uint()) { if (dtype.is_bool()) {
info.type_index = ffi::TypeIndex::kTVMFFIBool;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_int() || dtype.is_uint()) {
info.type_index = ffi::TypeIndex::kTVMFFIInt; info.type_index = ffi::TypeIndex::kTVMFFIInt;
info.expr = Cast(DataType::Int(64), val); info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_float()) { } else if (dtype.is_float()) {
...@@ -101,56 +103,39 @@ private: ...@@ -101,56 +103,39 @@ private:
LOG(FATAL) << "data type " << dtype << " not supported yet"; LOG(FATAL) << "data type " << dtype << " not supported yet";
} }
// If multiple return locations have the same data type, use the
// same dummy buffer declaration.
auto it = dummy_val_buffer_map_.find(info.type_index);
if (it != dummy_val_buffer_map_.end()) {
info.dummy_val_buffer = it->second;
} else {
info.dummy_val_buffer =
Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0),
ret_var_->name_hint, 0, 0, kDefault);
dummy_val_buffer_map_[info.type_index] = info.dummy_val_buffer;
}
// The type_index is always a 32-bit int, so we don't need to have a
// separate map.
if (!dummy_tcode_buffer_.defined()) {
dummy_tcode_buffer_ =
Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0),
ret_tcode_->name_hint, 0, 0, kDefault);
}
info.dummy_tcode_buffer = dummy_tcode_buffer_;
return info; return info;
} }
Stmt WriteToOut(const PrimExpr &val) { Stmt WriteToOut(PrimExpr val) {
auto info = ConvertForFFI(val); auto info = ConvertForFFI(val);
Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0}); Stmt store_tindex = tir::Evaluate(
Stmt store_tcode = tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(),
BufferStore(info.dummy_tcode_buffer, info.type_index, {0}); {ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyTypeIndex),
IntImm(DataType::Int(32), info.type_index)}));
Stmt store_zero_padding = tir::Evaluate(tir::Call(
DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyZeroPadding),
IntImm(DataType::Int(32), 0)}));
Stmt store_val = tir::Evaluate(tir::Call(
DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyUnionValue),
info.expr}));
Stmt ret_zero = Evaluate(tvm::ret(0)); Stmt ret_zero = Evaluate(tvm::ret(0));
return SeqStmt({store_val, store_tcode, ret_zero}); return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero});
} }
Var ret_var_; Var ret_var_;
Var ret_tcode_;
int in_parallel_{0}; int in_parallel_{0};
std::unordered_map<int, Buffer> dummy_val_buffer_map_;
Buffer dummy_tcode_buffer_;
}; };
Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
ReturnRewriter rewriter(std::move(ret_var), std::move(ret_tcode));
return rewriter(std::move(body));
}
class SubroutineCallRewriter : public StmtExprMutator { class SubroutineCallRewriter : public StmtExprMutator {
public: public:
static Optional<Stmt> Apply(const Map<GlobalVar, String> &packed_func_methods, static ffi::Optional<Stmt>
Stmt stmt) { Apply(const ffi::Map<GlobalVar, ffi::String> &packed_func_methods,
Stmt stmt) {
SubroutineCallRewriter rewriter(packed_func_methods); SubroutineCallRewriter rewriter(packed_func_methods);
stmt = rewriter.VisitStmt(stmt); stmt = rewriter.VisitStmt(stmt);
if (rewriter.made_change_) { if (rewriter.made_change_) {
...@@ -162,16 +147,16 @@ public: ...@@ -162,16 +147,16 @@ public:
private: private:
explicit SubroutineCallRewriter( explicit SubroutineCallRewriter(
const Map<GlobalVar, String> &packed_func_methods) const ffi::Map<GlobalVar, ffi::String> &packed_func_methods)
: packed_func_methods(packed_func_methods) {} : packed_func_methods(packed_func_methods) {}
PrimExpr VisitExpr_(const CallNode *op) override { PrimExpr VisitExpr_(const CallNode *op) override {
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) { if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = tvm::ffi::GetRef<GlobalVar>(gvar_ptr); auto gvar = ffi::GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) { if (auto symbol = packed_func_methods.Get(gvar)) {
Array<PrimExpr> cpacked_args; ffi::Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value())); cpacked_args.push_back(tir::StringImm(symbol.value()));
for (auto arg : node->args) { for (auto arg : node->args) {
cpacked_args.push_back(arg); cpacked_args.push_back(arg);
...@@ -187,19 +172,18 @@ private: ...@@ -187,19 +172,18 @@ private:
return node; return node;
} }
const Map<GlobalVar, String> &packed_func_methods; const ffi::Map<GlobalVar, ffi::String> &packed_func_methods;
bool made_change_{false}; bool made_change_{false};
}; };
} // namespace } // namespace
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, const std::string &msg) { inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(std::move(lhs) == std::move(rhs), tvm::tir::StringImm(msg), return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
Evaluate(0));
} }
inline Stmt MakeAssertNotNull(PrimExpr ptr, const std::string &msg) { inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {std::move(ptr)}); Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0)); return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
} }
...@@ -254,21 +238,16 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -254,21 +238,16 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} }
auto *func_ptr = func.CopyOnWrite(); auto *func_ptr = func.CopyOnWrite();
// set the global symbol to the packed function name
const Stmt nop = Evaluate(0); const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size()); int num_args = static_cast<int>(func_ptr->params.size());
// Data field definitions // Data field definitions
// The packed fields // The packed fields
Var v_self_handle("self_handle", DataType::Handle());
Var v_packed_args("args", DataType::Handle()); Var v_packed_args("args", DataType::Handle());
Buffer buf_packed_arg_type_ids =
decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())},
DataType::Int(32), "arg_type_ids");
Var v_num_packed_args("num_args", DataType::Int(32)); Var v_num_packed_args("num_args", DataType::Int(32));
Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::Void()))); Var v_result("result", PointerType(PrimType(DataType::Void())));
Var v_out_ret_tcode("out_ret_tcode",
PointerType(PrimType(DataType::Int(32))));
Var v_resource_handle("resource_handle", DataType::Handle());
// The arguments of the function.
// The device context // The device context
Var device_id("dev_id"); Var device_id("dev_id");
...@@ -278,37 +257,24 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -278,37 +257,24 @@ PrimFunc MakePackedAPI(PrimFunc func) {
std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations; std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations;
std::unordered_map<const VarNode *, PrimExpr> vmap; std::unordered_map<const VarNode *, PrimExpr> vmap;
ArgBinder binder(&vmap); ArgBinder binder(&vmap);
std::vector<Stmt> shape_checks;
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
bool disable_dynamic_tail_split =
ctxt->GetConfig<Bool>(kDisableDynamicTailSplit, Bool(true)).value();
// --------------------------- // ---------------------------
// local function definitions // local function definitions
// load i-th argument as type t // load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) { auto f_load_arg_value = [&](DataType arg_type, int i) {
Array<PrimExpr> call_args{ ffi::Array<PrimExpr> call_args{
v_packed_args, IntImm(DataType::Int(32), i), v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMValueContent)}; IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)};
// load 64 bit version // load 64 bit version
DataType api_type = APIType(t); DataType api_type = APIType(arg_type);
PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version. // cast to the target version.
if (api_type != t) { if (api_type != arg_type) {
res = Cast(t, res); res = Cast(arg_type, res);
} }
return res; return res;
}; };
// Find the device API context argument based on name
for (const auto &param : func_ptr->params) {
if (param->name_hint == kDeviceContextVar) {
num_args--;
v_resource_handle = param;
break;
}
}
// Assert correct type codes for each argument. This must be done // Assert correct type codes for each argument. This must be done
// *before* any initialization steps produced by // *before* any initialization steps produced by
// `binder.BindDLTensor()`. The validity of those initialization // `binder.BindDLTensor()`. The validity of those initialization
...@@ -321,12 +287,10 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -321,12 +287,10 @@ PrimFunc MakePackedAPI(PrimFunc func) {
return error_message.str(); return error_message.str();
}())); }()));
seq_init.push_back(MakeAssertNotNull( if (num_args > 0) {
v_packed_args, name_hint + ": TVMValue* arg pointer was NULL")); seq_init.push_back(
seq_init.push_back(MakeAssertNotNull( MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL"));
buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL")); }
seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop));
// Need to delay binding of the buffers, in case some arguments also // Need to delay binding of the buffers, in case some arguments also
// appear in the buffer. // appear in the buffer.
...@@ -335,26 +299,17 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -335,26 +299,17 @@ PrimFunc MakePackedAPI(PrimFunc func) {
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) { for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i]; Var param = func_ptr->params[i];
PrimExpr arg_value;
// Ignore the device context argument, as it will still be passed // type index checks
// as a native argument. Var type_index(param->name_hint + ".type_index", DataType::Int(32));
if (param->name_hint == kDeviceContextVar) { seq_init.push_back(LetStmt(
continue;
}
var_def.emplace_back(f_arg_value(param.dtype(), i), param);
if (func_ptr->buffer_map.count(param)) {
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
}
// type code checks
Var type_index(param->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(LetStmt(
type_index, type_index,
BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), tir::Call(DataType::Int(32), builtin::tvm_struct_get(),
{v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}),
nop)); nop));
DataType t = param.dtype(); DataType dtype = param.dtype();
if (t.is_handle()) { if (dtype.is_handle()) {
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer"; msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_init.emplace_back( seq_init.emplace_back(
...@@ -363,23 +318,63 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -363,23 +318,63 @@ PrimFunc MakePackedAPI(PrimFunc func) {
type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr || type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr ||
type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin, type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin,
tvm::tir::StringImm(msg.str()), nop)); tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) { // if type_index is Tensor, we need to add the offset of the DLTensor
// header which always equals 16 bytes, this ensures that T.handle always
// shows up as a DLTensor*
const int64_t object_cell_offset = sizeof(TVMFFIObject);
static_assert(object_cell_offset == 24);
arg_value = f_load_arg_value(param.dtype(), i);
PrimExpr handle_from_tensor =
Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(),
{arg_value, IntImm(DataType::Int(32), object_cell_offset)});
arg_value = Select(type_index == ffi::TypeIndex::kTVMFFITensor,
handle_from_tensor, arg_value);
} else if (dtype.is_bool()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be boolean";
seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool ||
type_index == ffi::TypeIndex::kTVMFFIInt,
tvm::tir::StringImm(msg.str()), nop));
arg_value =
Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), i));
} else if (dtype.is_int() || dtype.is_uint()) {
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int"; msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(AssertStmt(type_index == kDLInt, seq_init.emplace_back(
tvm::tir::StringImm(msg.str()), nop)); AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool,
tvm::tir::StringImm(msg.str()), nop));
arg_value = f_load_arg_value(param.dtype(), i);
} else { } else {
ICHECK(t.is_float()); ICHECK(dtype.is_float());
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float"; msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_init.emplace_back(AssertStmt(type_index == kDLFloat, seq_init.emplace_back(
tvm::tir::StringImm(msg.str()), nop)); AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool,
tvm::tir::StringImm(msg.str()), nop));
// use select so we can also handle int conversion to bool
arg_value = tir::Select(
type_index == ffi::TypeIndex::kTVMFFIFloat,
/* true_value = */ f_load_arg_value(param.dtype(), i),
/* false_value = */
Cast(param.dtype(), f_load_arg_value(DataType::Int(64), i)));
}
var_def.emplace_back(arg_value, param);
if (func_ptr->buffer_map.count(param)) {
// buffer binding now depends on type index
// if the index is Tensor handle, we need to offset to get the DLTensor*
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
} }
} }
Array<Var> args{v_packed_args, buf_packed_arg_type_ids->data, // signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny*
v_num_packed_args, v_out_ret_value, // v_result)
v_out_ret_tcode, v_resource_handle}; ffi::Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args,
v_result};
// Arg definitions are defined before buffer binding to avoid the use before // Arg definitions are defined before buffer binding to avoid the use before
// def errors. // def errors.
...@@ -392,83 +387,57 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -392,83 +387,57 @@ PrimFunc MakePackedAPI(PrimFunc func) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true); binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
} }
for (const auto &kv : buffer_def) { for (const auto &[var, buffer] : buffer_def) {
binder.BindDLTensor(kv.second, device_type, device_id, kv.first, binder.BindDLTensor(buffer, device_type, device_id, var,
name_hint + "." + kv.first->name_hint); name_hint + "." + var->name_hint);
arg_buffer_declarations.push_back(DeclBuffer(kv.second, nop)); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
} }
// reset global symbol to attach prefix
func = func = WithAttrs(
WithAttrs(std::move(func), std::move(func),
{{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)}, {{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}}); {tvm::attr::kTarget, target_host},
Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); {tvm::attr::kGlobalSymbol,
ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}});
Stmt body = ReturnRewriter(v_result)(func_ptr->body);
body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::compute_scope, body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::compute_scope,
StringImm(name_hint + "_compute_"), body); StringImm(name_hint + "_compute_"), body);
// Set device context // Set device context
if (vmap.count(device_id.get())) { if (vmap.count(device_id.get())) {
auto node = String("default"); ffi::Any node = ffi::String("default");
seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop)); seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop));
seq_check.push_back( seq_check.push_back(
AttrStmt(node, tir::attr::device_type, device_type, nop)); AttrStmt(node, tir::attr::device_type, device_type, nop));
if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) { if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
Stmt set_device = Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), Evaluate(Call(DataType::Int(32), tir::builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device), {StringImm(runtime::symbol::tvm_set_device),
device_type, device_id})); device_type, device_id}));
body = SeqStmt({set_device, body}); body = SeqStmt({set_device, body});
} }
} }
// (zhengju) For dynamic constraint, we need to check the buffer shape and
// dtype to make sure the buffer can be vectorized.
for (const auto &kv : buffer_def) {
if (disable_dynamic_tail_split) {
Optional<Integer> opt_dynamic_alignment =
ctxt->GetConfig(kDynamicAlignment, Optional<Integer>());
int dynamic_alignment = opt_dynamic_alignment.value_or(Integer(8))->value;
// The vectorize dimension will be the last dimension of the buffer
auto vectorize_dim = kv.second->shape[kv.second->shape.size() - 1];
auto shape_vectorize_expr = [&]() -> PrimExpr {
PrimExpr result = IntImm(kv.second->DefaultIndexType(), 1);
result = result * vectorize_dim;
result = FloorMod(result, IntImm(result->dtype, dynamic_alignment));
return result;
}();
shape_checks.emplace_back(AssertStmt(
shape_vectorize_expr == 0,
tvm::tir::StringImm(
kv.second->name +
": Vectorize dimension in buffer must be divisible by " +
std::to_string(dynamic_alignment)),
nop));
}
}
// Return error code of zero on success // Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))}); body = SeqStmt({body, Evaluate(ret(Integer(0)))});
if (!disable_dynamic_tail_split) { body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(), arg_buffer_declarations},
arg_buffer_declarations}, body);
body);
} else {
body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
arg_buffer_declarations, shape_checks},
body);
}
func_ptr->body = body; func_ptr->body = body;
func_ptr->params = args; func_ptr->params = args;
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params); ffi::Array<Var> undefined = UndefinedVars(body, func_ptr->params);
ICHECK_EQ(undefined.size(), 0) ICHECK_EQ(undefined.size(), 0)
<< "In PrimFunc " << name_hint << " variables " << undefined << "In PrimFunc " << name_hint << " variables " << undefined
<< " are used, but are not passed in as API arguments"; << " are used, but are not passed in as API arguments";
func_ptr->buffer_map = Map<Var, Buffer>(); func_ptr->buffer_map = ffi::Map<Var, Buffer>();
func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function. func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function.
return func; return func;
} }
......
...@@ -240,37 +240,42 @@ public: ...@@ -240,37 +240,42 @@ public:
simplifier.MarkBufferMapShapes(func); simplifier.MarkBufferMapShapes(func);
func.CopyOnWrite()->body = simplifier(func->body); func.CopyOnWrite()->body = simplifier(func->body);
// Begin to remove useless var and buffer // Optionally remove unused buffer parameters
// First get used buffers if (simplify_arguments) {
simplifier.used_buffers_ = CollectUsedBuffers(func); // First get used buffers
simplifier.used_buffers_ = CollectUsedBuffers(func);
bool param_updated = false;
Array<Var> new_params; bool param_updated = false;
Map<Var, Buffer> new_buffer_map; Array<Var> new_params;
// Check whether each buffer is used Map<Var, Buffer> new_buffer_map;
for (const auto &var : func->params) { // Check whether each buffer is used
if (func->buffer_map.find(var) != func->buffer_map.end()) { for (const auto &var : func->params) {
if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != if (func->buffer_map.find(var) != func->buffer_map.end()) {
simplifier.used_buffers_.end()) { if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
new_params.push_back(var); simplifier.used_buffers_.end()) {
new_buffer_map.Set(var, func->buffer_map[var]); new_params.push_back(var);
} else if (simplifier.used_in_buffer_def_.find( new_buffer_map.Set(var, func->buffer_map[var]);
func->buffer_map[var]->data.get()) != } else if (simplifier.used_in_buffer_def_.find(
simplifier.used_in_buffer_def_.end()) { func->buffer_map[var]->data.get()) !=
new_params.push_back(var); simplifier.used_in_buffer_def_.end()) {
new_buffer_map.Set(var, func->buffer_map[var]); new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else {
param_updated = true;
}
} else { } else {
param_updated = true; // Non-buffer parameters (e.g., scalars) are always retained
new_params.push_back(var);
} }
} }
}
if (param_updated) { if (param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span); new_buffer_map, func->attrs, func->span);
} else { }
return func;
} }
// Either no change to params or argument simplification disabled
return func;
} }
private: private:
......
...@@ -13,7 +13,7 @@ def debug_print_buffer(M=16, N=16, dtype="float16"): ...@@ -13,7 +13,7 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
shared_buf = T.alloc_shared([M, N], dtype) shared_buf = T.alloc_shared([M, N], dtype)
T.print(shared_buf) T.print(shared_buf)
jit_kernel = tilelang.compile(program, target="cuda") jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi")
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
profiler.run_once() profiler.run_once()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment