Unverified Commit a03df604 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Feature] Enhance fill operation to support various buffer types (#1189)

* [Feature] Enhance fill operation to support various buffer types

- Added support for `BufferLoad` in the `fill` function to handle different buffer types.
- Updated `Fill` class to process region descriptors and buffer regions, improving flexibility in buffer handling.
- Introduced checks for static bounds in region definitions to ensure safety during operations.
- Refactored loop induction variable handling in `FillNode` to accommodate sliced regions.

* lint fix
parent 1768cbef
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h" #include "../transform/loop_vectorize.h"
#include "builtin.h" #include "builtin.h"
#include "region.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -62,7 +63,30 @@ using namespace tir; ...@@ -62,7 +63,30 @@ using namespace tir;
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>(); ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
if (args[0]->IsInstance<BufferLoadNode>()) { // Case 1: Region descriptor call (tl.region)
if (const auto *call = args[0].as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}
// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;
// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]); auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) { for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) { if (const auto *ramp = index.as<RampNode>()) {
...@@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { ...@@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
} }
} }
node->dst = buffer_load->buffer; node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else { } else {
node->dst = vmap[GetVarFromAccessPtr(args[0])]; node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) { for (int i = 0; i < node->dst->shape.size(); i++) {
...@@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { ...@@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
<< " != " << node->dst->shape.size(); << " != " << node->dst->shape.size();
for (int i = 0; i < node->region.size(); i++) { for (int i = 0; i < node->region.size(); i++) {
// bound check if region is static // bound check if region is static
if (node->region[i]->min.as<IntImm>()) { if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
int64_t min = Downcast<IntImm>(node->region[i]->min)->value; int64_t min = min_imm->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0"; ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
} }
if (node->region[i]->extent.as<IntImm>()) { if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value; // Only perform the upper-bound check when the destination shape
ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value) // extent is also statically known. If the shape is symbolic (e.g., Var),
<< "region[" << i << "] = " << extent << " > " << node->dst->shape[i]; // skip this static check to avoid invalid downcasts.
if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
ICHECK_LE(extent_imm->value, shape_imm->value)
<< "region[" << i << "] = " << extent_imm->value << " > "
<< node->dst->shape[i];
}
} }
} }
data_ = std::move(node); data_ = std::move(node);
...@@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (int i = 0; i < ndim; i++) { for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype); Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
loop_vars.push_back({region[i], var, IterVarType::kDataPar}); loop_vars.push_back({region[i], var, IterVarType::kDataPar});
dst_indices.push_back(var); // Offset the loop induction variable by region min to honor sliced regions
dst_indices.push_back(region[i]->min + var);
} }
Stmt body = BufferStore(dst, value, dst_indices); Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) { for (int i = ndim - 1; i >= 0; i--) {
...@@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop; return vectorized_thread_loop;
} else { } else {
LOG(FATAL) << "Unsupported scope " << dst.scope(); LOG(FATAL) << "Unsupported scope " << dst.scope();
return Stmt();
} }
} }
......
import torch
import tilelang
import tilelang.testing
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
T.fill(x[0:128], 0)
return buggy_kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var('int'), T.alloc_var('int')
T.fill(x[a:b], 0)
return buggy_kernel
def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)
def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)
if __name__ == '__main__':
tilelang.testing.main()
...@@ -4,9 +4,14 @@ from __future__ import annotations ...@@ -4,9 +4,14 @@ from __future__ import annotations
from tvm import tir from tvm import tir
from tilelang.language import has_let_value, get_let_value from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load
from tilelang.language.utils import (
buffer_to_tile_region,
buffer_region_to_tile_region,
buffer_load_to_tile_region,
)
def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr): def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value. """Fill a buffer or buffer region with a specified value.
Args: Args:
...@@ -16,9 +21,30 @@ def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr): ...@@ -16,9 +21,30 @@ def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
Returns: Returns:
A TVM intrinsic call that performs the fill operation A TVM intrinsic call that performs the fill operation
""" """
# Normalize Var with let value to its underlying object
if isinstance(buffer, tir.Var) and has_let_value(buffer):
buffer = get_let_value(buffer)
# Convert to a tl.region descriptor (PrimExpr) with write access
region_call = None
if isinstance(buffer, tir.Buffer): if isinstance(buffer, tir.Buffer):
buffer = buffer.access_ptr("w") # Get write pointer if input is a Buffer region_call = buffer_to_tile_region(buffer, "w")
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value) elif isinstance(buffer, tir.BufferRegion):
extents = [r.extent for r in buffer.region]
region_call = buffer_region_to_tile_region(buffer, "w", extents)
elif isinstance(buffer, tir.BufferLoad):
region = get_buffer_region_from_load(buffer)
if region is not None:
extents = [r.extent for r in region.region]
region_call = buffer_region_to_tile_region(region, "w", extents)
else:
# Fallback: treat element access as 1-extent per dim
region_call = buffer_load_to_tile_region(buffer, "w", [1] * len(buffer.indices))
else:
# As-is fallback (rare): pass through for downstream handling
region_call = buffer
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), region_call, value)
def clear(buffer: tir.Buffer | tir.Var): def clear(buffer: tir.Buffer | tir.Var):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment