Unverified Commit 9dda774a authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[BugFix] Use BufferRegion in tl.cumsum to infer buffer shape (#1321)



* [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape

* remove debug lines

* remove rubbish

* Fix decorator syntax for atomic_different_memory_orders_program

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent c30df2a1
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "region.h" #include "region.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include "tvm/tir/stmt.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -57,12 +58,65 @@ static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, ...@@ -57,12 +58,65 @@ static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
RegionOp region(call->args, vmap); RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges()); return BufferRegion(region->GetBuffer(), region->GetRanges());
} }
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap[var];
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
} }
LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg; LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
throw; // Unreachable throw; // Unreachable
} }
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims";
PrimExpr offset, extent;
if (ndim == 1) {
// Simple 1D region: offset and extent come from the single axis.
auto axis = region->region[0];
offset = axis->min;
extent = axis->extent;
} else {
// Compute row-major strides for ndim >= 2
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
}
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>(); ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
// Accept BufferRegion/BufferLoad/tl.region for src/dst // Accept BufferRegion/BufferLoad/tl.region for src/dst
...@@ -231,6 +285,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -231,6 +285,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto dst_scope = this->dst.scope(); auto dst_scope = this->dst.scope();
if (src_scope == "local.fragment" && dst_scope == "local.fragment") { if (src_scope == "local.fragment" && dst_scope == "local.fragment") {
Buffer src_buffer = get_buffer(this->src); Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst); Buffer dst_buffer = get_buffer(this->dst);
Fragment src_layout = T.layout_map[this->src].as<Fragment>().value(); Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
...@@ -518,6 +573,16 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce) ...@@ -518,6 +573,16 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
// Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the
// ranges.
static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// CumSum constructor arguments: /// CumSum constructor arguments:
/// - src: input buffer /// - src: input buffer
...@@ -526,11 +591,19 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -526,11 +591,19 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - reverse: whether to cumsum in reverse order /// - reverse: whether to cumsum in reverse order
CHECK_EQ(args.size(), 4); CHECK_EQ(args.size(), 4);
ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>(); ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; // node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])]; // node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
node->dim = args[2].as<IntImm>().value()->value; node->dim = args[2].as<IntImm>().value()->value;
node->reverse = args[3].as<Bool>().value(); node->reverse = args[3].as<Bool>().value();
CHECK_LT(node->dim, static_cast<int>(node->src->shape.size())); CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()))
<< "The dim of cumsum should be less than the number of dimensions. Got "
"dim="
<< node->dim << ", but src has " << node->src->shape.size() << " dims.";
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -546,18 +619,22 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -546,18 +619,22 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto threads = T.thread_bounds->extent; auto threads = T.thread_bounds->extent;
Array<PrimExpr> args; Array<PrimExpr> args;
int ndim = static_cast<int>(src->shape.size()); int ndim = static_cast<int>(src->shape.size());
// Build access pointers from regions locally
PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1);
PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2);
if (ndim == 1) { if (ndim == 1) {
ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim " ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
"= 0."; "= 0.";
ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false") ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
<< ">::run"; << ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0]};
src->shape[0]};
} else if (ndim == 2) { } else if (ndim == 2) {
ss << "tl::CumSum2D<" << threads << ", " << dim << ", " ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run"; << (reverse ? "true" : "false") << ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0],
src->shape[0], src->shape[1]}; src->shape[1]};
} else { } else {
LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got " LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
<< ndim << "D."; << ndim << "D.";
......
...@@ -133,8 +133,10 @@ public: ...@@ -133,8 +133,10 @@ public:
class CumSumOpNode : public TileOperatorNode { class CumSumOpNode : public TileOperatorNode {
public: public:
tir::Buffer src, dst; ///< Source and destination buffers tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension along which to compute cumulative sum // Optional: keep the original regions used to construct this op
bool reverse; ///< Whether to compute in reverse order BufferRegion srcRegion_, dstRegion_;
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode, TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode,
TileOperatorNode); TileOperatorNode);
...@@ -143,6 +145,8 @@ public: ...@@ -143,6 +145,8 @@ public:
refl::ObjectDef<CumSumOpNode>() refl::ObjectDef<CumSumOpNode>()
.def_ro("src", &CumSumOpNode::src) .def_ro("src", &CumSumOpNode::src)
.def_ro("dst", &CumSumOpNode::dst) .def_ro("dst", &CumSumOpNode::dst)
.def_ro("srcRegion", &CumSumOpNode::srcRegion_)
.def_ro("dstRegion", &CumSumOpNode::dstRegion_)
.def_ro("dim", &CumSumOpNode::dim) .def_ro("dim", &CumSumOpNode::dim)
.def_ro("reverse", &CumSumOpNode::reverse); .def_ro("reverse", &CumSumOpNode::reverse);
} }
......
import torch
import tilelang
import tilelang.testing
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _cumsum_view_infer_layout(hidden):
num_tokens = T.dynamic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden,), dtype='float')
T.copy(x[pid, :], smem)
T.cumsum(T.view(smem, (1, hidden)), dim=1)
return buggy_kernel
def test_cumsum_view_infer_layout():
hidden = 128
x = torch.randn(1, hidden, device='cuda', dtype=torch.float)
kernel = _cumsum_view_infer_layout(hidden)
kernel(x)
if __name__ == '__main__':
tilelang.testing.main()
...@@ -260,7 +260,7 @@ def test_atomic_addx2(): ...@@ -260,7 +260,7 @@ def test_atomic_addx2():
run_atomic_addx2(32, 64, 8, 16) run_atomic_addx2(32, 64, 8, 16)
@tilelang.jit(debug_root_path="./testing/python/language") @tilelang.jit
def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
......
"""Tilelang IR analysis & visitors.""" """Tilelang IR analysis & visitors."""
from .ast_printer import ASTPrinter # noqa: F401
from .nested_loop_checker import NestedLoopChecker # noqa: F401 from .nested_loop_checker import NestedLoopChecker # noqa: F401
from tvm import tir
from tvm.tir import PrimFunc
from tvm.tir.transform import prim_func_pass
from tvm.tir.stmt_functor import ir_transform
def ASTPrinter():
"""
Print the AST of a given tilelang module for debugging.
"""
def pre_visit(statement: tir.Stmt) -> None:
"""
Pre-order visitor to print all visited statements.
"""
print(f"Visiting statement: {type(statement)}")
def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc:
new_body = ir_transform(func.body, pre_visit, None)
return func.with_body(new_body)
return prim_func_pass(pass_fn, opt_level=0)
...@@ -74,6 +74,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: ...@@ -74,6 +74,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:
Note: This is a validation-only pipeline of passes and does not modify or return the module. Note: This is a validation-only pipeline of passes and does not modify or return the module.
""" """
# Debug
# tilelang.analysis.ASTPrinter()(mod)
# Check if there are any invalid nested loops. # Check if there are any invalid nested loops.
tilelang.analysis.NestedLoopChecker()(mod) tilelang.analysis.NestedLoopChecker()(mod)
......
...@@ -246,8 +246,8 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - ...@@ -246,8 +246,8 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.cumsum"), tir.op.Op.get("tl.cumsum"),
cumsum_smem.access_ptr("r"), buffer_to_tile_region(cumsum_smem, "r"),
cumsum_smem.access_ptr("w"), buffer_to_tile_region(cumsum_smem, "w"),
dim, dim,
reverse, reverse,
) )
...@@ -300,8 +300,8 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse ...@@ -300,8 +300,8 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.cumsum"), tir.op.Op.get("tl.cumsum"),
src.access_ptr("r"), buffer_to_tile_region(src, "r"),
dst.access_ptr("w"), buffer_to_tile_region(dst, "w"),
dim, dim,
reverse, reverse,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment