Unverified Commit 74da3696 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Use tvm ffi as the default execution backend (#1259)

* [Refactor] Update FFI type handling and simplify argument management

* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity.
* Updated function registration in `runtime.cc` to utilize canonical names for better consistency.
* Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled.
* Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection.
* Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.

* [Update] Sync TVM submodule and enhance kernel source handling

* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes.
* Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging.
* Commented out the main execution call in test files to prevent unintended execution during testing.
* Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues.
* Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends.

* [Refactor] Clean up imports and improve code formatting

* Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code.
* Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency.
* Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality.

* Update execution backend options and improve resolution logic

- Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target.
- Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions.
- Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target.
- Updated documentation to reflect changes in execution backend options and their defaults.

* lint fix

* fix

* Enhance argument handling in CUDA and HIP runtime modules

- Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime.
- Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers.
- Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks.

* lint fix

* lint fix

* lint fix

* lint fix

* minor fix

* fix

* recover check

* Refactor argument binding and validation in `arg_binder.cc`

- Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers.
- Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards.
- Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling.
- Minor adjustments in test files to streamline kernel execution and improve readability.

* lint fix

* stride fix

* minor fix

* fix

* lint fix

* lint fix

* Add CUDA stream access policy window helpers and integrate with L2 persistent cache management

- Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage.
- Updated runtime files to include new FFI packed functions for managing stream attributes.
- Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown.
- Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source.

* check with symbolic

* support null ptr

* Update CMakeLists and lower.py for code generation and subproject status

- Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support.
- Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility.
- Marked the TVM subproject as dirty to indicate local modifications.

* lint fix

* Update comments for clarity in quickstart.py
parent 921b96a3
Subproject commit 093b2cdb2187140b197336496d65d61ace89e8ff
Subproject commit f4105f89a646622acc9818584d1d91e2ca3f533d
......@@ -138,6 +138,7 @@ file(GLOB TILE_LANG_SRCS
src/transform/*.cc
src/op/*.cc
src/target/utils.cc
src/target/codegen_c_host.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
# intrin_rule doesn't have system dependency
......
......@@ -166,7 +166,6 @@ def main():
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......
......@@ -468,7 +468,6 @@ def run_test(
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw,
block_DK, block_DV, threads, num_stages)
print(kernel.get_kernel_source())
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
if use_g:
......
......@@ -117,6 +117,7 @@ def test_example_chunk_o_bwd_compilation():
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
block_DK, block_DV, threads, num_stages)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv,
W) # noqa: F841
if use_g:
......
......@@ -55,10 +55,9 @@ block_M = 128
block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
# Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
# Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
......
......@@ -104,6 +104,7 @@ tilelang = "tilelang"
# TVM
"tilelang/3rdparty/tvm/src" = "3rdparty/tvm/src"
"tilelang/3rdparty/tvm/python" = "3rdparty/tvm/python"
"tilelang/3rdparty/tvm/include" = "3rdparty/tvm/include"
"tilelang/3rdparty/tvm/version.py" = "3rdparty/tvm/version.py"
# CUTLASS
"tilelang/3rdparty/cutlass/include" = "3rdparty/cutlass/include"
......
......@@ -13,6 +13,12 @@
namespace tvm {
namespace tl {
#if 1
// Thread-local storage for restoring the L2 persisting cache limit
static thread_local size_t __tl_prev_persisting_l2_cache_size = 0;
static thread_local bool __tl_prev_persisting_l2_cache_saved = false;
#endif
#if (CUDA_MAJOR_VERSION >= 12)
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss;
......@@ -91,19 +97,21 @@ struct TensorMapArgs {
// set device api
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args,
Any *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, T.swizzle,
T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << '\n'
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
// Register using the canonical names defined in runtime.h
refl::GlobalDef().def_packed(
tl::tvm_tensormap_create_tiled, [](PackedArgs args, Any *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< '\n'
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
}
struct TensorMapIm2ColArgs {
......@@ -183,7 +191,7 @@ struct TensorMapIm2ColArgs {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
tl::tvm_tensormap_create_im2col, [](PackedArgs args, Any *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
......@@ -201,5 +209,141 @@ TVM_FFI_STATIC_INIT_BLOCK() {
#endif // (CUDA_MAJOR_VERSION >= 12)
//
// CUDA L2 Persisting Cache Access Policy Window helpers.
// Exposed as TVM FFI packed functions similar to TMA initialization.
//
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// Set stream access policy window and adjust persisting L2 cache size
// Args:
// [0]: void* base_ptr (required)
// [1]: int64 num_bytes (required)
// [2]: float hit_ratio (optional, default 0.8)
// [3]: void* stream (optional, default 0 => default stream)
// [4]: int64 l2_limit_bytes (optional, default = num_bytes)
refl::GlobalDef().def_packed(
tl::tvm_cuda_stream_set_access_policy_window,
[](PackedArgs args, Any *ret) {
ICHECK(args.size() >= 2) << "Expected at least base_ptr and num_bytes";
void *base_ptr = args[0].cast<void *>();
size_t num_bytes = static_cast<size_t>(args[1].cast<int64_t>());
float hit_ratio = 0.8f;
if (args.size() >= 3) {
// Accept double/float
hit_ratio = static_cast<float>(args[2].cast<double>());
}
CUstream stream = nullptr;
if (args.size() >= 4) {
stream = reinterpret_cast<CUstream>(args[3].cast<void *>());
}
size_t l2_limit_bytes = num_bytes;
if (args.size() >= 5) {
l2_limit_bytes = static_cast<size_t>(args[4].cast<int64_t>());
}
// Clamp requested limit to device capability
CUdevice device;
CUresult result = cuCtxGetDevice(&device);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to get current CUDA device: " << result;
}
int max_persisting = 0;
result = cuDeviceGetAttribute(
&max_persisting, CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE,
device);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to query MAX_PERSISTING_L2_CACHE_SIZE: "
<< result;
}
if (max_persisting > 0 &&
l2_limit_bytes > static_cast<size_t>(max_persisting)) {
l2_limit_bytes = static_cast<size_t>(max_persisting);
}
// Save current limit to restore later
size_t init_persisting_l2_cache_size = 0;
result = cuCtxGetLimit(&init_persisting_l2_cache_size,
CU_LIMIT_PERSISTING_L2_CACHE_SIZE);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to get current persisting L2 cache size limit: "
<< result;
}
__tl_prev_persisting_l2_cache_size = init_persisting_l2_cache_size;
__tl_prev_persisting_l2_cache_saved = true;
// Set new limit
result =
cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, l2_limit_bytes);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to set persisting L2 cache size limit: "
<< result;
}
// Apply access policy window to stream
CUstreamAttrValue stream_attribute;
memset(&stream_attribute, 0, sizeof(stream_attribute));
stream_attribute.accessPolicyWindow.base_ptr = base_ptr;
stream_attribute.accessPolicyWindow.num_bytes = l2_limit_bytes;
stream_attribute.accessPolicyWindow.hitRatio = hit_ratio;
stream_attribute.accessPolicyWindow.hitProp =
CU_ACCESS_PROPERTY_PERSISTING;
stream_attribute.accessPolicyWindow.missProp =
CU_ACCESS_PROPERTY_STREAMING;
result = cuStreamSetAttribute(stream,
CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
&stream_attribute);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to set stream access policy window: " << result;
}
*ret = static_cast<int>(result);
});
// Reset stream access policy window and restore the previous L2 cache size
// Args:
// [0]: void* stream (optional, default 0)
refl::GlobalDef().def_packed(
tl::tvm_cuda_stream_reset_access_policy_window,
[](PackedArgs args, Any *ret) {
CUstream stream = nullptr;
if (args.size() >= 1) {
stream = reinterpret_cast<CUstream>(args[0].cast<void *>());
}
CUstreamAttrValue stream_attribute;
memset(&stream_attribute, 0, sizeof(stream_attribute));
// num_bytes = 0 disables the access policy window on the stream
stream_attribute.accessPolicyWindow.num_bytes = 0;
CUresult result = cuStreamSetAttribute(
stream, CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
&stream_attribute);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to reset stream access policy window: "
<< result;
}
result = cuCtxResetPersistingL2Cache();
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to reset persisting L2 cache lines: " << result;
}
if (__tl_prev_persisting_l2_cache_saved) {
result = cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE,
__tl_prev_persisting_l2_cache_size);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to restore persisting L2 cache size limit: "
<< result;
}
__tl_prev_persisting_l2_cache_saved = false;
}
*ret = static_cast<int>(result);
});
}
} // namespace tl
} // namespace tvm
......@@ -16,7 +16,13 @@ constexpr const char *tvm_tensormap_create_tiled =
constexpr const char *tvm_tensormap_create_im2col =
"__tvm_tensormap_create_im2col";
#endif // (CUDA_MAJOR_VERSION >= 12)
// CUDA stream access policy window helpers
constexpr const char *tvm_cuda_stream_set_access_policy_window =
"__tvm_cuda_stream_set_access_policy_window";
constexpr const char *tvm_cuda_stream_reset_access_policy_window =
"__tvm_cuda_stream_reset_access_policy_window";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file
#endif // TVM_TL_RUNTIME_RUNTIME_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.cc
*/
#include "codegen_c_host.h"
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/codegen.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
// For escaping strings embedded into generated C sources
#include "support/str_escape.h"
namespace tvm {
namespace tl {
CodeGenCHost::CodeGenCHost() {
module_name_ = name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx);
}
void CodeGenCHost::Init(bool output_ssa, bool emit_asserts,
bool emit_fwd_func_decl, std::string target_str,
const std::unordered_set<std::string> &devices) {
emit_asserts_ = emit_asserts;
emit_fwd_func_decl_ = emit_fwd_func_decl;
declared_globals_.clear();
decl_stream << "// tilelang target: " << target_str << "\n";
decl_stream << "#define TVM_EXPORTS\n";
decl_stream << "#include \"tvm/runtime/base.h\"\n";
decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
decl_stream << "#include \"tvm/ffi/c_api.h\"\n";
decl_stream << "#include <math.h>\n";
// snprintf for richer assert messages with actual values
decl_stream << "#include <stdio.h>\n";
decl_stream << "#include <stdbool.h>\n";
CodeGenCHost::InitGlobalContext();
tvm::codegen::CodeGenC::Init(output_ssa);
}
void CodeGenCHost::InitGlobalContext() {
decl_stream << "void* " << tvm::ffi::symbol::tvm_ffi_library_ctx
<< " = NULL;\n";
}
void CodeGenCHost::DefineModuleName() {
decl_stream << "void* " << module_name_ << " = NULL;\n";
}
void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &func) {
return AddFunction(gvar, func, /*emit_fwd_func_decl=*/false);
}
void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &func,
bool emit_fwd_func_decl) {
auto global_symbol =
func->GetAttr<tvm::ffi::String>(tvm::attr::kGlobalSymbol);
if (global_symbol) {
function_names_.push_back(global_symbol.value());
}
emit_fwd_func_decl_ = emit_fwd_func_decl;
tvm::codegen::CodeGenC::AddFunction(gvar, func);
if (func->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc) && !has_main_func_) {
ICHECK(global_symbol.has_value())
<< "CodeGenCHost: The entry func must have the global_symbol "
"attribute, "
<< "but function " << gvar << " only has attributes " << func->attrs;
function_names_.push_back(tvm::ffi::symbol::tvm_ffi_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(func->ret_type, stream);
stream << " " << tvm::ffi::symbol::tvm_ffi_main
<< "(void* self, void* args,int num_args, void* result) {\n";
stream << " return " << static_cast<std::string>(global_symbol.value())
<< "(self, args, num_args, result);\n";
stream << "}\n";
has_main_func_ = true;
}
}
void CodeGenCHost::GenerateForwardFunctionDeclarations(
tvm::ffi::String global_symbol, const tvm::ffi::Array<tvm::Type> &arg_types,
const tvm::Type &ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
for (auto &func_already_defined : GetFunctionNames()) {
if (global_symbol == func_already_defined) {
return;
}
}
this->PrintFuncPrefix(fwd_decl_stream);
this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 0; i < arg_types.size(); ++i) {
if (i > 0) {
fwd_decl_stream << ", ";
}
tvm::codegen::CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
void CodeGenCHost::PrintFuncPrefix(std::ostream &os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n";
}
void CodeGenCHost::PrintType(tvm::DataType t, std::ostream &os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
ICHECK_EQ(lanes, 1) << "does not support vector types";
os << "void*";
return;
}
if (t.is_void()) {
os << "void";
return;
}
if (t == tvm::DataType::Bool()) {
os << "bool";
return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16:
os << "half";
break;
case 32:
os << "float";
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1)
return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
}
if (t.is_bfloat16()) {
os << "__bf16";
return;
}
if (t.is_int() || t.is_uint()) {
if (t.is_uint()) {
os << 'u';
}
switch (t.bits()) {
case 8:
os << "int8_t";
break;
case 16:
os << "int16_t";
break;
case 32:
os << "int32_t";
break;
case 64:
os << "int64_t";
break;
case 1:
os << "int32_t";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1)
return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
void CodeGenCHost::VisitExpr_(const tvm::tir::BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
os << "((";
PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << v;
}
os << "))";
}
void CodeGenCHost::PrintGetFuncFromBackend(
const std::string &func_name, const std::string &packed_func_name) {
this->PrintIndent();
this->stream << "if (" << packed_func_name << " == NULL) {\n";
int packed_func_if_scope = this->BeginScope();
this->PrintIndent();
this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \""
<< func_name << "\""
<< ", &" << packed_func_name << ") != 0) {\n";
int get_func_env_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(get_func_env_scope);
this->PrintIndent();
this->stream << "}\n";
this->EndScope(packed_func_if_scope);
this->PrintIndent();
this->stream << "}\n";
}
void CodeGenCHost::PrintCallPacked(const tvm::tir::CallNode *op) {
using namespace tvm::tir;
const StringImmNode *func_name = op->args[0].as<StringImmNode>();
ICHECK(func_name != nullptr)
<< "tvm_call_[c]packed_lowered expects first argument as function name";
int64_t begin = op->args[2].as<IntImmNode>()->value;
int64_t end = op->args[3].as<IntImmNode>()->value;
int64_t num_args = end - begin;
ICHECK_GE(num_args, 0);
std::string packed_func_name;
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
packed_func_name = GetPackedName(op);
this->PrintGetFuncFromBackend(func_name->value, packed_func_name);
} else {
// directly use the original symbol
ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered()));
packed_func_name =
tvm::ffi::symbol::tvm_ffi_symbol_prefix + func_name->value;
}
std::string args_stack = PrintExpr(op->args[1]);
this->PrintIndent();
std::string result = name_supply_->FreshName("result");
this->stream << "TVMFFIAny " << result << ";\n";
this->PrintIndent();
// must make sure type_index is set to none
this->stream << result << ".type_index = kTVMFFINone;\n";
this->PrintIndent();
this->stream << result << ".zero_padding = 0;\n";
this->PrintIndent();
this->stream << result << ".v_int64 = 0;\n";
this->PrintIndent();
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
this->stream << "if (TVMFFIFunctionCall(" << packed_func_name << ", ";
} else {
this->stream << "if (" << packed_func_name << "(NULL, ";
}
this->stream << "(TVMFFIAny*) " << args_stack << ", " << num_args << ", "
<< "&" << result << ") != 0) {\n";
int func_call_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(func_call_scope);
this->PrintIndent();
this->stream << "}\n";
}
std::string CodeGenCHost::GetPackedName(const tvm::tir::CallNode *op) {
using namespace tvm::tir;
const StringImmNode *s = op->args[0].as<StringImmNode>();
ICHECK(s != nullptr)
<< "tvm_call_packed_lowered expects first argument as function name";
std::string func_name = s->value;
std::string packed_func_name = func_name + "_packed";
std::string unique_name;
auto it = declared_globals_.find(packed_func_name);
if (it != declared_globals_.end()) {
unique_name = it->second;
} else {
unique_name = name_supply_->FreshName(packed_func_name);
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
return unique_name;
}
void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op,
std::ostream &os) { // NOLINT(*)
using namespace tvm::tir;
if (op->op.same_as(builtin::tvm_stack_alloca())) {
std::string stack_name = name_supply_->FreshName("stack");
const std::string &type = op->args[0].as<StringImmNode>()->value;
const IntImmNode *num = op->args[1].as<IntImmNode>();
ICHECK(num != nullptr);
static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMFFIAny);
size_t size = 0;
if (type == "shape") {
size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit;
} else if (type == "tvm_ffi_any") {
size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit;
} else if (type == "array") {
size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
this->PrintIndent();
this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n";
os << stack_name;
} else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
this->PrintCallPacked(op);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
this->PrintCallPacked(op);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
} else {
tvm::codegen::CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
using namespace tvm::tir;
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
{
// Prepare the base error message
const auto *msg_node = op->message.as<StringImmNode>();
ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm";
const std::string &raw_msg = msg_node->value;
const std::string esc_msg = tvm::support::StrEscape(
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
/*escape_whitespace_special_chars=*/true);
// If the assertion condition contains any equality checks anywhere
// in a composite boolean expression, append the actual LHS/RHS values
// Collect all EQ nodes within the condition (including inside And/Or/Not)
std::vector<const EQNode *> eq_nodes;
{
std::vector<PrimExpr> stk;
stk.push_back(op->condition);
while (!stk.empty()) {
PrimExpr cur = stk.back();
stk.pop_back();
if (const auto *eq = cur.as<EQNode>()) {
eq_nodes.push_back(eq);
continue;
}
if (const auto *an = cur.as<AndNode>()) {
stk.push_back(an->a);
stk.push_back(an->b);
continue;
}
if (const auto *on = cur.as<OrNode>()) {
stk.push_back(on->a);
stk.push_back(on->b);
continue;
}
if (const auto *nn = cur.as<NotNode>()) {
stk.push_back(nn->a);
continue;
}
}
}
if (!eq_nodes.empty()) {
// Build a single detailed message that includes all LHS/RHS pairs
PrintIndent();
stream << "char __tvm_assert_msg_buf[1024];\n";
PrintIndent();
stream << "int __tvm_assert_msg_len = snprintf(__tvm_assert_msg_buf, "
"sizeof(__tvm_assert_msg_buf), \"%s\", \""
<< esc_msg << "\");\n";
auto escape_for_printf_literal = [&](const std::string &s) {
std::string out;
out.reserve(s.size());
for (char c : s) {
if (c == '%') {
out += "%%";
} else if (c == '"') {
out += "\\\"";
} else if (c == '\\') {
out += "\\\\";
} else {
out.push_back(c);
}
}
return out;
};
for (const auto *eq : eq_nodes) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
std::string lhs_disp = escape_for_printf_literal(lhs);
std::string rhs_disp = escape_for_printf_literal(rhs);
PrintIndent();
stream << "__tvm_assert_msg_len += snprintf(__tvm_assert_msg_buf + "
"__tvm_assert_msg_len, "
"sizeof(__tvm_assert_msg_buf) - __tvm_assert_msg_len, \"; ("
<< lhs_disp << " == " << rhs_disp
<< ") got: %lld, expected: %lld\", (long long)(" << lhs
<< "), (long long)(" << rhs << "));\n";
}
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
"__tvm_assert_msg_buf);\n";
} else {
// Fallback: just emit the base message
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg
<< "\");\n";
}
}
PrintIndent();
stream << "return -1;\n";
this->EndScope(assert_if_scope);
PrintIndent();
stream << "}\n";
}
this->PrintStmt(op->body);
}
void CodeGenCHost::VisitExpr_(const tvm::tir::MinNode *op,
std::ostream &os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os);
}
void CodeGenCHost::VisitExpr_(const tvm::tir::MaxNode *op,
std::ostream &os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os);
}
template <typename T>
inline void CodeGenCHost::PrintTernaryCondExpr(const T *op, const char *compare,
std::ostream &os) { // NOLINT(*)
std::ostringstream temp_a;
VisitExpr(op->a, temp_a);
std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
std::ostringstream temp_b;
VisitExpr(op->b, temp_b);
std::string b_id = SSAGetID(temp_b.str(), op->b.dtype());
os << "((" << a_id << ") " << compare << " (" << b_id << ") "
<< "? (" << a_id << ") : (" << b_id << "))";
}
} // namespace tl
} // namespace tvm
namespace tvm {
namespace tl {
using tvm::codegen::CodeGenSourceBase;
using tvm::codegen::CSourceModuleCreate;
using tvm::ffi::Array;
using tvm::ffi::Map;
using tvm::ffi::Module;
using tvm::ffi::String;
// Build function that mirrors TVM's host C codegen, registered under a
// TileLang-specific name.
::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod,
::tvm::Target target) {
bool output_ssa = false;
bool emit_asserts = true;
bool emit_fwd_func_decl = true;
std::unordered_set<std::string> devices;
if (mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>(
"device_contexts") != nullptr) {
::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String> device_contexts =
mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>(
"device_contexts")
.value();
for (auto const &context : device_contexts) {
devices.insert(context.second.data());
}
}
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
cg.SetConstantsByteAlignment(
target->GetAttr<::tvm::Integer>("constants-byte-alignment").value_or(16));
auto is_aot_executor_fn = [](::tvm::tir::PrimFunc const &func) -> bool {
return func->GetAttr<::tvm::Bool>("runner_function", ::tvm::Bool(false))
.value();
};
std::vector<std::pair<::tvm::GlobalVar, ::tvm::tir::PrimFunc>> funcs;
for (auto [gvar, base_func] : mod->functions) {
ICHECK(base_func->IsInstance<::tvm::tir::PrimFuncNode>())
<< "CodegenCHost: Can only take PrimFunc";
auto prim_func = ::tvm::Downcast<::tvm::tir::PrimFunc>(base_func);
funcs.push_back({gvar, prim_func});
}
auto sort_key = [&is_aot_executor_fn](const auto &kv) {
return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
};
std::sort(funcs.begin(), funcs.end(),
[&sort_key](const auto &kv_a, const auto &kv_b) {
return sort_key(kv_a) < sort_key(kv_b);
});
for (const auto &[gvar, prim_func] : funcs) {
cg.DeclareFunction(gvar, prim_func);
}
for (const auto &[gvar, prim_func] : funcs) {
cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}
std::string code = cg.Finish();
return ::tvm::codegen::CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_c", BuildTileLangCHost);
}
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.h
* \brief Generate C host code (TileLang copy).
*/
#ifndef TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
#define TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "target/source/codegen_c.h"
#include "tvm/target/codegen.h"
#include "tvm/tir/expr.h"
namespace tvm {
namespace tl {
// TileLang copy of TVM's CodeGenCHost, under the tl namespace.
// Inherits from tvm::codegen::CodeGenC.
class CodeGenCHost : public tvm::codegen::CodeGenC {
public:
CodeGenCHost();
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
std::string target_str,
const std::unordered_set<std::string> &devices);
void InitGlobalContext();
void AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &f) override;
void AddFunction(const tvm::GlobalVar &gvar, const tvm::tir::PrimFunc &f,
bool emit_fwd_func_decl);
/*!
* \brief Add functions from the (unordered) range to the current module in a
* deterministic order. This helps with debugging.
*
* \param functions A vector of unordered range of current module.
*/
void AddFunctionsOrdered(
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
void DefineModuleName();
using tvm::codegen::CodeGenC::PrintType;
void PrintType(tvm::DataType t, std::ostream &os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream &os) final; // NOLINT(*)
// overload visitor functions
void VisitExpr_(const tvm::tir::BroadcastNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const tvm::tir::CallNode *op,
std::ostream &os) override; // NOLINT(*)
// overload min and max to use the ternary operator, so we don't rely on the
// standard library implementations
void VisitExpr_(const tvm::tir::MinNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const tvm::tir::MaxNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitStmt_(const tvm::tir::AssertStmtNode *op) final; // NOLINT(*)
void GenerateForwardFunctionDeclarations(
tvm::ffi::String global_symbol,
const tvm::ffi::Array<tvm::Type> &arg_types,
const tvm::Type &ret_type) override;
tvm::ffi::Array<tvm::ffi::String> GetFunctionNames() {
return function_names_;
}
private:
std::string module_name_;
/* \brief mapping global packed func to the unique name */
std::unordered_map<std::string, std::string> declared_globals_;
/* \brief names of the functions declared in this module */
tvm::ffi::Array<tvm::ffi::String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
/*! \brief whether to emit forwared function declarations in the resulting C
* code */
bool emit_fwd_func_decl_;
/*! \brief whether to generate the entry function if encountered */
bool has_main_func_ = false;
std::string GetPackedName(const tvm::tir::CallNode *op);
void PrintGetFuncFromBackend(const std::string &func_name,
const std::string &packed_func_name);
void PrintCallPacked(const tvm::tir::CallNode *op);
/*!
* \brief Print ternary conditional operator implementing binary `op`
* Forces the operands to be in SSA form.
* \param op binary operator being expressed
* \param compare string representation of comparison operator
* \param os stream reference to print into
*/
template <typename T>
inline void PrintTernaryCondExpr(const T *op, const char *compare,
std::ostream &os); // NOLINT(*)
};
} // namespace tl
} // namespace tvm
#endif // TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
......@@ -203,12 +203,12 @@ void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name,
this->PrintIndent();
std::string ret_val = name_supply_->FreshName("ret_val");
std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n";
this->stream << "TVMFFIAny " << ret_val << ";\n";
this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent();
this->stream << "if (TVMFuncCall(" << packed_func_name << ", "
<< "(TVMValue*) stack_value"
<< "(TVMFFIAny*) stack_value"
<< ", "
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
......@@ -228,13 +228,13 @@ void CodeGenTileLangCPP::PrintFuncCallC(
this->PrintIndent();
std::string ret_val = name_supply_->FreshName("ret_val");
std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n";
this->stream << "TVMFFIAny " << ret_val << ";\n";
this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent();
this->stream << "if (" << packed_func_name << "( "
<< "(TVMValue*) stack_value "
<< "(TVMFFIAny*) stack_value "
<< ", "
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
......
......@@ -24,7 +24,11 @@ ExtractFuncInfo(const IRModule &mod) {
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
DataType dtype = f->params[i].dtype();
// Device runtime cannot directly take bool arguments, map to int32.
if (dtype.is_bool())
dtype = DataType::Int(32);
info.arg_types.push_back(dtype);
}
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
......
......@@ -35,7 +35,11 @@ ExtractFuncInfo(const IRModule &mod) {
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
DataType dtype = f->params[i].dtype();
// Device runtime cannot directly take bool arguments, map to int32.
if (dtype.is_bool())
dtype = DataType::Int(32);
info.arg_types.push_back(dtype);
}
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
......
......@@ -51,6 +51,43 @@ void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond,
}
}
bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_lets,
const PrimExpr &nullable_guard) {
// Currently only used in BindDLTensor, nullable_guard is already a defined
// bool, so use it directly.
auto MakeGuarded = [&](PrimExpr basic) -> PrimExpr {
// is_null || basic
return Or(nullable_guard, basic);
};
ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value;
if (const VarNode *v = arg.as<VarNode>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
// First time binding: identical behavior as Bind_
Var v_arg = Downcast<Var>(arg);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0)));
} else {
(*def_map_)[v] = value;
}
return true;
} else {
// Second or later binding: add is_null short-circuit
PrimExpr cond = MakeGuarded(it->second == value);
BinderAddAssert(&analyzer_, cond, arg_name, &asserts_);
}
} else {
// For non-Var expressions, also add is_null short-circuit
PrimExpr cond = MakeGuarded(arg == value);
BinderAddAssert(&analyzer_, cond, arg_name, &asserts_);
}
return false;
}
bool ArgBinder::Bind_(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_lets) {
ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value;
......@@ -96,8 +133,30 @@ void ArgBinder::BindBuffer(const Buffer &arg, const Buffer &value,
const std::string &arg_name, bool fuzzy_match) {
ICHECK_EQ(arg.scope(), value.scope())
<< "Argument " << arg_name << " Buffer bind scope mismatch";
ICHECK_EQ(arg->dtype, value->dtype)
<< "Argument " << arg_name << " Buffer bind data type mismatch";
// Relax dtype check to allow FP8 E4M3 variants to bind together.
auto dtype_compatible = [](DataType expected, DataType provided) -> bool {
if (expected == provided)
return true;
// If expected is float8_e4m3, allow float8_e4m3fn/float8_e4m3fnuz as well.
if (expected.is_float8_e4m3()) {
return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() ||
provided.is_float8_e4m3fnuz();
}
// If expected is float8_e5m2, allow float8_e5m2fnuz as well.
if (expected.is_float8_e5m2()) {
return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz();
}
// If expected is bool, allow binding from int8/uint8 with same lanes.
if (expected.is_bool()) {
bool is_i8 = provided.is_int() && provided.bits() == 8;
bool is_u8 = provided.is_uint() && provided.bits() == 8;
return (is_i8 || is_u8) && expected.lanes() == provided.lanes();
}
return false;
};
ICHECK(dtype_compatible(arg->dtype, value->dtype))
<< "Argument " << arg_name << " Buffer bind data type mismatch: expected "
<< arg->dtype << ", got " << value->dtype;
if (value->data_alignment % arg->data_alignment != 0) {
LOG(WARNING) << "Trying to bind buffer to another one with lower alignment "
"requirement "
......@@ -167,10 +226,15 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate(0);
init_nest_.emplace_back(AssertStmt(
!Call(DataType::Bool(), builtin::isnullptr(), {handle}),
StringImm(arg_name + " is expected to have non-NULL DLTensor* pointer"),
nop));
// Allow NULL DLTensor* for optional inputs. When the handle is NULL,
// avoid dereferencing it by using expression-level conditionals and
// short-circuiting guards in asserts. Cache the null check in a Let-bound
// boolean so codegen does not repeat `(handle == NULL)` everywhere.
Var is_null_var(arg_name + "_is_null", DataType::Bool());
init_nest_.emplace_back(
LetStmt(is_null_var,
Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop));
const PrimExpr &is_null = is_null_var;
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
......@@ -193,25 +257,91 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
PrimExpr a_ndim =
make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
// Note: We cannot embed runtime values into the message string.
// Keep message human-friendly without printing TIR exprs.
ndim_err_msg << arg_name << ".ndim is expected to equal "
<< buffer->shape.size();
<< buffer->shape.size() << ", but got mismatched ndim";
auto msg = StringImm(ndim_err_msg.str());
init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
// Only check ndim when handle is non-NULL (using short-circuit OR)
v_ndim = tvm::if_then_else(Not(is_null), v_ndim, make_zero(tvm_ndim_type));
init_nest_.emplace_back(AssertStmt(Or(is_null, a_ndim == v_ndim), msg, nop));
// type checks
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype;
PrimExpr cond =
(TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) ==
IntImm(DataType::UInt(8), buffer->dtype.code()) &&
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) ==
IntImm(DataType::UInt(8), buffer->dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) ==
IntImm(DataType::UInt(16), buffer->dtype.lanes()));
// Avoid dumping TIR expressions in error text; just state mismatch.
// Include expected dtype triplet for clarity.
type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype
<< ", but got incompatible dtype";
// Guard all dtype field loads by `is_null` using if_then_else
PrimExpr v_type_code = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode),
IntImm(DataType::UInt(8), buffer->dtype.code()));
PrimExpr v_type_bits = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits),
IntImm(DataType::UInt(8), buffer->dtype.bits()));
PrimExpr v_type_lanes = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes),
IntImm(DataType::UInt(16), buffer->dtype.lanes()));
PrimExpr expect_code = IntImm(DataType::UInt(8), buffer->dtype.code());
PrimExpr expect_bits = IntImm(DataType::UInt(8), buffer->dtype.bits());
PrimExpr expect_lanes = IntImm(DataType::UInt(16), buffer->dtype.lanes());
PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
// Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime.
if (buffer->dtype.is_float8_e4m3()) {
PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3);
PrimExpr code_e4m3fn = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn);
PrimExpr code_e4m3fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz);
PrimExpr code_match =
(v_type_code == code_e4m3 || v_type_code == code_e4m3fn ||
v_type_code == code_e4m3fnuz);
cond = cond || (code_match && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
}
// Allow float8_e5m2 to match float8_e5m2fnuz at runtime.
if (buffer->dtype.is_float8_e5m2()) {
PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2);
PrimExpr code_e5m2fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz);
PrimExpr code_match =
(v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz);
cond = cond || (code_match && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
}
// Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6).
if (buffer->dtype.is_bool()) {
PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt);
PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt);
PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6);
PrimExpr bits8 = IntImm(DataType::UInt(8), 8);
PrimExpr bits1 = IntImm(DataType::UInt(8), 1);
PrimExpr lanes_ok = (v_type_lanes == expect_lanes);
PrimExpr int8_ok =
(v_type_code == code_int && v_type_bits == bits8 && lanes_ok);
PrimExpr uint8_ok =
(v_type_code == code_uint && v_type_bits == bits8 && lanes_ok);
// Some frontends may tag bool tensors as kDLBool(code=6), commonly with
// bits=8 or bits=1.
PrimExpr kdlbool8_ok =
(v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok);
PrimExpr kdlbool1_ok =
(v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
// Also accept any dtype whose bitwidth=1, regardless of code, to be
// defensive.
PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
}
if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) {
auto type_msg = StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
// Only check dtype when handle is non-NULL (short-circuit)
asserts_.emplace_back(AssertStmt(Or(is_null, cond), type_msg, nop));
}
// shape field
......@@ -220,32 +350,70 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
tvm_shape_type, shape_handle_name());
Var v_shape(shape_handle_name(), DataType::Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt(
buf_shape->data,
TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop));
// Use if_then_else for NULL guard on the shape pointer itself, avoiding
// dereferencing TVMStructGet(handle, kArrShape) when handle is NULL.
init_nest_.emplace_back(
LetStmt(buf_shape->data,
tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_shape, nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
// These packed-bit dtype shapes were not bound in the original
// implementation, so we just use them as is.
if (buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) ||
buffer->dtype == DataType::Int(1)) {
break;
}
Bind_(buffer->shape[k],
cast(buffer->shape[k].dtype(),
BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})),
shape_element_name(k), true);
// The "real" runtime shape value read from DLTensor
PrimExpr raw_shape_val =
cast(buffer->shape[k].dtype(),
BufferLoad(buf_shape,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
// Bind to the value of the symbolic dimension (e.g., m) in TIR, with an
// is_null guard:
// handle is NULL → use 0, placeholder but no dereference
// handle non-NULL → actually read from DLTensor's shape array
PrimExpr bound_shape_val = tvm::if_then_else(
is_null, make_zero(buffer->shape[k].dtype()), raw_shape_val);
// When first encountering a Var (e.g., m), this will generate:
// Let(m, bound_shape_val, ...)
// Constant dimensions will only generate consistency assertions.
BindNullable(buffer->shape[k], bound_shape_val, shape_element_name(k), true,
is_null);
// Keep an explicit "consistency check": when non-NULL, the symbolic
// dimension must equal the DLTensor's shape.
Stmt shape_check = AssertStmt(
Or(is_null, buffer->shape[k] == raw_shape_val),
StringImm(shape_element_name(k) + " mismatch with DLTensor shape"),
Evaluate(0));
asserts_.emplace_back(shape_check);
}
// strides field
Buffer buf_strides =
decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())},
tvm_shape_type, arg_name + ".strides");
def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(LetStmt(
buf_strides->data,
TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
init_nest_.emplace_back(
LetStmt(buf_strides->data,
tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
PrimExpr v_strides_is_null =
Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});
if (buffer->strides.empty()) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
......@@ -253,13 +421,16 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
PrimExpr svalue =
cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
PrimExpr svalue = cast(
stype, BufferLoad(buf_strides,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
std::ostringstream stride_err_msg;
stride_err_msg << stride_handle_name() << ": expected to be compact array";
stride_err_msg
<< stride_handle_name()
<< ": expected to be compact array, but got non-compact strides";
if (!conds.empty()) {
auto stride_msg = StringImm(stride_err_msg.str());
Stmt check =
......@@ -267,6 +438,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
Span span) { return logical_and(a, b, span); },
const_true(1), conds),
stride_msg, Evaluate(0));
// Only check when strides array is actually present at runtime
check = IfThenElse(Not(v_strides_is_null), check);
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
......@@ -277,13 +449,27 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr explicit_stride =
cast(stride_dtype,
BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
BufferLoad(buf_strides,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape);
PrimExpr value = tvm::if_then_else(
PrimExpr core_value = tvm::if_then_else(
v_strides_is_null, stride_from_shape_cast, explicit_stride);
value = tvm::if_then_else(buffer->shape[k] == 1, make_zero(stride_dtype),
value);
Bind_(buffer->strides[k], value, stride_element_name(k), true);
core_value = tvm::if_then_else(buffer->shape[k] == 1,
make_zero(stride_dtype), core_value);
// Bind like shape: define var when needed, and only assert when non-NULL
PrimExpr bound_stride_val =
tvm::if_then_else(is_null, make_zero(stride_dtype), core_value);
BindNullable(buffer->strides[k], bound_stride_val, stride_element_name(k),
true, is_null);
Stmt stride_check = AssertStmt(
Or(is_null, buffer->strides[k] == core_value),
StringImm(stride_element_name(k) + " mismatch with DLTensor strides"),
Evaluate(0));
asserts_.emplace_back(stride_check);
PrimExpr shape_extent = cast(stride_dtype, buffer->shape[k]);
stride_from_shape =
analyzer_.Simplify(stride_from_shape_cast * shape_extent);
......@@ -291,7 +477,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
} else {
PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1);
for (int k = buffer->strides.size() - 1; k >= 0; k--) {
for (int k = static_cast<int>(buffer->strides.size()) - 1; k >= 0; --k) {
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr explicit_stride =
cast(stride_dtype,
......@@ -300,75 +486,127 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape);
Bind_(buffer->strides[k],
tvm::if_then_else(v_strides_is_null, stride_from_shape_cast,
explicit_stride),
stride_element_name(k), true);
PrimExpr core_value = tvm::if_then_else(
v_strides_is_null, stride_from_shape_cast, explicit_stride);
PrimExpr bound_stride_val =
tvm::if_then_else(is_null, make_zero(stride_dtype), core_value);
BindNullable(buffer->strides[k], bound_stride_val, stride_element_name(k),
true, is_null);
Stmt stride_check = AssertStmt(
Or(is_null, buffer->strides[k] == core_value),
StringImm(stride_element_name(k) + " mismatch with DLTensor strides"),
Evaluate(0));
asserts_.emplace_back(stride_check);
stride_from_shape =
analyzer_.Simplify(stride_from_shape_cast * shape_stride);
}
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
if (const auto *const_offset = buffer->elem_offset.as<IntImmNode>()) {
Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
arg_name + ".byte_offset", true);
// Constant elem_offset: only need consistency check, no need for additional
// Var binding.
PrimExpr actual_byte_offset = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
make_const(DataType::UInt(64), 0));
PrimExpr expect_byte_offset =
make_const(DataType::UInt(64), const_offset->value * data_bytes);
Stmt byte_off_check =
AssertStmt(Or(is_null, expect_byte_offset == actual_byte_offset),
StringImm(arg_name + ".byte_offset mismatch"), nop);
asserts_.emplace_back(byte_off_check);
} else {
if (Bind_(buffer->elem_offset,
cast(buffer->elem_offset.dtype(),
(TVMArrayGet(DataType::UInt(64), handle,
builtin::kArrByteOffset) /
make_const(DataType::UInt(64), data_bytes))),
arg_name + ".elem_offset", true)) {
if (buffer->offset_factor > 1) {
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
}
PrimExpr actual_byte_offset = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
make_const(DataType::UInt(64), 0));
PrimExpr expect_elem_off =
cast(buffer->elem_offset.dtype(),
(actual_byte_offset / make_const(DataType::UInt(64), data_bytes)));
// Like shape/stride, do NULL-safe binding for elem_offset:
// handle is NULL → 0
// handle non-NULL → actual_byte_offset / data_bytes
PrimExpr bound_elem_off = tvm::if_then_else(
is_null, make_zero(buffer->elem_offset.dtype()), expect_elem_off);
BindNullable(buffer->elem_offset, bound_elem_off, arg_name + ".elem_offset",
true, is_null);
// Strict consistency check for non-NULL case
Stmt elem_off_check =
AssertStmt(Or(is_null, buffer->elem_offset == expect_elem_off),
StringImm(arg_name + ".elem_offset mismatch"), nop);
asserts_.emplace_back(elem_off_check);
if (buffer->offset_factor > 1) {
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
Stmt off_factor_check =
AssertStmt(Or(is_null, truncmod(offset, factor) == zero),
StringImm(arg_name + ".elem_offset factor mismatch"), nop);
asserts_.emplace_back(off_factor_check);
}
}
// device info.
Bind_(device_type,
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType),
arg_name + ".device_type", true);
Bind_(device_id,
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
arg_name + ".device_id", true);
// Define device_id from handle when available (so later passes can use it)
PrimExpr actual_dev_type = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType),
make_zero(DataType::Int(32)));
PrimExpr actual_dev_id = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
make_zero(DataType::Int(32)));
// Bind device_id to a safe expression (0 when NULL handle)
BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true,
is_null);
// Check device_type consistency (device_id equality is implicitly ensured by
// binding above)
init_nest_.emplace_back(
AssertStmt(Or(is_null, device_type == actual_dev_type),
StringImm(arg_name + ".device_type mismatch"), nop));
// Data field. Because the validation of the data field may depend
// on a dynamic size defined by the other DLTensor* parameters, this
// field must be generated last.
if (Bind_(buffer->data,
TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
arg_name + ".data", true)) {
// Bind data pointer using expression-level guard to avoid deref on NULL.
{
Var vptr(buffer->data);
PrimExpr data_ptr = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
make_zero(DataType::Handle()));
BindNullable(buffer->data, data_ptr, arg_name + ".data", true, is_null);
// Check if the data pointer is NULL. This check is skipped for
// size-0 arrays, since CUDA provides a NULL pointer for size-zero
// allocations.
// size-0 arrays and also skipped when handle itself is NULL.
auto alloc_size = [&]() -> PrimExpr {
PrimExpr product = IntImm(buffer->DefaultIndexType(), 1);
for (const auto &dim : buffer->shape) {
for (const auto &dim : buffer->shape)
product *= dim;
}
return product;
}();
asserts_.emplace_back(AssertStmt(
alloc_size == 0 ||
!Call(DataType::Bool(), builtin::isnullptr(), {vptr}),
StringImm(arg_name + " is expected to have non-NULL data pointer"),
Or(is_null, (alloc_size == 0) ||
!Call(DataType::Bool(), builtin::isnullptr(), {vptr})),
StringImm(arg_name +
" is expected to have non-NULL data pointer, but got NULL"),
nop));
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
init_nest_.emplace_back(
AttrStmt(vptr, tir::attr::storage_alignment,
IntImm(DataType::Int(32), buffer->data_alignment), nop));
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
}
}
......
......@@ -154,6 +154,10 @@ public:
return def_handle_dtype_;
}
bool BindNullable(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_lets,
const PrimExpr &nullable_guard);
private:
// Internal bind function
bool Bind_(const PrimExpr &arg, const PrimExpr &value,
......
......@@ -26,10 +26,13 @@ public:
LowerHopperIntrin substituter(disable_shuffle_elect);
fptr->body = substituter.VisitStmt(f->body);
Map<Var, Array<PrimExpr>> init_desc_arg_map;
// Collect prologue/epilogue statements for host-side setup/teardown
Array<Stmt> prologue_stmts;
Array<Stmt> epilogue_stmts;
for (const auto &[call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
{StringImm("arg_value"), 16});
{StringImm("tvm_ffi_any"), 16});
Array<PrimExpr> init_desc_args;
if (call->op.same_as(create_tma_descriptor())) {
init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled));
......@@ -44,11 +47,66 @@ public:
// add to function attribute
Call init_desc =
Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
fptr->body =
LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
// Accumulate TMA descriptor init into prologue
prologue_stmts.push_back(LetStmt(var, alloc_desc, Evaluate(init_desc)));
init_desc_arg_map.Set(var, init_desc_args);
}
f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map);
// Additionally, if L2 persistent cache annotations were lowered earlier,
// materialize TVM FFI calls to set the stream access policy window.
if (f->attrs.defined() && f->attrs->dict.count("l2_persistent_map")) {
auto l2_map =
f->GetAttr<Map<String, Array<PrimExpr>>>("l2_persistent_map");
if (l2_map.defined()) {
// Build a lookup from buffer name to Buffer object
std::unordered_map<std::string, Buffer> name2buf;
for (const auto &kv : f->buffer_map) {
name2buf.emplace(kv.second->name, kv.second);
}
for (const auto &kv : l2_map.value()) {
const std::string buf_name = kv.first;
const Array<PrimExpr> &args = kv.second;
if (name2buf.count(buf_name) == 0) {
continue;
}
const Buffer &buf = name2buf.at(buf_name);
// Build base pointer expression (read access)
PrimExpr base_ptr = buf.access_ptr(1);
// Args packed: func_name, base_ptr, num_bytes, hit_ratio
Array<PrimExpr> packed_args;
packed_args.push_back(
StringImm(tvm_cuda_stream_set_access_policy_window));
packed_args.push_back(base_ptr);
// size_in_bytes (args[1]) then hit_ratio (args[0])
ICHECK_GE(args.size(), 2);
packed_args.push_back(args[1]);
packed_args.push_back(args[0]);
prologue_stmts.push_back(Evaluate(Call(
DataType::Int(32), builtin::tvm_call_packed(), packed_args)));
}
// Add a single epilogue call to reset the access policy window and
// restore L2 limit
Array<PrimExpr> reset_args;
reset_args.push_back(
StringImm(tvm_cuda_stream_reset_access_policy_window));
epilogue_stmts.push_back(Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), reset_args)));
}
}
// Stitch prologue statements before the original body
if (!prologue_stmts.empty()) {
// Chain the Let/Evaluate statements sequentially
Stmt seq = prologue_stmts.size() == 1 ? prologue_stmts[0]
: SeqStmt(prologue_stmts);
fptr->body = SeqStmt({seq, fptr->body});
}
if (!epilogue_stmts.empty()) {
Stmt seq_end = epilogue_stmts.size() == 1 ? epilogue_stmts[0]
: SeqStmt(epilogue_stmts);
fptr->body = SeqStmt({fptr->body, seq_end});
}
return f;
}
......
......@@ -20,6 +20,7 @@
/*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
......@@ -32,6 +33,7 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -43,13 +45,11 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
static constexpr const char *kDeviceContextVar = "device_api_context";
namespace {
class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode)
: ret_var_(std::move(ret_var)), ret_tcode_(std::move(ret_tcode)) {}
explicit ReturnRewriter(Var ret_var) : ret_var_(ret_var) {}
Stmt VisitStmt_(const ForNode *node) override {
if (node->kind == ForKind::kParallel)
......@@ -79,8 +79,6 @@ private:
struct ConvertedInfo {
int type_index{-1};
PrimExpr expr;
Buffer dummy_val_buffer;
Buffer dummy_tcode_buffer;
};
ConvertedInfo ConvertForFFI(const PrimExpr &val) {
......@@ -88,7 +86,11 @@ private:
// convert val's data type to FFI data type, return type code
DataType dtype = val.dtype();
if (dtype.is_int() || dtype.is_uint()) {
if (dtype.is_bool()) {
info.type_index = ffi::TypeIndex::kTVMFFIBool;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_int() || dtype.is_uint()) {
info.type_index = ffi::TypeIndex::kTVMFFIInt;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_float()) {
......@@ -101,56 +103,39 @@ private:
LOG(FATAL) << "data type " << dtype << " not supported yet";
}
// If multiple return locations have the same data type, use the
// same dummy buffer declaration.
auto it = dummy_val_buffer_map_.find(info.type_index);
if (it != dummy_val_buffer_map_.end()) {
info.dummy_val_buffer = it->second;
} else {
info.dummy_val_buffer =
Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0),
ret_var_->name_hint, 0, 0, kDefault);
dummy_val_buffer_map_[info.type_index] = info.dummy_val_buffer;
}
// The type_index is always a 32-bit int, so we don't need to have a
// separate map.
if (!dummy_tcode_buffer_.defined()) {
dummy_tcode_buffer_ =
Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0),
ret_tcode_->name_hint, 0, 0, kDefault);
}
info.dummy_tcode_buffer = dummy_tcode_buffer_;
return info;
}
Stmt WriteToOut(const PrimExpr &val) {
Stmt WriteToOut(PrimExpr val) {
auto info = ConvertForFFI(val);
Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0});
Stmt store_tcode =
BufferStore(info.dummy_tcode_buffer, info.type_index, {0});
Stmt store_tindex = tir::Evaluate(
tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyTypeIndex),
IntImm(DataType::Int(32), info.type_index)}));
Stmt store_zero_padding = tir::Evaluate(tir::Call(
DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyZeroPadding),
IntImm(DataType::Int(32), 0)}));
Stmt store_val = tir::Evaluate(tir::Call(
DataType::Int(32), tir::builtin::tvm_struct_set(),
{ret_var_, IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyUnionValue),
info.expr}));
Stmt ret_zero = Evaluate(tvm::ret(0));
return SeqStmt({store_val, store_tcode, ret_zero});
return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero});
}
Var ret_var_;
Var ret_tcode_;
int in_parallel_{0};
std::unordered_map<int, Buffer> dummy_val_buffer_map_;
Buffer dummy_tcode_buffer_;
};
Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
ReturnRewriter rewriter(std::move(ret_var), std::move(ret_tcode));
return rewriter(std::move(body));
}
class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const Map<GlobalVar, String> &packed_func_methods,
Stmt stmt) {
static ffi::Optional<Stmt>
Apply(const ffi::Map<GlobalVar, ffi::String> &packed_func_methods,
Stmt stmt) {
SubroutineCallRewriter rewriter(packed_func_methods);
stmt = rewriter.VisitStmt(stmt);
if (rewriter.made_change_) {
......@@ -162,16 +147,16 @@ public:
private:
explicit SubroutineCallRewriter(
const Map<GlobalVar, String> &packed_func_methods)
const ffi::Map<GlobalVar, ffi::String> &packed_func_methods)
: packed_func_methods(packed_func_methods) {}
PrimExpr VisitExpr_(const CallNode *op) override {
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = tvm::ffi::GetRef<GlobalVar>(gvar_ptr);
auto gvar = ffi::GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) {
Array<PrimExpr> cpacked_args;
ffi::Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value()));
for (auto arg : node->args) {
cpacked_args.push_back(arg);
......@@ -187,19 +172,18 @@ private:
return node;
}
const Map<GlobalVar, String> &packed_func_methods;
const ffi::Map<GlobalVar, ffi::String> &packed_func_methods;
bool made_change_{false};
};
} // namespace
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, const std::string &msg) {
return AssertStmt(std::move(lhs) == std::move(rhs), tvm::tir::StringImm(msg),
Evaluate(0));
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
inline Stmt MakeAssertNotNull(PrimExpr ptr, const std::string &msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {std::move(ptr)});
inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
}
......@@ -254,21 +238,16 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}
auto *func_ptr = func.CopyOnWrite();
// set the global symbol to the packed function name
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
// Data field definitions
// The packed fields
Var v_self_handle("self_handle", DataType::Handle());
Var v_packed_args("args", DataType::Handle());
Buffer buf_packed_arg_type_ids =
decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())},
DataType::Int(32), "arg_type_ids");
Var v_num_packed_args("num_args", DataType::Int(32));
Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::Void())));
Var v_out_ret_tcode("out_ret_tcode",
PointerType(PrimType(DataType::Int(32))));
Var v_resource_handle("resource_handle", DataType::Handle());
// The arguments of the function.
Var v_result("result", PointerType(PrimType(DataType::Void())));
// The device context
Var device_id("dev_id");
......@@ -278,37 +257,24 @@ PrimFunc MakePackedAPI(PrimFunc func) {
std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations;
std::unordered_map<const VarNode *, PrimExpr> vmap;
ArgBinder binder(&vmap);
std::vector<Stmt> shape_checks;
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
bool disable_dynamic_tail_split =
ctxt->GetConfig<Bool>(kDisableDynamicTailSplit, Bool(true)).value();
// ---------------------------
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
Array<PrimExpr> call_args{
auto f_load_arg_value = [&](DataType arg_type, int i) {
ffi::Array<PrimExpr> call_args{
v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMValueContent)};
IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)};
// load 64 bit version
DataType api_type = APIType(t);
DataType api_type = APIType(arg_type);
PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version.
if (api_type != t) {
res = Cast(t, res);
if (api_type != arg_type) {
res = Cast(arg_type, res);
}
return res;
};
// Find the device API context argument based on name
for (const auto &param : func_ptr->params) {
if (param->name_hint == kDeviceContextVar) {
num_args--;
v_resource_handle = param;
break;
}
}
// Assert correct type codes for each argument. This must be done
// *before* any initialization steps produced by
// `binder.BindDLTensor()`. The validity of those initialization
......@@ -321,12 +287,10 @@ PrimFunc MakePackedAPI(PrimFunc func) {
return error_message.str();
}()));
seq_init.push_back(MakeAssertNotNull(
v_packed_args, name_hint + ": TVMValue* arg pointer was NULL"));
seq_init.push_back(MakeAssertNotNull(
buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL"));
seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop));
if (num_args > 0) {
seq_init.push_back(
MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL"));
}
// Need to delay binding of the buffers, in case some arguments also
// appear in the buffer.
......@@ -335,26 +299,17 @@ PrimFunc MakePackedAPI(PrimFunc func) {
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
// Ignore the device context argument, as it will still be passed
// as a native argument.
if (param->name_hint == kDeviceContextVar) {
continue;
}
var_def.emplace_back(f_arg_value(param.dtype(), i), param);
if (func_ptr->buffer_map.count(param)) {
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
}
// type code checks
Var type_index(param->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(LetStmt(
PrimExpr arg_value;
// type index checks
Var type_index(param->name_hint + ".type_index", DataType::Int(32));
seq_init.push_back(LetStmt(
type_index,
BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}),
tir::Call(DataType::Int(32), builtin::tvm_struct_get(),
{v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}),
nop));
DataType t = param.dtype();
if (t.is_handle()) {
DataType dtype = param.dtype();
if (dtype.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_init.emplace_back(
......@@ -363,23 +318,63 @@ PrimFunc MakePackedAPI(PrimFunc func) {
type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr ||
type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
// if type_index is Tensor, we need to add the offset of the DLTensor
// header which always equals 16 bytes, this ensures that T.handle always
// shows up as a DLTensor*
const int64_t object_cell_offset = sizeof(TVMFFIObject);
static_assert(object_cell_offset == 24);
arg_value = f_load_arg_value(param.dtype(), i);
PrimExpr handle_from_tensor =
Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(),
{arg_value, IntImm(DataType::Int(32), object_cell_offset)});
arg_value = Select(type_index == ffi::TypeIndex::kTVMFFITensor,
handle_from_tensor, arg_value);
} else if (dtype.is_bool()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be boolean";
seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool ||
type_index == ffi::TypeIndex::kTVMFFIInt,
tvm::tir::StringImm(msg.str()), nop));
arg_value =
Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), i));
} else if (dtype.is_int() || dtype.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(AssertStmt(type_index == kDLInt,
tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool,
tvm::tir::StringImm(msg.str()), nop));
arg_value = f_load_arg_value(param.dtype(), i);
} else {
ICHECK(t.is_float());
ICHECK(dtype.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_init.emplace_back(AssertStmt(type_index == kDLFloat,
tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool,
tvm::tir::StringImm(msg.str()), nop));
// use select so we can also handle int conversion to bool
arg_value = tir::Select(
type_index == ffi::TypeIndex::kTVMFFIFloat,
/* true_value = */ f_load_arg_value(param.dtype(), i),
/* false_value = */
Cast(param.dtype(), f_load_arg_value(DataType::Int(64), i)));
}
var_def.emplace_back(arg_value, param);
if (func_ptr->buffer_map.count(param)) {
// buffer binding now depends on type index
// if the index is Tensor handle, we need to offset to get the DLTensor*
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
}
}
Array<Var> args{v_packed_args, buf_packed_arg_type_ids->data,
v_num_packed_args, v_out_ret_value,
v_out_ret_tcode, v_resource_handle};
// signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny*
// v_result)
ffi::Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args,
v_result};
// Arg definitions are defined before buffer binding to avoid the use before
// def errors.
......@@ -392,83 +387,57 @@ PrimFunc MakePackedAPI(PrimFunc func) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
}
for (const auto &kv : buffer_def) {
binder.BindDLTensor(kv.second, device_type, device_id, kv.first,
name_hint + "." + kv.first->name_hint);
arg_buffer_declarations.push_back(DeclBuffer(kv.second, nop));
for (const auto &[var, buffer] : buffer_def) {
binder.BindDLTensor(buffer, device_type, device_id, var,
name_hint + "." + var->name_hint);
arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
}
func =
WithAttrs(std::move(func),
{{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}});
Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
// reset global symbol to attach prefix
func = WithAttrs(
std::move(func),
{{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host},
{tvm::attr::kGlobalSymbol,
ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}});
Stmt body = ReturnRewriter(v_result)(func_ptr->body);
body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::compute_scope,
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
auto node = String("default");
ffi::Any node = ffi::String("default");
seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop));
seq_check.push_back(
AttrStmt(node, tir::attr::device_type, device_type, nop));
if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
Evaluate(Call(DataType::Int(32), tir::builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device),
device_type, device_id}));
body = SeqStmt({set_device, body});
}
}
// (zhengju) For dynamic constraint, we need to check the buffer shape and
// dtype to make sure the buffer can be vectorized.
for (const auto &kv : buffer_def) {
if (disable_dynamic_tail_split) {
Optional<Integer> opt_dynamic_alignment =
ctxt->GetConfig(kDynamicAlignment, Optional<Integer>());
int dynamic_alignment = opt_dynamic_alignment.value_or(Integer(8))->value;
// The vectorize dimension will be the last dimension of the buffer
auto vectorize_dim = kv.second->shape[kv.second->shape.size() - 1];
auto shape_vectorize_expr = [&]() -> PrimExpr {
PrimExpr result = IntImm(kv.second->DefaultIndexType(), 1);
result = result * vectorize_dim;
result = FloorMod(result, IntImm(result->dtype, dynamic_alignment));
return result;
}();
shape_checks.emplace_back(AssertStmt(
shape_vectorize_expr == 0,
tvm::tir::StringImm(
kv.second->name +
": Vectorize dimension in buffer must be divisible by " +
std::to_string(dynamic_alignment)),
nop));
}
}
// Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))});
if (!disable_dynamic_tail_split) {
body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
arg_buffer_declarations},
body);
} else {
body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
arg_buffer_declarations, shape_checks},
body);
}
body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
arg_buffer_declarations},
body);
func_ptr->body = body;
func_ptr->params = args;
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
ffi::Array<Var> undefined = UndefinedVars(body, func_ptr->params);
ICHECK_EQ(undefined.size(), 0)
<< "In PrimFunc " << name_hint << " variables " << undefined
<< " are used, but are not passed in as API arguments";
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function.
func_ptr->buffer_map = ffi::Map<Var, Buffer>();
func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function.
return func;
}
......
......@@ -240,37 +240,42 @@ public:
simplifier.MarkBufferMapShapes(func);
func.CopyOnWrite()->body = simplifier(func->body);
// Begin to remove useless var and buffer
// First get used buffers
simplifier.used_buffers_ = CollectUsedBuffers(func);
bool param_updated = false;
Array<Var> new_params;
Map<Var, Buffer> new_buffer_map;
// Check whether each buffer is used
for (const auto &var : func->params) {
if (func->buffer_map.find(var) != func->buffer_map.end()) {
if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
simplifier.used_buffers_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else if (simplifier.used_in_buffer_def_.find(
func->buffer_map[var]->data.get()) !=
simplifier.used_in_buffer_def_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
// Optionally remove unused buffer parameters
if (simplify_arguments) {
// First get used buffers
simplifier.used_buffers_ = CollectUsedBuffers(func);
bool param_updated = false;
Array<Var> new_params;
Map<Var, Buffer> new_buffer_map;
// Check whether each buffer is used
for (const auto &var : func->params) {
if (func->buffer_map.find(var) != func->buffer_map.end()) {
if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
simplifier.used_buffers_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else if (simplifier.used_in_buffer_def_.find(
func->buffer_map[var]->data.get()) !=
simplifier.used_in_buffer_def_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else {
param_updated = true;
}
} else {
param_updated = true;
// Non-buffer parameters (e.g., scalars) are always retained
new_params.push_back(var);
}
}
}
if (param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span);
} else {
return func;
if (param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span);
}
}
// Either no change to params or argument simplification disabled
return func;
}
private:
......
......@@ -13,7 +13,7 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
shared_buf = T.alloc_shared([M, N], dtype)
T.print(shared_buf)
jit_kernel = tilelang.compile(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi")
profiler = jit_kernel.get_profiler()
profiler.run_once()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment