Unverified Commit 00dd7388 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[CUDA] Add read-only parameter annotation for CUDA codegen (#1416)

* [Enhancement] Add read-only parameter annotation for CUDA codegen

* Introduced the `AnnotateReadOnlyParams` transformation to annotate read-only handle parameters in PrimFuncs, enabling the generation of `const` qualifiers in CUDA codegen.
* Updated `PrintFunctionSignature` and `AddFunction` methods to utilize the new attribute `tl.readonly_param_indices`, enhancing performance by allowing read-only cache loads.
* Modified the optimization pipeline to include the new annotation step, improving the overall efficiency of the code generation process.

* lint fix

* [Dependency] Update apache-tvm-ffi version to >=0.1.3

* Updated the version of apache-tvm-ffi in pyproject.toml, requirements.txt, and requirements-dev.txt to ensure compatibility with the latest features and fixes.
* Made adjustments in CUDA and HIP template files to use `const` qualifiers for global pointer parameters, enhancing code safety and clarity.

* lint fix

* [Enhancement] Refactor ReadWriteMarker for improved parameter handling

* Updated the ReadWriteMarker class to accept a set of parameter or data variables, enhancing its ability to track written variables.
* Introduced a new method, ResolveDataVarFromPtrArg, to resolve underlying buffer data from pointer-like arguments, improving accuracy in identifying written variables.
* Modified the MarkReadOnlyParams function to gather handle parameters and their corresponding buffer data variables, streamlining the process of determining read-only parameters.
* Enhanced the logic for identifying written variables to account for aliased data variables, ensuring comprehensive tracking of modifications.

* lint fix

* Update tma_load function to use const qualifier for global memory pointer

* Changed the parameter type of gmem_ptr in the tma_load function from void* to void const* to enhance type safety and clarity in memory operations.
* This modification ensures that the function correctly handles read-only global memory pointers, aligning with best practices in CUDA programming.

* Remove commented-out code and reorder transformations in OptimizeForTarget function for clarity

* Refactor buffer marking logic in annotate_read_only_params.cc to improve accuracy in identifying written variables. Update OptimizeForTarget function to reorder transformations for better clarity.
parent 3546e2ee
......@@ -7,8 +7,6 @@ import argparse
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input
# tilelang.disable_cache()
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
......
......@@ -31,7 +31,7 @@ dependencies = [
# Extra constraint to tvm-ffi for abi issue,
# should be removed after our tvm's update.
# See discussion in tilelang#1373 and apache/tvm-ffi#307
"apache-tvm-ffi>=0.1.2",
"apache-tvm-ffi>=0.1.3",
# torch-c-dlpack-ext provides prebuilt torch extensions.
# Without it, TVM FFI may require JIT compilation on first import.
"torch-c-dlpack-ext",
......
# Requirements to run local build with `--no-build-isolation` or other developments
apache-tvm-ffi>=0.1.2
apache-tvm-ffi>=0.1.3
build
cmake>=3.26
cython>=3.0.0
......
......@@ -3246,6 +3246,14 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
CodeGenC::PrintType(func->ret_type, os);
CodeGenC::PrintExtraAttrs(func, os);
bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
// Read-only param indices attribute, if present.
std::unordered_set<int> ro_param_indices;
if (auto opt =
func->GetAttr<ffi::Array<Integer>>("tl.readonly_param_indices")) {
for (const auto &idx : opt.value()) {
ro_param_indices.insert(static_cast<int>(Downcast<Integer>(idx)->value));
}
}
os << " " << function_name << "(";
for (size_t i = 0; i < func->params.size(); ++i) {
tir::Var v = func->params[i];
......@@ -3270,7 +3278,10 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
}
// If marked read-only, emit const qualifier before type.
if (ro_param_indices.count(static_cast<int>(i))) {
os << "const ";
}
CodeGenC::PrintType(GetType(v), os);
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
......@@ -3314,6 +3325,13 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
// Read-only param indices attribute, if present.
std::unordered_set<int> ro_param_indices;
if (auto opt = f->GetAttr<ffi::Array<Integer>>("tl.readonly_param_indices")) {
for (const auto &idx : opt.value()) {
ro_param_indices.insert(static_cast<int>(Downcast<Integer>(idx)->value));
}
}
this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream);
......@@ -3341,7 +3359,10 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
// If marked read-only, emit const qualifier before type.
if (ro_param_indices.count(static_cast<int>(i))) {
stream << "const ";
}
CodeGenC::PrintType(GetType(v), stream);
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
......
......@@ -26,7 +26,8 @@ template <int N> TL_DEVICE void cp_async_wait() {
}
template <int N>
TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
TL_DEVICE void cp_async_gs(void const *const smem_addr,
void const *global_ptr) {
static_assert(N == 16 || N == 8 || N == 4);
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
......@@ -37,7 +38,7 @@ TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N));
"l"((void const *)(global_ptr)), "n"(N));
} else {
asm volatile(
#if TL_ENABLE_L2_PREFETCH
......@@ -46,13 +47,13 @@ TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
"cp.async.ca.shared.global [%0], [%1], %2;"
#endif
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N));
"l"((void const *)(global_ptr)), "n"(N));
}
}
template <int N>
TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
void *global_ptr, bool cond) {
void const *global_ptr, bool cond) {
static_assert(N == 16 || N == 8 || N == 4);
int bytes = cond ? N : 0;
unsigned int addr = smem_ptr_to_uint(smem_addr);
......@@ -64,7 +65,7 @@ TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
"cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N), "r"(bytes));
"l"((void const *)(global_ptr)), "n"(N), "r"(bytes));
} else {
asm volatile(
#if TL_ENABLE_L2_PREFETCH
......@@ -73,7 +74,7 @@ TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
"cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N), "r"(bytes));
"l"((void const *)(global_ptr)), "n"(N), "r"(bytes));
}
}
......
......@@ -15,14 +15,14 @@ enum class CacheHintSm90 : uint64_t {
};
template <typename BarrierType = uint64_t>
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, BarrierType &smem_mbar,
uint32_t size) {
TL_DEVICE void tma_load(void *smem_ptr, void const *gmem_ptr,
BarrierType &smem_mbar, uint32_t size) {
uint32_t smem_int_mbar =
smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
"bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr),
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar)
"l"((void const *)gmem_ptr), "r"(size), "r"(smem_int_mbar)
:);
}
......
......@@ -73,33 +73,35 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
}
template <int N>
TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
TL_DEVICE void cp_async_gs(void *lds_base_ptr, void const *global_base_ptr) {
if constexpr (N == 16) {
*(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
*(uint4 *)lds_base_ptr = *(const uint4 *)global_base_ptr;
} else if constexpr (N == 8) {
*(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
*(uint2 *)lds_base_ptr = *(const uint2 *)global_base_ptr;
} else if constexpr (N == 4) {
async_buffer_load_dword_v(
lds_base_ptr,
make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
make_wave_buffer_resource(((const int32_t *)global_base_ptr) -
threadIdx.x),
threadIdx.x * N /*assume 4 bytes*/);
}
}
template <int N>
TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
void *global_base_ptr, bool cond) {
void const *global_base_ptr, bool cond) {
if constexpr (N == 16) {
*(uint4 *)lds_base_ptr =
cond ? *(uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0);
cond ? *(const uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0);
} else if constexpr (N == 8) {
*(uint2 *)lds_base_ptr =
cond ? *(uint2 *)global_base_ptr : make_uint2(0, 0);
cond ? *(const uint2 *)global_base_ptr : make_uint2(0, 0);
} else {
if (cond) {
async_buffer_load_dword_v(
lds_base_ptr,
make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
make_wave_buffer_resource(((const int32_t *)global_base_ptr) -
threadIdx.x),
threadIdx.x * N /*assume 4 bytes*/);
} else {
*(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0);
......
/*!
* \file annotate_read_only_params.cc
* \brief Annotate PrimFunc parameters that are read-only (never written).
*/
#include <string>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
/*!
* \brief A simple visitor that marks handle parameters as written when they
* appear on the LHS of a BufferStore or in a tvm_access_ptr with write
* flag.
*/
class ReadWriteMarker : public StmtExprVisitor {
public:
explicit ReadWriteMarker(
const std::unordered_set<const VarNode *> &param_or_data_vars)
: param_or_data_vars_(param_or_data_vars) {}
const std::unordered_set<const VarNode *> &written() const {
return written_;
}
// Try to resolve the underlying buffer data Var from a pointer-like
// argument. Supports:
// - address_of(BufferLoad(...)) -> returns buffer->data
// - BufferLoad(...) -> returns buffer->data
// Otherwise returns nullptr.
const VarNode *ResolveDataVarFromPtrArg(const PrimExpr &arg) const {
if (const auto *call = arg.as<CallNode>()) {
if (call->op.same_as(builtin::address_of())) {
if (call->args.size() == 1U) {
if (const auto *load = call->args[0].as<BufferLoadNode>()) {
return load->buffer->data.get();
}
}
}
} else if (const auto *load = arg.as<BufferLoadNode>()) {
return load->buffer->data.get();
}
return nullptr;
}
void VisitStmt_(const BufferStoreNode *op) final {
const VarNode *data = op->buffer->data.get();
if (param_or_data_vars_.count(data)) {
written_.insert(data);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
// Detect tvm_access_ptr writes. Be conservative if rw_mask is non-constant.
if (op->op.same_as(builtin::tvm_access_ptr())) {
if (op->args.size() == 5U) {
if (const VarNode *buf = op->args[1].as<VarNode>()) {
const IntImmNode *flag = op->args[4].as<IntImmNode>();
bool maybe_write = true; // default conservative
if (flag) {
maybe_write = (flag->value & 2) != 0; // write bit set
}
if (maybe_write && param_or_data_vars_.count(buf)) {
written_.insert(buf);
}
}
}
} else {
// Generic fallback: mark buffers that appear as
// address_of(BufferLoad(...)) in call arguments as written. This matches
// patterns like
// tl.tma_store(address_of(smem[..]), address_of(gmem[..]), ...)
// call_extern("AtomicAdd*", address_of(gmem[..]), ...)
// and avoids over-marking plain BufferLoad used for reads.
for (const PrimExpr &a : op->args) {
if (const auto *c = a.as<CallNode>()) {
if (c->op.same_as(builtin::address_of()) && c->args.size() == 1U) {
if (const auto *bl = c->args[0].as<BufferLoadNode>()) {
const VarNode *data = bl->buffer->data.get();
if (param_or_data_vars_.count(data)) {
written_.insert(data);
}
}
}
}
}
}
StmtExprVisitor::VisitExpr_(op);
}
private:
std::unordered_set<const VarNode *> param_or_data_vars_;
std::unordered_set<const VarNode *> written_;
};
/*!
* \brief Annotate PrimFunc with indices of read-only handle parameters.
*
* Adds an Array<Integer> attribute "tl.readonly_param_indices" that lists
* parameter indices which correspond to handle parameters that are never
* written inside the function body. This can be used by codegen to emit
* `const` qualifiers to enable read-only caching (e.g., __ldg on CUDA).
*/
static tir::PrimFunc MarkReadOnlyParams(tir::PrimFunc f) {
// Gather handle params and their corresponding buffer data vars (aliases).
std::unordered_set<const VarNode *> param_or_data_vars;
// Map back from data var to parameter index for result attribution.
std::unordered_map<const VarNode *, size_t> data_var_to_param_idx;
for (size_t i = 0; i < f->params.size(); ++i) {
const Var &p = f->params[i];
if (!p->dtype.is_handle())
continue;
param_or_data_vars.insert(p.get());
// If there is a buffer_map entry for this param, include its data var too.
if (auto opt = f->buffer_map.Get(p)) {
const VarNode *data = opt.value()->data.get();
param_or_data_vars.insert(data);
data_var_to_param_idx[data] = i;
}
}
if (param_or_data_vars.empty())
return f;
ReadWriteMarker marker(param_or_data_vars);
marker(f->body);
// Determine read-only parameter indices among all params (handle only)
Array<Integer> readonly_indices;
for (size_t i = 0; i < f->params.size(); ++i) {
const Var &v = f->params[i];
if (!v->dtype.is_handle())
continue;
bool is_written = false;
// Direct param var written?
if (marker.written().count(v.get())) {
is_written = true;
} else {
// Or any aliased data var written?
if (auto opt = f->buffer_map.Get(v)) {
if (marker.written().count(opt.value()->data.get())) {
is_written = true;
}
}
}
if (!is_written) {
readonly_indices.push_back(Integer(static_cast<int>(i)));
}
}
if (!readonly_indices.empty()) {
Map<String, Any> attrs;
attrs.Set(String("tl.readonly_param_indices"), readonly_indices);
f = WithAttrs(std::move(f), attrs);
}
return f;
}
namespace transform {
using namespace tir::transform;
Pass AnnotateReadOnlyParams() {
auto pass_func = [](PrimFunc f, const IRModule &m,
const tvm::transform::PassContext &ctx) {
return MarkReadOnlyParams(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateReadOnlyParams", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateReadOnlyParams",
AnnotateReadOnlyParams);
}
} // namespace transform
} // namespace tl
} // namespace tvm
import tilelang.language as T
from tilelang.engine.lower import lower
from tilelang.jit.adapter.utils import match_declare_kernel
def _simple_add_kernel():
@T.prim_func
def main(
x: T.Tensor((128,), "float32"),
y: T.Tensor((128,), "float32"),
):
# One-dimensional kernel; writes y from x without modifying x
with T.Kernel(128, threads=32) as pid:
y[pid] = x[pid] + 1.0
return main
def test_codegen_emits_const_for_readonly_params():
# Lower without device compilation to retrieve CUDA source reliably
func = _simple_add_kernel()
artifact = lower(func, target="cuda", enable_device_compile=False)
src = artifact.kernel_source
print(src)
assert 'extern "C" __global__' in src
# Extract kernel signature and check qualifiers
lparen = match_declare_kernel(src)
rparen = src.find(")", lparen)
assert rparen != -1
signature = src[lparen:rparen]
# x is read-only: should be `const` and `__restrict__`
assert "const float* __restrict__" in signature
# y is written: must not be const, but still `__restrict__` due to noalias
# We ensure there is a non-const float* parameter with __restrict__ as well
assert "const float* __restrict__ x" in src or "const float *__restrict__ x" in src
assert " float* __restrict__ y" in src or " float *__restrict__ y" in src
# Also validate the function attribute carries read-only param indices
# Expect only the first handle parameter (x) to be marked read-only
device_mod = artifact.device_mod
prim_funcs = [f for f in device_mod.functions.values() if hasattr(f, "attrs")]
assert prim_funcs, "No PrimFunc found in device module"
pf = prim_funcs[0]
ro = pf.attrs.get("tl.readonly_param_indices")
assert ro is not None, "Expected tl.readonly_param_indices to be present"
ro_list = [int(i) for i in ro]
assert 0 in ro_list and 1 not in ro_list
if __name__ == "__main__":
test_codegen_emits_const_for_readonly_params()
......@@ -250,6 +250,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tilelang.transform.SplitHostDevice()(mod)
mod = tilelang.transform.AnnotateReadOnlyParams()(mod)
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
......
......@@ -303,6 +303,21 @@ def SplitHostDevice():
return _ffi_api.SplitHostDevice() # type: ignore
def AnnotateReadOnlyParams():
"""Annotate read-only handle parameters for PrimFuncs.
Adds attribute `tl.readonly_param_indices` listing param indices that are
never written, enabling CUDA codegen to emit `const` qualifiers to unlock
read-only cache loads.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateReadOnlyParams() # type: ignore
def VectorizeLoop(enable_vectorize: bool = True):
"""VectorizeLoop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment