Unverified Commit bddb125e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Support tilelang `alloc_var(dtype, init=x)` (#1092)

* - carry existing local-var initializer map into OpaqueBlockLower, reattach it to
    generated Allocates and the PrimFunc attrs
  - thread the map through FlattenBuffer and StorageRewrite so flattened/merged
    allocations keep their tl.local_var_init annotations
  - teach annotation handling to accept scalar initializers, resolve buffers, and merge
    with existing stat

* lint fix

* enhance

* lint fix

* lint fix
parent cdc67fc4
...@@ -27,6 +27,7 @@ static constexpr const char *kWarpSpecializationScope = ...@@ -27,6 +27,7 @@ static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope"; "kWarpSpecializationScope";
static constexpr const char *kCustomWarpSpecialization = static constexpr const char *kCustomWarpSpecialization =
"kCustomWarpSpecialization"; "kCustomWarpSpecialization";
static constexpr const char *kLocalVarInit = "tl.local_var_init";
} // namespace attr } // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations = static constexpr const char *kDebugMergeSharedMemoryAllocations =
......
...@@ -2201,8 +2201,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ...@@ -2201,8 +2201,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
} else if (scope == "local") { } else if (scope == "local") {
stream << ' ' << vid << '[' << constant_size << "];\n"; stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "local.var") { } else if (scope == "local.var") {
stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0)) PrimExpr init = tir::make_const(op->dtype, 0);
<< ";\n"; auto init_it = op->annotations.find(tl::attr::kLocalVarInit);
if (init_it != op->annotations.end()) {
PrimExpr user_init = Downcast<PrimExpr>((*init_it).second);
if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) {
user_init = tir::Cast(op->dtype, user_init);
}
init = user_init;
}
stream << ' ' << vid << " = " << PrintExpr(init) << ";\n";
} else if (scope != "local.descriptor") { } else if (scope != "local.descriptor") {
ICHECK(false) << "Unsupported scope: " << scope; ICHECK(false) << "Unsupported scope: " << scope;
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/data_type_rewriter.h> #include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -32,6 +33,8 @@ ...@@ -32,6 +33,8 @@
#include <utility> #include <utility>
#include "../op/builtin.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -46,6 +49,10 @@ public: ...@@ -46,6 +49,10 @@ public:
static PrimFunc Flatten(PrimFunc func) { static PrimFunc Flatten(PrimFunc func) {
arith::Analyzer ana; arith::Analyzer ana;
auto pass = BufferFlattener(&ana); auto pass = BufferFlattener(&ana);
if (auto init_map =
func->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
pass.local_var_init_map_ = init_map.value();
}
auto writer = func.CopyOnWrite(); auto writer = func.CopyOnWrite();
pass.MarkBufferMapShapes(func); pass.MarkBufferMapShapes(func);
writer->body = pass.VisitStmt(func->body); writer->body = pass.VisitStmt(func->body);
...@@ -198,6 +205,13 @@ private: ...@@ -198,6 +205,13 @@ private:
if (!new_extents.same_as(alloc->extents)) { if (!new_extents.same_as(alloc->extents)) {
alloc.CopyOnWrite()->extents = new_extents; alloc.CopyOnWrite()->extents = new_extents;
} }
if (!local_var_init_map_.empty()) {
auto init_it = local_var_init_map_.find(alloc->buffer_var);
if (init_it != local_var_init_map_.end()) {
const PrimExpr &init = (*init_it).second;
alloc.CopyOnWrite()->annotations.Set(tl::attr::kLocalVarInit, init);
}
}
return std::move(alloc); return std::move(alloc);
} }
...@@ -354,6 +368,9 @@ private: ...@@ -354,6 +368,9 @@ private:
/*! \brief The updated external buffer map. */ /*! \brief The updated external buffer map. */
Map<Var, Buffer> updated_extern_buffer_map_; Map<Var, Buffer> updated_extern_buffer_map_;
/*! \brief Local var initializers preserved from block annotations. */
Map<Var, PrimExpr> local_var_init_map_;
}; };
PrimFunc FlattenBufferRewriter(PrimFunc f) { PrimFunc FlattenBufferRewriter(PrimFunc f) {
......
...@@ -22,11 +22,14 @@ ...@@ -22,11 +22,14 @@
*/ */
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <string>
#include <utility> #include <utility>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
...@@ -39,10 +42,20 @@ using namespace tir::attr; ...@@ -39,10 +42,20 @@ using namespace tir::attr;
*/ */
class OpaqueBlockLower : public StmtExprMutator { class OpaqueBlockLower : public StmtExprMutator {
public: public:
static Stmt Rewrite(Stmt body) { static PrimFunc Rewrite(PrimFunc f) {
auto fptr = f.CopyOnWrite();
OpaqueBlockLower lower; OpaqueBlockLower lower;
lower.storage_align_ = CollectStorageAlignAnnotation(body); if (auto existing =
return lower(std::move(body)); fptr->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
lower.local_var_init_map_ = existing.value();
}
lower.storage_align_ = CollectStorageAlignAnnotation(fptr->body);
fptr->body = lower(std::move(fptr->body));
if (!lower.local_var_init_map_.empty()) {
f = WithAttr(std::move(f), tl::attr::kLocalVarInit,
lower.local_var_init_map_);
}
return f;
} }
private: private:
...@@ -59,7 +72,13 @@ private: ...@@ -59,7 +72,13 @@ private:
if (!is_one(predicate)) { if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body)); body = IfThenElse(predicate, std::move(body));
} }
// Step 3. Handle allocations in reverse order // Step 3. Handle annotations, block annotations are not preserved by
// default.
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true,
new_block->alloc_buffers);
// Step 4. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer &buffer = new_block->alloc_buffers[i - 1]; const Buffer &buffer = new_block->alloc_buffers[i - 1];
Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer); Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer);
...@@ -74,14 +93,15 @@ private: ...@@ -74,14 +93,15 @@ private:
} }
allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns); allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns);
} }
auto init_it = local_var_init_map_.find(buffer->data);
if (init_it != local_var_init_map_.end()) {
const PrimExpr &init = (*init_it).second;
allocate_annotations.Set(tl::attr::kLocalVarInit, init);
}
body = Allocate(buffer->data, buffer->dtype, allocation_shape, body = Allocate(buffer->data, buffer->dtype, allocation_shape,
const_true(), std::move(body), allocate_annotations); const_true(), std::move(body), allocate_annotations);
} }
// Step 4. Handle annotations, block annotations are not preserved by // Step 5. Insert attribute statements converted from pragmas
// default.
std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true);
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
body = AttrStmt(Integer(0), it->first, it->second, std::move(body)); body = AttrStmt(Integer(0), it->first, it->second, std::move(body));
} }
...@@ -188,13 +208,34 @@ private: ...@@ -188,13 +208,34 @@ private:
Map<String, ffi::Any> Map<String, ffi::Any>
HandleAnnotations(const Map<String, ffi::Any> &annotations, HandleAnnotations(const Map<String, ffi::Any> &annotations,
std::vector<std::pair<std::string, PrimExpr>> *pragma_attrs, std::vector<std::pair<std::string, PrimExpr>> *pragma_attrs,
bool is_block) { bool is_block,
const Array<Buffer> &alloc_buffers = Array<Buffer>()) {
Map<String, ffi::Any> preserved_annotations; Map<String, ffi::Any> preserved_annotations;
pragma_attrs->clear(); pragma_attrs->clear();
for (const auto &kv : annotations) { for (const auto &kv : annotations) {
const String &key = kv.first; const String &key = kv.first;
if (tir::attr::IsPragmaKey(key)) { if (tir::attr::IsPragmaKey(key)) {
pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second)); pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second));
} else if (key == tl::attr::kLocalVarInit) {
if (auto local_init_map = kv.second.try_cast<Map<Var, PrimExpr>>()) {
for (const auto &pair : local_init_map.value()) {
local_var_init_map_.Set(pair.first, pair.second);
}
} else if (auto init_expr = kv.second.try_cast<PrimExpr>()) {
ICHECK(is_block) << "`" << tl::attr::kLocalVarInit
<< "` on non-block annotations is not supported";
Buffer target = ResolveLocalVarBuffer(alloc_buffers);
if (!target.defined()) {
LOG(WARNING) << "Failed to resolve buffer for `"
<< tl::attr::kLocalVarInit << "` annotation";
continue;
}
local_var_init_map_.Set(target->data, init_expr.value());
} else {
LOG(FATAL) << "Expected `" << tl::attr::kLocalVarInit
<< "` to be a PrimExpr or Map<Var, PrimExpr>, but got "
<< kv.second.GetTypeKey();
}
} else if (!is_block) { } else if (!is_block) {
// the loop annotation is preserved // the loop annotation is preserved
preserved_annotations.Set(key, kv.second); preserved_annotations.Set(key, kv.second);
...@@ -206,6 +247,19 @@ private: ...@@ -206,6 +247,19 @@ private:
return preserved_annotations; return preserved_annotations;
} }
Buffer ResolveLocalVarBuffer(const Array<Buffer> &alloc_buffers) const {
for (const Buffer &buffer : alloc_buffers) {
std::string scope = buffer.scope();
if (scope.find("local.var") != std::string::npos) {
return buffer;
}
}
if (!alloc_buffers.empty()) {
return alloc_buffers.back();
}
return Buffer();
}
/*! \brief Record the loop_var and loop start value of unit loops, whose /*! \brief Record the loop_var and loop start value of unit loops, whose
* extent is one. */ * extent is one. */
std::unordered_map<Var, PrimExpr> unit_loop_vars_; std::unordered_map<Var, PrimExpr> unit_loop_vars_;
...@@ -215,12 +269,13 @@ private: ...@@ -215,12 +269,13 @@ private:
/*! \brief The map from buffer var to its storage alignment information. */ /*! \brief The map from buffer var to its storage alignment information. */
std::unordered_map<Var, StorageAlignAnnotation> storage_align_; std::unordered_map<Var, StorageAlignAnnotation> storage_align_;
/*! \brief Local var initializers collected from block annotations. */
Map<Var, PrimExpr> local_var_init_map_;
}; };
PrimFunc TLLowerOpaqueBlock(PrimFunc f) { PrimFunc TLLowerOpaqueBlock(PrimFunc f) {
auto fptr = f.CopyOnWrite(); return OpaqueBlockLower::Rewrite(std::move(f));
fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
return f;
} }
tir::transform::Pass LowerOpaqueBlock() { tir::transform::Pass LowerOpaqueBlock() {
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h> #include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <tvm/target/target_info.h> #include <tvm/target/target_info.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
...@@ -468,8 +469,10 @@ public: ...@@ -468,8 +469,10 @@ public:
using AllocEntry = LinearAccessPatternFinder::AllocEntry; using AllocEntry = LinearAccessPatternFinder::AllocEntry;
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse, Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
bool reuse_require_exact_matched_dtype) { bool reuse_require_exact_matched_dtype,
Map<Var, PrimExpr> local_var_init_map = {}) {
detect_inplace_ = detect_inplace; detect_inplace_ = detect_inplace;
local_var_init_map_ = std::move(local_var_init_map);
// plan the rewrite // plan the rewrite
LinearAccessPatternFinder finder; LinearAccessPatternFinder finder;
finder(stmt); finder(stmt);
...@@ -694,6 +697,17 @@ private: ...@@ -694,6 +697,17 @@ private:
} }
return body; return body;
} }
Map<String, ffi::Any> MakeAllocateAnnotations(const Var &buffer_var) const {
Map<String, ffi::Any> annotations;
if (local_var_init_map_.defined()) {
auto it = local_var_init_map_.find(buffer_var);
if (it != local_var_init_map_.end()) {
const PrimExpr &init = (*it).second;
annotations.Set(tl::attr::kLocalVarInit, init);
}
}
return annotations;
}
// Remap the index // Remap the index
PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry *e) { PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry *e) {
if (e->bits_offset == 0) if (e->bits_offset == 0)
...@@ -766,9 +780,11 @@ private: ...@@ -766,9 +780,11 @@ private:
if (all_allocs_identical) { if (all_allocs_identical) {
// simply use the original allocation. // simply use the original allocation.
e->alloc_nest.push_back( Map<String, ffi::Any> annotations =
Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents, MakeAllocateAnnotations(e->alloc_var);
e->allocs[0]->condition, Evaluate(0))); e->alloc_nest.push_back(Allocate(
e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0), std::move(annotations)));
if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) { if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) {
e->alloc_nest.push_back(DeclBuffer( e->alloc_nest.push_back(DeclBuffer(
RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0))); RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0)));
...@@ -824,9 +840,11 @@ private: ...@@ -824,9 +840,11 @@ private:
combo_size = combo_size + make_const(DataType::Int(32), 1); combo_size = combo_size + make_const(DataType::Int(32), 1);
} }
combo_size = analyzer_.Simplify(combo_size); combo_size = analyzer_.Simplify(combo_size);
e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type, Map<String, ffi::Any> annotations =
{combo_size}, const_true(), MakeAllocateAnnotations(e->alloc_var);
Evaluate(0))); e->alloc_nest.push_back(
Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate(0), std::move(annotations)));
if (IsSpecialTaggedMemory(e->scope)) { if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string()); MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) { if (info.defined()) {
...@@ -875,8 +893,10 @@ private: ...@@ -875,8 +893,10 @@ private:
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
(total_bits + type_bits - 1) / type_bits); (total_bits + type_bits - 1) / type_bits);
Map<String, ffi::Any> annotations = MakeAllocateAnnotations(e->alloc_var);
e->alloc_nest.push_back(Allocate(e->alloc_var, e->elem_type, {alloc_size}, e->alloc_nest.push_back(Allocate(e->alloc_var, e->elem_type, {alloc_size},
const_true(), Evaluate(0))); const_true(), Evaluate(0),
std::move(annotations)));
if (info.defined()) { if (info.defined()) {
ICHECK_LE(total_bits, info->max_num_bits) ICHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string(); << "Allocation exceed bound of memory tag " << e->scope.to_string();
...@@ -1178,6 +1198,8 @@ private: ...@@ -1178,6 +1198,8 @@ private:
// Any buffers that is accessed at some point. DeclBuffer instances // Any buffers that is accessed at some point. DeclBuffer instances
// that do not appear in this list may be removed. // that do not appear in this list may be removed.
std::unordered_set<const BufferNode *> all_buffers_accessed_; std::unordered_set<const BufferNode *> all_buffers_accessed_;
// Initial values for local variable buffers.
Map<Var, PrimExpr> local_var_init_map_;
// analyzer // analyzer
arith::Analyzer analyzer_; arith::Analyzer analyzer_;
}; };
...@@ -1795,7 +1817,7 @@ public: ...@@ -1795,7 +1817,7 @@ public:
DLOG(INFO) << "Allocate with " << new_buffer_var << " and " DLOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents; << info.new_element_dtype << " extents: " << extents;
return Allocate(new_buffer_var, info.new_element_dtype, extents, return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body); op->condition, op->body, op->annotations);
} }
Stmt VisitStmt_(const AllocateConstNode *op) final { Stmt VisitStmt_(const AllocateConstNode *op) final {
...@@ -1941,10 +1963,16 @@ Pass StorageRewrite() { ...@@ -1941,10 +1963,16 @@ Pass StorageRewrite() {
// Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU // Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU
reuse_require_exact_matched_dtype = true; reuse_require_exact_matched_dtype = true;
} }
Map<Var, PrimExpr> local_var_init_map;
if (auto init_map =
f->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
local_var_init_map = init_map.value();
}
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), detect_inplace, StoragePlanRewriter plan_rewriter;
enable_reuse, n->body = plan_rewriter.Rewrite(
reuse_require_exact_matched_dtype); std::move(n->body), detect_inplace, enable_reuse,
reuse_require_exact_matched_dtype, std::move(local_var_init_map));
// Parameters may not be rewritten, but internal allocations may. // Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has // Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3 // indexing issues for types that include padding (e.g. int8x3
......
...@@ -81,5 +81,85 @@ def test_alloc_var_add(): ...@@ -81,5 +81,85 @@ def test_alloc_var_add():
run_alloc_var_add(1024, 128, "float16") run_alloc_var_add(1024, 128, "float16")
def alloc_var_with_initializer(
N,
block_N,
dtype,
init_value,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
tmp = T.alloc_var(dtype, init_value)
T.copy(A[bx * block_N], B[bx * block_N])
for i in T.Parallel(block_N):
B[bx * block_N + i] = tmp
return main
def run_alloc_var_with_initializer(
N,
block_N,
dtype,
init_value,
):
program = alloc_var_with_initializer(N, block_N, dtype, init_value)
kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source()
print(code)
assert f"= {init_value};" in code
def test_alloc_var_with_initializer():
run_alloc_var_with_initializer(256, 64, "int32", 5)
def alloc_multi_vars_with_initializer(
N,
block_N,
dtype,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
tmp0 = T.alloc_var(dtype, 1)
tmp1 = T.alloc_var(dtype, 2)
T.copy(A[bx * block_N], B[bx * block_N])
for i in T.Parallel(block_N):
B[bx * block_N + i] = tmp0 + tmp1
return main
def run_alloc_multi_vars_with_initializer(
N,
block_N,
dtype,
):
program = alloc_multi_vars_with_initializer(N, block_N, dtype)
kernel = tilelang.compile(program, out_idx=[1])
code = kernel.get_kernel_source()
print(code)
assert code.count("= 1;") == 1
assert code.count("= 2;") == 1
def test_alloc_multi_vars_with_initializer():
run_alloc_multi_vars_with_initializer(256, 64, "int32")
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -16,6 +16,9 @@ with the appropriate memory scope. ...@@ -16,6 +16,9 @@ with the appropriate memory scope.
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.script import tir as T from tvm.script import tir as T
from tvm.tir import PrimExpr
from tvm.script.parser.tir import block_attr
from typing import Union
def alloc_shared(shape, dtype, scope="shared.dyn"): def alloc_shared(shape, dtype, scope="shared.dyn"):
...@@ -64,17 +67,54 @@ def alloc_fragment(shape, dtype, scope="local.fragment"): ...@@ -64,17 +67,54 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_var(dtype, scope="local.var"): def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None):
"""Allocate a single-element variable buffer. """Allocate a single-element variable buffer.
Args: Args:
dtype (str): The data type of the buffer (e.g., 'float32', 'int32') dtype (str): The data type of the buffer (e.g., 'float32', 'int32')
scope (str, optional): The memory scope. Defaults to "local.var" *args: Optional positional arguments. A single positional string is treated
as the scope for backward compatibility. A single non-string positional
argument (or keyword ``init``) specifies the initializer. When two
positional arguments are provided, they are interpreted as
``(init, scope)``.
scope (str, optional): The memory scope. Defaults to "local.var".
Use as keyword argument for clarity when also providing an initializer.
init (PrimExpr, optional): The optional initializer value. When provided,
the generated code will initialize the variable with this value instead
of defaulting to zero.
Returns: Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable T.Buffer: A TVM buffer object allocated as a single-element variable
""" """
return T.alloc_buffer([1], dtype, scope=scope) parsed_scope = scope
parsed_init = init
if len(args) == 1:
arg = args[0]
if isinstance(arg, str) and parsed_init is None and scope == "local.var":
parsed_scope = arg
else:
if parsed_init is not None:
raise TypeError("Initializer specified multiple times in alloc_var.")
parsed_init = arg
elif len(args) == 2:
if parsed_init is not None:
raise TypeError("Initializer specified multiple times in alloc_var.")
parsed_init, parsed_scope_arg = args
if not isinstance(parsed_scope_arg, str):
raise TypeError("Scope must be provided as a string in alloc_var.")
parsed_scope = parsed_scope_arg
elif len(args) > 2:
raise TypeError(
f"alloc_var expected at most 3 positional arguments but got {len(args) + 1}.")
if not isinstance(parsed_scope, str):
raise TypeError("Scope must be a string in alloc_var.")
buffer = T.alloc_buffer([1], dtype, scope=parsed_scope)
if parsed_init is not None:
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
return buffer
def alloc_barrier(arrive_count: int): def alloc_barrier(arrive_count: int):
...@@ -141,7 +181,6 @@ def alloc_reducer(shape, dtype, op="sum", replication=None): ...@@ -141,7 +181,6 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
Returns: Returns:
T.Buffer: A TVM buffer object allocated in thread-private storage, available to reduce values in T.Parallel loops. T.Buffer: A TVM buffer object allocated in thread-private storage, available to reduce values in T.Parallel loops.
""" """
import tilelang.language as TL
assert op in ["sum", "max", "min"] assert op in ["sum", "max", "min"]
# TODO: support automatic layout # TODO: support automatic layout
...@@ -150,7 +189,7 @@ def alloc_reducer(shape, dtype, op="sum", replication=None): ...@@ -150,7 +189,7 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
assert replication in ["all", "none"] assert replication in ["all", "none"]
reducer = T.alloc_buffer(shape, dtype, scope="local.fragment") reducer = T.alloc_buffer(shape, dtype, scope="local.fragment")
TL.block_attr({"reducer_info": {reducer.data: {"rep": replication, "op": op}}}) block_attr({"reducer_info": {reducer.data: {"rep": replication, "op": op}}})
return reducer return reducer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment