"vscode:/vscode.git/clone" did not exist on "e9eaa00dcdfe2b4649016af8506c2ae7858432bd"
Unverified Commit bddb125e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Support tilelang `alloc_var(dtype, init=x)` (#1092)

* - carry existing local-var initializer map into OpaqueBlockLower, reattach it to
    generated Allocates and the PrimFunc attrs
  - thread the map through FlattenBuffer and StorageRewrite so flattened/merged
    allocations keep their tl.local_var_init annotations
  - teach annotation handling to accept scalar initializers, resolve buffers, and merge
    with existing stat

* lint fix

* enhance

* lint fix

* lint fix
parent cdc67fc4
......@@ -27,6 +27,7 @@ static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope";
static constexpr const char *kCustomWarpSpecialization =
"kCustomWarpSpecialization";
static constexpr const char *kLocalVarInit = "tl.local_var_init";
} // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations =
......
......@@ -2201,8 +2201,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
} else if (scope == "local") {
stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "local.var") {
stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0))
<< ";\n";
PrimExpr init = tir::make_const(op->dtype, 0);
auto init_it = op->annotations.find(tl::attr::kLocalVarInit);
if (init_it != op->annotations.end()) {
PrimExpr user_init = Downcast<PrimExpr>((*init_it).second);
if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) {
user_init = tir::Cast(op->dtype, user_init);
}
init = user_init;
}
stream << ' ' << vid << " = " << PrintExpr(init) << ";\n";
} else if (scope != "local.descriptor") {
ICHECK(false) << "Unsupported scope: " << scope;
}
......
......@@ -25,6 +25,7 @@
#include "tir/transforms/ir_utils.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/stmt_functor.h>
......@@ -32,6 +33,8 @@
#include <utility>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
......@@ -46,6 +49,10 @@ public:
static PrimFunc Flatten(PrimFunc func) {
arith::Analyzer ana;
auto pass = BufferFlattener(&ana);
if (auto init_map =
func->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
pass.local_var_init_map_ = init_map.value();
}
auto writer = func.CopyOnWrite();
pass.MarkBufferMapShapes(func);
writer->body = pass.VisitStmt(func->body);
......@@ -198,6 +205,13 @@ private:
if (!new_extents.same_as(alloc->extents)) {
alloc.CopyOnWrite()->extents = new_extents;
}
if (!local_var_init_map_.empty()) {
auto init_it = local_var_init_map_.find(alloc->buffer_var);
if (init_it != local_var_init_map_.end()) {
const PrimExpr &init = (*init_it).second;
alloc.CopyOnWrite()->annotations.Set(tl::attr::kLocalVarInit, init);
}
}
return std::move(alloc);
}
......@@ -354,6 +368,9 @@ private:
/*! \brief The updated external buffer map. */
Map<Var, Buffer> updated_extern_buffer_map_;
/*! \brief Local var initializers preserved from block annotations. */
Map<Var, PrimExpr> local_var_init_map_;
};
PrimFunc FlattenBufferRewriter(PrimFunc f) {
......
......@@ -22,11 +22,14 @@
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <string>
#include <utility>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
......@@ -39,10 +42,20 @@ using namespace tir::attr;
*/
class OpaqueBlockLower : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt body) {
static PrimFunc Rewrite(PrimFunc f) {
auto fptr = f.CopyOnWrite();
OpaqueBlockLower lower;
lower.storage_align_ = CollectStorageAlignAnnotation(body);
return lower(std::move(body));
if (auto existing =
fptr->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
lower.local_var_init_map_ = existing.value();
}
lower.storage_align_ = CollectStorageAlignAnnotation(fptr->body);
fptr->body = lower(std::move(fptr->body));
if (!lower.local_var_init_map_.empty()) {
f = WithAttr(std::move(f), tl::attr::kLocalVarInit,
lower.local_var_init_map_);
}
return f;
}
private:
......@@ -59,7 +72,13 @@ private:
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// Step 3. Handle allocations in reverse order
// Step 3. Handle annotations, block annotations are not preserved by
// default.
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true,
new_block->alloc_buffers);
// Step 4. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer &buffer = new_block->alloc_buffers[i - 1];
Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer);
......@@ -74,14 +93,15 @@ private:
}
allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns);
}
auto init_it = local_var_init_map_.find(buffer->data);
if (init_it != local_var_init_map_.end()) {
const PrimExpr &init = (*init_it).second;
allocate_annotations.Set(tl::attr::kLocalVarInit, init);
}
body = Allocate(buffer->data, buffer->dtype, allocation_shape,
const_true(), std::move(body), allocate_annotations);
}
// Step 4. Handle annotations, block annotations are not preserved by
// default.
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true);
// Step 5. Insert attribute statements converted from pragmas
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
body = AttrStmt(Integer(0), it->first, it->second, std::move(body));
}
......@@ -188,13 +208,34 @@ private:
Map<String, ffi::Any>
HandleAnnotations(const Map<String, ffi::Any> &annotations,
std::vector<std::pair<std::string, PrimExpr>> *pragma_attrs,
bool is_block) {
bool is_block,
const Array<Buffer> &alloc_buffers = Array<Buffer>()) {
Map<String, ffi::Any> preserved_annotations;
pragma_attrs->clear();
for (const auto &kv : annotations) {
const String &key = kv.first;
if (tir::attr::IsPragmaKey(key)) {
pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second));
} else if (key == tl::attr::kLocalVarInit) {
if (auto local_init_map = kv.second.try_cast<Map<Var, PrimExpr>>()) {
for (const auto &pair : local_init_map.value()) {
local_var_init_map_.Set(pair.first, pair.second);
}
} else if (auto init_expr = kv.second.try_cast<PrimExpr>()) {
ICHECK(is_block) << "`" << tl::attr::kLocalVarInit
<< "` on non-block annotations is not supported";
Buffer target = ResolveLocalVarBuffer(alloc_buffers);
if (!target.defined()) {
LOG(WARNING) << "Failed to resolve buffer for `"
<< tl::attr::kLocalVarInit << "` annotation";
continue;
}
local_var_init_map_.Set(target->data, init_expr.value());
} else {
LOG(FATAL) << "Expected `" << tl::attr::kLocalVarInit
<< "` to be a PrimExpr or Map<Var, PrimExpr>, but got "
<< kv.second.GetTypeKey();
}
} else if (!is_block) {
// the loop annotation is preserved
preserved_annotations.Set(key, kv.second);
......@@ -206,6 +247,19 @@ private:
return preserved_annotations;
}
Buffer ResolveLocalVarBuffer(const Array<Buffer> &alloc_buffers) const {
for (const Buffer &buffer : alloc_buffers) {
std::string scope = buffer.scope();
if (scope.find("local.var") != std::string::npos) {
return buffer;
}
}
if (!alloc_buffers.empty()) {
return alloc_buffers.back();
}
return Buffer();
}
/*! \brief Record the loop_var and loop start value of unit loops, whose
* extent is one. */
std::unordered_map<Var, PrimExpr> unit_loop_vars_;
......@@ -215,12 +269,13 @@ private:
/*! \brief The map from buffer var to its storage alignment information. */
std::unordered_map<Var, StorageAlignAnnotation> storage_align_;
/*! \brief Local var initializers collected from block annotations. */
Map<Var, PrimExpr> local_var_init_map_;
};
PrimFunc TLLowerOpaqueBlock(PrimFunc f) {
auto fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
return f;
return OpaqueBlockLower::Rewrite(std::move(f));
}
tir::transform::Pass LowerOpaqueBlock() {
......
......@@ -25,6 +25,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/type.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/analysis.h>
......@@ -468,8 +469,10 @@ public:
using AllocEntry = LinearAccessPatternFinder::AllocEntry;
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
bool reuse_require_exact_matched_dtype) {
bool reuse_require_exact_matched_dtype,
Map<Var, PrimExpr> local_var_init_map = {}) {
detect_inplace_ = detect_inplace;
local_var_init_map_ = std::move(local_var_init_map);
// plan the rewrite
LinearAccessPatternFinder finder;
finder(stmt);
......@@ -694,6 +697,17 @@ private:
}
return body;
}
Map<String, ffi::Any> MakeAllocateAnnotations(const Var &buffer_var) const {
Map<String, ffi::Any> annotations;
if (local_var_init_map_.defined()) {
auto it = local_var_init_map_.find(buffer_var);
if (it != local_var_init_map_.end()) {
const PrimExpr &init = (*it).second;
annotations.Set(tl::attr::kLocalVarInit, init);
}
}
return annotations;
}
// Remap the index
PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry *e) {
if (e->bits_offset == 0)
......@@ -766,9 +780,11 @@ private:
if (all_allocs_identical) {
// simply use the original allocation.
e->alloc_nest.push_back(
Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0)));
Map<String, ffi::Any> annotations =
MakeAllocateAnnotations(e->alloc_var);
e->alloc_nest.push_back(Allocate(
e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0), std::move(annotations)));
if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) {
e->alloc_nest.push_back(DeclBuffer(
RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0)));
......@@ -824,9 +840,11 @@ private:
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = analyzer_.Simplify(combo_size);
e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type,
{combo_size}, const_true(),
Evaluate(0)));
Map<String, ffi::Any> annotations =
MakeAllocateAnnotations(e->alloc_var);
e->alloc_nest.push_back(
Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate(0), std::move(annotations)));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
......@@ -875,8 +893,10 @@ private:
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
(total_bits + type_bits - 1) / type_bits);
Map<String, ffi::Any> annotations = MakeAllocateAnnotations(e->alloc_var);
e->alloc_nest.push_back(Allocate(e->alloc_var, e->elem_type, {alloc_size},
const_true(), Evaluate(0)));
const_true(), Evaluate(0),
std::move(annotations)));
if (info.defined()) {
ICHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
......@@ -1178,6 +1198,8 @@ private:
// Any buffers that is accessed at some point. DeclBuffer instances
// that do not appear in this list may be removed.
std::unordered_set<const BufferNode *> all_buffers_accessed_;
// Initial values for local variable buffers.
Map<Var, PrimExpr> local_var_init_map_;
// analyzer
arith::Analyzer analyzer_;
};
......@@ -1795,7 +1817,7 @@ public:
DLOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents;
return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body);
op->condition, op->body, op->annotations);
}
Stmt VisitStmt_(const AllocateConstNode *op) final {
......@@ -1941,10 +1963,16 @@ Pass StorageRewrite() {
// Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU
reuse_require_exact_matched_dtype = true;
}
Map<Var, PrimExpr> local_var_init_map;
if (auto init_map =
f->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
local_var_init_map = init_map.value();
}
auto *n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), detect_inplace,
enable_reuse,
reuse_require_exact_matched_dtype);
StoragePlanRewriter plan_rewriter;
n->body = plan_rewriter.Rewrite(
std::move(n->body), detect_inplace, enable_reuse,
reuse_require_exact_matched_dtype, std::move(local_var_init_map));
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
......
......@@ -81,5 +81,85 @@ def test_alloc_var_add():
run_alloc_var_add(1024, 128, "float16")
def alloc_var_with_initializer(
N,
block_N,
dtype,
init_value,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
tmp = T.alloc_var(dtype, init_value)
T.copy(A[bx * block_N], B[bx * block_N])
for i in T.Parallel(block_N):
B[bx * block_N + i] = tmp
return main
def run_alloc_var_with_initializer(
N,
block_N,
dtype,
init_value,
):
program = alloc_var_with_initializer(N, block_N, dtype, init_value)
kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source()
print(code)
assert f"= {init_value};" in code
def test_alloc_var_with_initializer():
run_alloc_var_with_initializer(256, 64, "int32", 5)
def alloc_multi_vars_with_initializer(
N,
block_N,
dtype,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
tmp0 = T.alloc_var(dtype, 1)
tmp1 = T.alloc_var(dtype, 2)
T.copy(A[bx * block_N], B[bx * block_N])
for i in T.Parallel(block_N):
B[bx * block_N + i] = tmp0 + tmp1
return main
def run_alloc_multi_vars_with_initializer(
N,
block_N,
dtype,
):
program = alloc_multi_vars_with_initializer(N, block_N, dtype)
kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source()
print(code)
assert code.count("= 1;") == 1
assert code.count("= 2;") == 1
def test_alloc_multi_vars_with_initializer():
run_alloc_multi_vars_with_initializer(256, 64, "int32")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -16,6 +16,9 @@ with the appropriate memory scope.
from tilelang import tvm as tvm
from tvm.script import tir as T
from tvm.tir import PrimExpr
from tvm.script.parser.tir import block_attr
from typing import Union
def alloc_shared(shape, dtype, scope="shared.dyn"):
......@@ -64,17 +67,54 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_var(dtype, scope="local.var"):
def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None):
"""Allocate a single-element variable buffer.
Args:
dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "local.var"
*args: Optional positional arguments. A single positional string is treated
as the scope for backward compatibility. A single non-string positional
argument (or keyword ``init``) specifies the initializer. When two
positional arguments are provided, they are interpreted as
``(init, scope)``.
scope (str, optional): The memory scope. Defaults to "local.var".
Use as keyword argument for clarity when also providing an initializer.
init (PrimExpr, optional): The optional initializer value. When provided,
the generated code will initialize the variable with this value instead
of defaulting to zero.
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
return T.alloc_buffer([1], dtype, scope=scope)
parsed_scope = scope
parsed_init = init
if len(args) == 1:
arg = args[0]
if isinstance(arg, str) and parsed_init is None and scope == "local.var":
parsed_scope = arg
else:
if parsed_init is not None:
raise TypeError("Initializer specified multiple times in alloc_var.")
parsed_init = arg
elif len(args) == 2:
if parsed_init is not None:
raise TypeError("Initializer specified multiple times in alloc_var.")
parsed_init, parsed_scope_arg = args
if not isinstance(parsed_scope_arg, str):
raise TypeError("Scope must be provided as a string in alloc_var.")
parsed_scope = parsed_scope_arg
elif len(args) > 2:
raise TypeError(
f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.")
if not isinstance(parsed_scope, str):
raise TypeError("Scope must be a string in alloc_var.")
buffer = T.alloc_buffer([1], dtype, scope=parsed_scope)
if parsed_init is not None:
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
return buffer
def alloc_barrier(arrive_count: int):
......@@ -141,7 +181,6 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
Returns:
T.Buffer: A TVM buffer object allocated in thread-private storage, available to reduce values in T.Parallel loops.
"""
import tilelang.language as TL
assert op in ["sum", "max", "min"]
# TODO: support automatic layout
......@@ -150,7 +189,7 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
assert replication in ["all", "none"]
reducer = T.alloc_buffer(shape, dtype, scope="local.fragment")
TL.block_attr({"reducer_info": {reducer.data: {"rep": replication, "op": op}}})
block_attr({"reducer_info": {reducer.data: {"rep": replication, "op": op}}})
return reducer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment