Unverified Commit 0814b171 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Introduce `T.annotate_restrict_buffers` (#1428)

* [Enhancement] Introduce non-restrict parameter support in code generation

- Added a new PrimFunc-level attribute `tl.non_restrict_params` to specify handle Vars that should not be marked with the restrict qualifier during code generation.
- Updated `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to handle non-restrict parameters, ensuring proper treatment of overlapping buffer aliases.
- Implemented a new annotation function `annotate_restrict_buffers` to facilitate the marking of buffer parameters as non-restrict.
- Enhanced the `SplitHostDevice` transformation to propagate non-restrict parameters from host to device functions.
- Added a new transform function `HoistNonRestrictParams` to manage non-restrict parameters effectively.

* [Enhancement] Improve HoistNonRestrictParams transformation

- Updated the HoistNonRestrictParams function to recursively collect all `tl.non_restrict_params` annotations from nested blocks, enhancing flexibility in annotation placement.
- Introduced a new NonRestrictCollector class to manage the collection and deduplication of non-restrict parameters.
- Modified the SplitHostDevice transformation to remove the non-restrict attribute from the host-side PrimFunc after propagation to device kernels.
- Adjusted the LowerAndLegalize function to directly apply the HoistNonRestrictParams transformation without exception handling, streamlining the process.

* [Refactor] Simplify non-restrict parameter handling in code generation

- Removed unnecessary normalization logic and associated data structures from `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP`.
- Streamlined the handling of non-restrict parameters by directly inserting them into the `non_restrict` set, improving code clarity and maintainability.
- Updated conditional checks to eliminate redundant checks against normalized names, enhancing performance and readability.

* [Dependency] Update TVM subproject to latest commit 68aa8461

- Updated the TVM subproject to the latest commit, ensuring compatibility with recent changes and improvements.
- Refactored non-restrict parameter handling in `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to enhance code clarity and maintainability.
- Adjusted the `SplitHostDevice` transformation to streamline the propagation of non-restrict parameters.

* fix
parent f4f87f46
...@@ -28,6 +28,10 @@ static constexpr const char *kWarpSpecializationScope = ...@@ -28,6 +28,10 @@ static constexpr const char *kWarpSpecializationScope =
static constexpr const char *kCustomWarpSpecialization = static constexpr const char *kCustomWarpSpecialization =
"kCustomWarpSpecialization"; "kCustomWarpSpecialization";
static constexpr const char *kLocalVarInit = "tl.local_var_init"; static constexpr const char *kLocalVarInit = "tl.local_var_init";
// A PrimFunc-level attribute carrying a list of handle Vars
// that must NOT be marked with the restrict qualifier in codegen.
// Type: Array<tir::Var>
static constexpr const char *kNonRestrictParams = "tl.non_restrict_params";
} // namespace attr } // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations = static constexpr const char *kDebugMergeSharedMemoryAllocations =
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "../op/builtin.h"
#include "../support/ffi_aliases.h" #include "../support/ffi_aliases.h"
#include "support/str_escape.h" #include "support/str_escape.h"
#include "target/build_common.h" #include "target/build_common.h"
...@@ -260,6 +261,12 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) { ...@@ -260,6 +261,12 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
ICHECK(global_symbol) ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
for (const tir::Var &v : opt.value())
non_restrict.insert(v.get());
}
this->PrintFuncPrefix(stream); this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream); CodeGenC::PrintType(f->ret_type, stream);
...@@ -294,7 +301,7 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) { ...@@ -294,7 +301,7 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
} }
} }
if (no_alias) { if (no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, stream); PrintRestrict(v, stream);
} }
} else { } else {
......
...@@ -3418,6 +3418,12 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name, ...@@ -3418,6 +3418,12 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
CodeGenC::PrintType(func->ret_type, os); CodeGenC::PrintType(func->ret_type, os);
CodeGenC::PrintExtraAttrs(func, os); CodeGenC::PrintExtraAttrs(func, os);
bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias); bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
func->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
for (const tir::Var &v : opt.value())
non_restrict.insert(v.get());
}
// Read-only param indices attribute, if present. // Read-only param indices attribute, if present.
std::unordered_set<int> ro_param_indices; std::unordered_set<int> ro_param_indices;
if (auto opt = if (auto opt =
...@@ -3461,7 +3467,7 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name, ...@@ -3461,7 +3467,7 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
} }
} }
if (no_alias) { if (no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, os); PrintRestrict(v, os);
} }
} else { } else {
...@@ -3497,6 +3503,12 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar, ...@@ -3497,6 +3503,12 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ICHECK(global_symbol) ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
for (const tir::Var &v : opt.value())
non_restrict.insert(v.get());
}
// Read-only param indices attribute, if present. // Read-only param indices attribute, if present.
std::unordered_set<int> ro_param_indices; std::unordered_set<int> ro_param_indices;
if (auto opt = f->GetAttr<ffi::Array<Integer>>("tl.readonly_param_indices")) { if (auto opt = f->GetAttr<ffi::Array<Integer>>("tl.readonly_param_indices")) {
...@@ -3542,7 +3554,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar, ...@@ -3542,7 +3554,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
} }
} }
if (no_alias) { if (no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, stream); PrintRestrict(v, stream);
} }
} else { } else {
......
...@@ -1322,6 +1322,12 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) { ...@@ -1322,6 +1322,12 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
ICHECK(global_symbol.has_value()) ICHECK(global_symbol.has_value())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
for (const tir::Var &v : opt.value())
non_restrict.insert(v.get());
}
this->PrintFuncPrefix(stream); this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream); CodeGenC::PrintType(f->ret_type, stream);
...@@ -1356,7 +1362,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) { ...@@ -1356,7 +1362,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
} }
} }
if (no_alias) { if (no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, stream); PrintRestrict(v, stream);
} }
} else { } else {
......
/*
* Hoist tl.non_restrict_params block annotation(s) to PrimFunc attribute.
*
* Previously, we only looked at the root block. This version recursively
* scans all blocks, unions any tl.non_restrict_params entries it finds,
* merges with any existing PrimFunc-level attribute, then writes the
* deduplicated result back to the PrimFunc attrs. This makes annotation
* placement within the function body flexible for frontends.
*/
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tvm::tir;
class NonRestrictCollector : public StmtVisitor {
public:
void Collect(const Stmt &stmt) { VisitStmt(stmt); }
Array<Var> Result() const {
Array<Var> out;
out.reserve(collected_.size());
for (const Var &v : collected_)
out.push_back(v);
return out;
}
private:
static std::string NormalizeName(const std::string &s) {
if (s.size() >= 8 && s.rfind("_handle") == s.size() - 7) {
return s.substr(0, s.size() - 7);
}
return s;
}
void MaybeInsert(const Var &v) {
if (!v.defined())
return;
const VarNode *p = v.get();
if (seen_ptr_.count(p))
return;
// Also dedup by normalized name to be robust w.r.t recreated Vars
std::string norm = NormalizeName(v->name_hint);
if (seen_name_.count(norm))
return;
seen_ptr_.insert(p);
seen_name_.insert(std::move(norm));
collected_.push_back(v);
}
void VisitStmt_(const BlockNode *op) final {
auto it = op->annotations.find(attr::kNonRestrictParams);
if (it != op->annotations.end()) {
if (const auto *arr = (*it).second.as<ffi::ArrayObj>()) {
// Downcast directly to Array<Var> for convenience
Array<Var> vars = tvm::Downcast<Array<Var>>((*it).second);
for (const Var &v : vars) {
MaybeInsert(v);
}
}
}
// Recurse into child statements
StmtVisitor::VisitStmt_(op);
}
std::vector<Var> collected_;
std::unordered_set<const VarNode *> seen_ptr_;
std::unordered_set<std::string> seen_name_;
};
static PrimFunc HoistNonRestrictParams(PrimFunc f) {
if (!f.defined())
return f;
NonRestrictCollector collector;
collector.Collect(f->body);
Array<Var> from_blocks = collector.Result();
// Merge with any existing PrimFunc-level attribute if present
if (auto opt_existing = f->GetAttr<Array<Var>>(attr::kNonRestrictParams)) {
for (const Var &v : opt_existing.value()) {
// Reuse the collector's dedup logic by temporarily constructing a new
// collector Alternatively, do a small inline dedup mirroring MaybeInsert
// Here we inline a simplified pointer-based dedup plus name-based
// fallback
bool exists = false;
for (const Var &cur : from_blocks) {
if (cur.get() == v.get() || cur->name_hint == v->name_hint) {
exists = true;
break;
}
}
if (!exists)
from_blocks.push_back(v);
}
}
if (from_blocks.empty())
return f;
return WithAttr(std::move(f), attr::kNonRestrictParams,
std::move(from_blocks));
}
namespace transform {
tvm::transform::Pass HoistNonRestrictParams() {
auto pass_func = [](PrimFunc f, const IRModule &,
const tvm::transform::PassContext &) {
return tvm::tl::HoistNonRestrictParams(std::move(f));
};
return tvm::tir::transform::CreatePrimFuncPass(
pass_func, 0, "tl.HoistNonRestrictParams", {});
}
} // namespace transform
} // namespace tl
} // namespace tvm
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.HoistNonRestrictParams",
tvm::tl::transform::HoistNonRestrictParams);
}
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "common/assume.h" #include "common/assume.h"
#include "tir/analysis/var_use_def_analysis.h" #include "tir/analysis/var_use_def_analysis.h"
#include "tvm/node/cast.h" #include "tvm/node/cast.h"
...@@ -57,6 +58,12 @@ public: ...@@ -57,6 +58,12 @@ public:
std::function<GlobalVar()> var_supply) std::function<GlobalVar()> var_supply)
: device_mod_(device_mod), var_supply_(std::move(var_supply)) {} : device_mod_(device_mod), var_supply_(std::move(var_supply)) {}
void SetNonRestrictParams(Optional<Array<tir::Var>> params) {
for (auto param : params.value()) {
non_restrict_params_.push_back(param);
}
}
tir::Stmt VisitStmt_(const tir::AttrStmtNode *op) final { tir::Stmt VisitStmt_(const tir::AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) { if (op->attr_key == tvm::attr::kTarget) {
found_device_region_ = true; found_device_region_ = true;
...@@ -93,6 +100,7 @@ public: ...@@ -93,6 +100,7 @@ public:
private: private:
bool found_device_region_{false}; bool found_device_region_{false};
Array<tir::Var> non_restrict_params_;
Stmt wrapBodyWithHostSideAssumes(Stmt body) { Stmt wrapBodyWithHostSideAssumes(Stmt body) {
for (auto it = host_assumes_.rbegin(); it != host_assumes_.rend(); ++it) { for (auto it = host_assumes_.rbegin(); it != host_assumes_.rend(); ++it) {
...@@ -103,6 +111,7 @@ private: ...@@ -103,6 +111,7 @@ private:
} }
tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) { tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) {
auto [params, buffers_to_declare] = auto [params, buffers_to_declare] =
[&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> { [&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> {
tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{}, tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{},
...@@ -152,9 +161,11 @@ private: ...@@ -152,9 +161,11 @@ private:
tir::PrimFunc device_func(params, body, kernel_ret_type); tir::PrimFunc device_func(params, body, kernel_ret_type);
device_func = device_func =
WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, WithAttrs(std::move(device_func),
{{tvm::attr::kTarget, device_target},
{tir::attr::kNoAlias, true}, {tir::attr::kNoAlias, true},
{tir::attr::kIsGlobalFunc, true}}); {tir::attr::kIsGlobalFunc, true},
{tl::attr::kNonRestrictParams, non_restrict_params_}});
GlobalVar kernel_symbol_global = var_supply_(); GlobalVar kernel_symbol_global = var_supply_();
(*device_mod_)->Add(kernel_symbol_global, device_func); (*device_mod_)->Add(kernel_symbol_global, device_func);
...@@ -188,6 +199,13 @@ private: ...@@ -188,6 +199,13 @@ private:
tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod, tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod,
std::function<GlobalVar()> var_supply) { std::function<GlobalVar()> var_supply) {
HostDeviceSplitter splitter(device_mod, std::move(var_supply)); HostDeviceSplitter splitter(device_mod, std::move(var_supply));
// Propagate non-restrict parameter list from host func to device kernels
if (auto opt = func->GetAttr<Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
splitter.SetNonRestrictParams(opt.value());
// Remove the attribute from host-side PrimFunc; it only matters for device
// codegen.
func = tvm::WithoutAttr(std::move(func), tl::attr::kNonRestrictParams);
}
if (auto body = splitter(func->body); !body.same_as(func->body)) { if (auto body = splitter(func->body); !body.same_as(func->body)) {
func.CopyOnWrite()->body = body; func.CopyOnWrite()->body = body;
...@@ -204,7 +222,6 @@ tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod, ...@@ -204,7 +222,6 @@ tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod,
} }
} }
} }
return func; return func;
} }
...@@ -235,7 +252,6 @@ tvm::transform::Pass SplitHostDevice() { ...@@ -235,7 +252,6 @@ tvm::transform::Pass SplitHostDevice() {
} }
} }
} }
mod->Update(updates); mod->Update(updates);
mod->Update(device_mod); mod->Update(device_mod);
return tir::transform::ConvertSSA()(mod); return tir::transform::ConvertSSA()(mod);
......
import tilelang
import tilelang.language as T
import tilelang.testing
def _get_sig_line(code: str) -> str:
# Find the kernel signature line in generated CUDA code
for line in code.splitlines():
line = line.strip()
if line.startswith('extern "C" __global__ void'):
return line
raise AssertionError("Kernel signature not found in generated code")
@tilelang.testing.requires_cuda
def test_cuda_restrict_default_has_restrict():
N = 128
@T.prim_func
def kernel(x: T.Tensor((N,), T.float32), y: T.Tensor((N,), T.float32)):
with T.Kernel(N, threads=32) as pid:
y[pid] = x[pid] + 1.0
artifact = tilelang.lower(kernel, target="cuda")
sig = _get_sig_line(artifact.kernel_source)
# By default, kNoAlias is set and both pointers are restrict-qualified
assert "__restrict__" in sig
@tilelang.testing.requires_cuda
def test_cuda_restrict_annotation_removes_restrict():
N = 128
@T.prim_func
def kernel_body_annot(x: T.Tensor((N,), T.float32), y: T.Tensor((N,), T.float32)):
# Explicitly mark buffers that may alias as non-restrict
with T.Kernel(N, threads=32) as pid:
T.annotate_restrict_buffers(x, y)
y[pid] = x[pid] + 1.0
art1 = tilelang.lower(kernel_body_annot, target="cuda")
sig1 = _get_sig_line(art1.kernel_source)
# No parameter should be emitted with __restrict__
assert "__restrict__" not in sig1
if __name__ == "__main__":
tilelang.testing.main()
...@@ -175,6 +175,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -175,6 +175,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# TODO(lei): return to tir pass when kSymbolicBound simplification # TODO(lei): return to tir pass when kSymbolicBound simplification
# is merged into tvm. # is merged into tvm.
mod = tilelang.transform.Simplify()(mod) mod = tilelang.transform.Simplify()(mod)
# Hoist any root-block annotations to PrimFunc attrs if pass is available
mod = tilelang.transform.HoistNonRestrictParams()(mod)
return mod return mod
......
...@@ -108,6 +108,7 @@ from .annotations import ( # noqa: F401 ...@@ -108,6 +108,7 @@ from .annotations import ( # noqa: F401
annotate_layout, annotate_layout,
annotate_safe_value, annotate_safe_value,
annotate_l2_hit_ratio, annotate_l2_hit_ratio,
annotate_restrict_buffers,
) )
......
...@@ -11,6 +11,7 @@ __all__ = [ ...@@ -11,6 +11,7 @@ __all__ = [
"annotate_layout", "annotate_layout",
"annotate_safe_value", "annotate_safe_value",
"annotate_l2_hit_ratio", "annotate_l2_hit_ratio",
"annotate_restrict_buffers",
] ]
...@@ -51,3 +52,31 @@ def annotate_l2_hit_ratio(l2_hit_ratio_map: dict): ...@@ -51,3 +52,31 @@ def annotate_l2_hit_ratio(l2_hit_ratio_map: dict):
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers" assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = FloatImm("float32", float(hit_ratio)) _l2_hit_ratio_map[buffer.data] = FloatImm("float32", float(hit_ratio))
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map}) return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
def annotate_restrict_buffers(*buffers):
"""Mark the given buffer parameters as non-restrict.
This annotation tells codegen to omit the `__restrict__` qualifier for the
specified kernel buffer parameters. Use this when two (or more) buffers may
alias, for example overlapping slices from the same base tensor.
Example
-------
>>> @T.prim_func
... def buggy_kernel(x: T.Tensor((N,), T.float32),
... y: T.Tensor((N,), T.float32)):
... T.annotate_restrict_buffers(x, y)
... with T.Kernel(N, threads=32) as pid:
... y[pid] = x[pid] + 1
"""
if not buffers:
return None
data_vars = []
for buf in buffers:
try:
data_vars.append(buf.data)
except Exception as e:
raise TypeError(f"annotate_restrict_buffers expects Buffer arguments, got {type(buf)}") from e
# Also return as block attribute (root block exists by default) for readability/tools.
return block_attr({"tl.non_restrict_params": data_vars})
...@@ -435,6 +435,10 @@ def PlanAndUpdateBufferAllocationLocation(): ...@@ -435,6 +435,10 @@ def PlanAndUpdateBufferAllocationLocation():
return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore
def HoistNonRestrictParams():
return _ffi_api.HoistNonRestrictParams() # type: ignore
def StorageRewrite(): def StorageRewrite():
"""StorageRewrite """StorageRewrite
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment