"vscode:/vscode.git/clone" did not exist on "e82fcf30c678e743e08be111ba73ba7debbca135"
Unverified Commit f5d9da46 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Phaseout vmap for Tile Operators (#1334)



* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations.

* fix

* Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions.

* fix

* fix

* test fix

* lint fix

* Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management.

* fix

* lint fix

* fix

* fix

* test fix

* lint fix

* lint fix

* minor fix

* fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent fac04006
......@@ -125,7 +125,7 @@ class ReduceOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator,
ReduceOpNode);
TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL ReduceOp(Array<PrimExpr> args);
static const Op &Get();
};
......@@ -163,7 +163,7 @@ class CumSumOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator,
CumSumOpNode);
TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL CumSumOp(Array<PrimExpr> args);
static const Op &Get();
};
......
/*!
* \file tl/op/region.cc
* \brief Define region operator.
* \brief Define region operator (bridge to carry BufferRegion via Call args).
*
* Notes:
* - BufferLoad/Ramp cannot represent a general PrimExpr as a vector lane
* count. Dynamic extents like (H1 - H0) cannot be encoded as
* Ramp(lanes = H1 - H0), and lowering BufferRegion to BufferLoad loses the
* explicit extent information.
* - tl.region carries both mins and extents in Call args and lets the backend
* reconstruct a BufferRegion faithfully.
*/
#include "region.h"
......@@ -11,27 +18,7 @@ namespace tvm {
namespace tl {
using namespace tir;
/**
* @brief Construct a RegionOp from TL operator arguments.
*
* Parses the TL `region` operator call arguments to populate the RegionOpNode:
* - Expects args[0] to be a `BufferLoad` whose `indices` are the per-dimension
* minima.
* - args[1] must be a constant integer used as the access mask.
* - args[2 + i] provides the extent for dimension `i`.
*
* The constructor validates that the number of load indices equals `args.size()
* - 2` and will abort via ICHECK on mismatch or if args[0] is not a
* `BufferLoad`.
*
* Parameters:
* - args: TL operator call arguments in the form
* [BufferLoad(min_i...), access_mask, extent_0, extent_1, ...,
* extent_{n-1}] where n = number of dimensions.
* - vmap: BufferMap passed through by the caller (not documented here as a
* generic utility).
*/
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
RegionOp::RegionOp(Array<PrimExpr> args) {
size_t n = args.size();
size_t ndim = n - 2;
auto load = args[0].as<BufferLoadNode>();
......@@ -39,10 +26,24 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
ICHECK(load->indices.size() == ndim)
<< "load->indices.size() = " << load->indices << " ndim = " << ndim;
Array<Range> ranges;
// Rebuild per-axis ranges from mins (BufferLoad indices) and provided extents
for (size_t i = 0; i < ndim; i++) {
PrimExpr min = load->indices[i];
PrimExpr index = load->indices[i];
PrimExpr extent = args[2 + i];
ranges.push_back(Range::FromMinExtent(min, extent));
if (const auto *ramp = index.as<RampNode>()) {
const auto *stride_imm = ramp->stride.as<IntImmNode>();
ICHECK(stride_imm && stride_imm->value == 1)
<< "RegionOp expects stride-1 Ramp for index";
if (const auto *lanes_imm = ramp->lanes.as<IntImmNode>()) {
if (const auto *ext_imm = extent.as<IntImmNode>()) {
ICHECK_EQ(lanes_imm->value, ext_imm->value)
<< "Ramp lanes and provided extent must match";
}
}
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, extent));
}
}
ObjectPtr<RegionOpNode> node = tvm::ffi::make_object<RegionOpNode>();
node->buffer_ = load->buffer;
......@@ -51,26 +52,11 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Create a copy of this RegionOpNode and return it as a TileOperator.
*
* @return TileOperator A new TileOperator that owns a copied RegionOpNode.
*/
TileOperator RegionOpNode::Clone() const {
auto op = tvm::ffi::make_object<RegionOpNode>(*this);
return RegionOp(op);
}
/**
* @brief Check whether the region spans the entire underlying buffer.
*
* Returns true if for every dimension the range minimum is zero and the
* range extent is structurally equal to the corresponding buffer shape
* dimension. Otherwise returns false.
*
* @return true if the region covers the full buffer in all dimensions; false
* otherwise.
*/
bool RegionOpNode::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min))
......@@ -81,39 +67,26 @@ bool RegionOpNode::IsFullRegion() const {
return true;
}
/**
* @brief Lower the region operator to a TIR statement.
*
* Lowers this RegionOpNode into a TIR Stmt by delegating to the operator's
* evaluation path (currently `Evaluate(0)`).
*
* @param T Lowering context (provides buffers, producers/consumers and other
* environment required for lowering).
* @param analyzer Optional arithmetic analyzer used for simplification during
* lowering.
* @return Stmt The lowered TIR statement representing this region operation.
*/
Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(0);
}
/**
* @brief Infers data layout for the region operator.
*
* This operator does not provide any layout inference; the function always
* returns an empty LayoutMap regardless of the provided arguments or inference
* level.
*
* @param T Layout inference arguments (ignored).
* @param level Inference granularity level (ignored).
* @return LayoutMap Empty map indicating no inferred layouts.
*/
LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
}
TIR_REGISTER_TL_OP(RegionOp, region)
const Op &RegionOp::Get() {
static const Op &op = Op::Get("tl.region");
return op;
}
TVM_REGISTER_OP("tl.region")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "region")
.set_attr<OpBuilderFunc>("TLOpBuilder",
[](Array<PrimExpr> args) {
return RegionOp(args);
})
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
......
/*!
* \file tl/op/op.h
* \brief Tile library operations.
* \file tl/op/region.h
* \brief Tile memory region descriptor op (bridge to carry BufferRegion via
* Call args).
*
* Why tl.region instead of passing BufferRegion directly?
*
* - While TIR can represent a BufferRegion, when a BufferRegion is passed as a
* call argument through call_intrin/FFI, the Python->C++ conversion lowers it
* to a BufferLoad(indices). To encode an interval inside indices, the FFI
* typically uses Ramp(base, stride, lanes) to represent a contiguous slice.
* - Ramp(lanes) may only be a constant or vscale*k (scalable vector). A general
* PrimExpr (e.g., H1 - H0) is not allowed as lanes, so dynamic extents would
* make the lowered BufferLoad invalid.
* - Moreover, BufferLoad only carries indices, not per-axis extents. Downstream
* tile operators (e.g., tl.copy, tl.reduce) that require both min and extent
* cannot losslessly recover dynamic extents from a BufferLoad alone.
*
* tl.region is a small transport-only op that solves this:
* - The frontend packs buffer + mins (from BufferLoad.indices) + extents into
* Call args, allowing dynamic extents to be expressed explicitly.
* - The backend (NormalizeToBufferRegion) reconstructs a BufferRegion from the
* tl.region call without losing information.
* - The op itself carries no semantics in Lower/InferLayout and is only used as
* a bridge for argument passing.
*/
#ifndef TVM_TL_OP_REGION_H_
#define TVM_TL_OP_REGION_H_
#include "./operator.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
/**
* Tile operator representing a memory region (buffer + ranges) used by TL
* passes.
*
* Encapsulates the target tir::Buffer, the region extents as an Array<Range>,
* and an access mask that indicates permitted or intended accesses for lowering
* and layout inference.
*/
/**
* Lower this RegionOp into a TIR statement representing the region access.
*
* @param T Lowering-time arguments (e.g., loop/build context and value
* mappings).
* @param analyzer Arithmetic analyzer used to simplify and reason about
* expressions.
* @return A tir::Stmt that implements the region access/mutation described by
* this operator.
*/
/**
* Infer the layout mapping for this region operator.
*
* Produces a LayoutMap describing how loop/axis indices map to buffer axes for
* layout-aware scheduling and subsequent operators.
*
* @param T Layout inference arguments (e.g., input layouts and shapes).
* @param level The inference detail level to use.
* @return A LayoutMap describing inferred mappings for the operator.
*/
/**
* Return true when this RegionOp represents the full buffer region (i.e.,
* ranges cover the entire buffer extent).
*/
/**
* Create a shallow copy of this operator as a TileOperator handle.
*
* @return A TileOperator that references a cloned RegionOpNode.
*/
/**
* Construct a RegionOp from argument expressions and a buffer map.
*
* @param args Positional expressions used to instantiate the operator
* (semantics depend on how RegionOp is invoked in TL pipelines).
* @param vmap Mapping from Buffer to replacement Buffer or buffer metadata used
* during creation.
*/
/**
* Return the global Op registration for RegionOp.
*
* @return Reference to the registered tvm::Op describing the RegionOp.
*/
namespace tvm {
namespace tl {
......@@ -80,6 +42,12 @@ public:
Array<Range> ranges_;
int access_mask_;
/*!
* access_mask_ encodes the intended access type when the region is used as
* an argument to tile operators: 1=read, 2=write, 3=read-write. The mask is
* transport metadata only and does not affect lowering.
*/
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode,
TileOperatorNode);
......@@ -107,8 +75,13 @@ class RegionOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator,
RegionOpNode);
TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap);
/*!
* Build a RegionOp from call arguments:
* - args[0]: BufferLoad whose indices are per-axis minima.
* - args[1]: Integer access mask (1=r, 2=w, 3=rw).
* - args[2 + i]: Extent of axis i (supports dynamic PrimExpr).
*/
TVM_DLL RegionOp(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -12,8 +12,7 @@ namespace tl {
using namespace tir;
BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
......@@ -38,23 +37,15 @@ BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
// Case 3: tl.region(...) — reconstruct via RegionOp (bridge)
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
RegionOp region(call->args);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap.at(var);
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
LOG(FATAL) << "Unsupported argument for BufferRegion (expect "
"BufferLoad/BufferRegion/tl.region): "
<< arg;
}
LOG(FATAL) << "Unsupported argument for BufferRegion: " << arg;
......
......@@ -16,10 +16,10 @@ namespace tl {
using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region/tvm_access_ptr)
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so ops can uniformly consume regions.
TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap);
// Note: tvm_access_ptr is no longer supported here.
TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg);
// Build a tvm_access_ptr(handle) from a BufferRegion.
// - If `require_2d` is true, checks buffer ndim >= 2.
......
......@@ -437,11 +437,13 @@ private:
if (op->op.as<GlobalVarNode>())
return;
auto p = ParseOperator(tvm::ffi::GetRef<Call>(op), GetBufferMap());
auto p = ParseOperator(tvm::ffi::GetRef<Call>(op));
if (p.defined()) {
for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) {
addToUseList(buffer.value());
} else if (auto buffer = getBufferFromRegion(arg)) {
addToUseList(buffer.value());
}
}
// Compute thread_var_ and thread_bounds_
......@@ -495,6 +497,9 @@ private:
}
Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
if (auto bl = expr.as<BufferLoadNode>()) {
return bl->buffer;
}
auto call = expr.as<CallNode>();
if (!call) {
return std::nullopt;
......@@ -514,8 +519,18 @@ private:
}
}
return std::nullopt;
} else if (call->op.same_as(RegionOp::Get())) {
return call->args[0].as<BufferLoadNode>()->buffer;
}
return std::nullopt;
}
Optional<Buffer> getBufferFromRegion(const PrimExpr &expr) {
if (auto call = expr.as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
if (auto bl = call->args[0].as<BufferLoadNode>()) {
return bl->buffer;
}
return std::nullopt;
}
}
return std::nullopt;
}
......
......@@ -277,7 +277,7 @@ private:
if (op->op.same_as(Fill::Get())) {
ICHECK(!op->args.empty());
if (auto arg0_call = op->args[0].as<Call>()) {
// Case 1: tl.region(...) — extract buffer var from its first arg
// tl.region(...) — extract buffer var from its first arg
if (arg0_call.value()->op.same_as(RegionOp::Get())) {
ICHECK(!arg0_call.value()->args.empty());
if (auto bl = arg0_call.value()->args[0].as<BufferLoadNode>()) {
......@@ -285,15 +285,14 @@ private:
if (reducer_info_map_.count(var)) {
ICHECK(inside_reducer_range_.count(var) == 0)
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
"T.finalize_reducer before next.";
inside_reducer_range_.Set(var,
reducer_info_map_.Get(var).value());
}
}
}
// Case 2: builtin.tvm_access_ptr(...) — existing path
else if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
// builtin.tvm_access_ptr(...) — existing path (legacy)
if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
......@@ -305,10 +304,33 @@ private:
var.value(), reducer_info_map_.Get(var.value()).value());
}
}
} else if (auto bl = op->args[0].as<BufferLoadNode>()) {
Var var = bl->buffer->data;
if (reducer_info_map_.count(var)) {
ICHECK(inside_reducer_range_.count(var) == 0)
<< "T.fill on reducer must be enclosed with a T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var, reducer_info_map_.Get(var).value());
}
}
} else if (op->op.same_as(FinalizeReducerOp::Get())) {
ICHECK(op->args.size() == 1);
auto var = GetVarFromAccessPtr(op->args[0]);
Var var;
if (auto bl = op->args[0].as<BufferLoadNode>()) {
var = bl->buffer->data;
} else if (auto reg_call = op->args[0].as<Call>()) {
if (reg_call.value()->op.same_as(RegionOp::Get())) {
if (auto bl2 = reg_call.value()->args[0].as<BufferLoadNode>()) {
var = bl2->buffer->data;
} else {
LOG(FATAL) << "tl.region expects BufferLoad as first arg";
}
} else {
var = GetVarFromAccessPtr(op->args[0]);
}
} else {
var = GetVarFromAccessPtr(op->args[0]);
}
ICHECK(inside_reducer_range_.count(var) == 1)
<< "T.finalize_reducer must have a pairing T.fill ahead of it, "
"enclosing a reduction range.";
......
......@@ -606,8 +606,7 @@ private:
if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto tile_op =
ParseOperator(tvm::ffi::GetRef<Stmt>(op), buffer_data_to_buffer_);
auto tile_op = ParseOperator(tvm::ffi::GetRef<Stmt>(op));
if (!tile_op.defined())
return IRMutatorWithAnalyzer::VisitStmt_(op);
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
......
......@@ -17,7 +17,15 @@ def _empty_kernel():
return empty_kernel
@tilelang.testing.requires_cuda
def test_empty_kernel_lowering():
# Ensure a valid CUDA runtime context is current on this thread for the
# target device before using driver API calls. Without this, calls like
# cuModuleLoadData can fail with CUDA_ERROR_INVALID_CONTEXT, especially
# for kernels that don't touch any device memory or streams beforehand
# (e.g., "empty" kernels) and therefore haven't triggered context
# creation implicitly.
torch.cuda.set_device(0)
kernel = _empty_kernel()
kernel()
......@@ -59,7 +67,9 @@ def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False):
return kernel_with_tuple_kernel_binding if use_tuple_binding else kernel_with_scalar_kernel_binding
@tilelang.testing.requires_cuda
def test_empty_kernel_with_binding_variants():
torch.cuda.set_device(0)
kernel = _empty_kernel_with_binding_variants()
kernel()
......
......@@ -2,14 +2,15 @@ from __future__ import annotations
from tilelang import tvm as tvm
import tilelang.language as T
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tvm import tir
from tvm.ir import Range
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad
from tvm.runtime import convert
from .utils import (
mfma_store_index_map,)
from .utils import (mfma_store_index_map)
from typing import Literal, Callable
from tilelang.utils import is_fragment
from tilelang.utils.language import to_buffer_region
from tilelang.utils.language import get_buffer_region_from_load
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
......@@ -268,7 +269,7 @@ class MatrixCoreIntrinEmitter:
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_region = self._legalize_to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
......@@ -314,7 +315,7 @@ class MatrixCoreIntrinEmitter:
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_region = self._legalize_to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
......@@ -655,6 +656,33 @@ class MatrixCoreIntrinEmitter:
forward_index_fn=forward_index,
)
@staticmethod
def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.
- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
"""
if isinstance(obj, BufferRegion):
return obj
if isinstance(obj, Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape]
ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return BufferRegion(obj, ranges)
if isinstance(obj, BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
......
......@@ -3,14 +3,16 @@ import tilelang.language as T
from typing import Literal, Callable
from tilelang.common import TransformKind
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tvm import tir
from tvm.ir import Range
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad
from tilelang import tvm as tvm
from tvm.runtime import convert
from .utils import (
mma_store_index_map,
get_ldmatrix_offset,
)
from tilelang.utils import is_fragment, to_buffer_region
from tilelang.utils import is_fragment, get_buffer_region_from_load
from tilelang.intrinsics.mma_layout import (
shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x8_to_mma_32x4_layout_sr_b,
......@@ -243,7 +245,7 @@ class TensorCoreIntrinEmitter:
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_region = self._legalize_to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
......@@ -294,7 +296,7 @@ class TensorCoreIntrinEmitter:
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_region = self._legalize_to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
......@@ -360,7 +362,7 @@ class TensorCoreIntrinEmitter:
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_region = self._legalize_to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
......@@ -397,7 +399,7 @@ class TensorCoreIntrinEmitter:
thread_binding = self.get_thread_binding()
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_region = self._legalize_to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
......@@ -798,6 +800,33 @@ class TensorCoreIntrinEmitter:
forward_index_fn=forward_index,
)
@staticmethod
def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.
- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
"""
if isinstance(obj, BufferRegion):
return obj
if isinstance(obj, Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape]
ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return BufferRegion(obj, ranges)
if isinstance(obj, BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")
class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
"""
......
......@@ -5,7 +5,7 @@ from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion
from tilelang import tvm as tvm
from tvm.runtime import convert
from tilelang.utils import is_fragment, to_buffer_region
from tilelang.utils import is_fragment
from tilelang.intrinsics.mma_sm70_layout import (
shared_16x4_to_mma_a_32x4_layout,
shared_4x16_to_mma_b_32x4_layout,
......@@ -207,7 +207,7 @@ class TensorCoreIntrinEmitter:
mma_load_layout = mma_load_a_32x4_to_shared_16x4_layout
# legalize shared buffer to region
A_region = to_buffer_region(A_shared_buf)
A_region = self._legalize_to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
......@@ -248,7 +248,7 @@ class TensorCoreIntrinEmitter:
mma_load_layout = mma_load_b_32x4_to_shared_16x4_layout_trans if b_transposed else mma_load_b_32x4_to_shared_4x16_layout
# legalize shared buffer to region
B_region = to_buffer_region(B_shared_buf)
B_region = self._legalize_to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
......
......@@ -4,10 +4,9 @@
from __future__ import annotations
import tilelang.language as T
from tvm import ir, tir
from tvm import ir
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.utils.language import get_buffer_region_from_load, legalize_pairwise_extents
from tilelang.utils.language import to_buffer_region, legalize_pairwise_extents
_MEMORY_ORDER_ID_MAP = {
"relaxed": 0,
......@@ -203,24 +202,8 @@ def atomic_add(dst: Buffer,
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
def _to_region(data, access_type, extent):
if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, tir.Buffer):
zeros = [tir.IntImm("int32", 0) for _ in extent]
return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent)
elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent)
elif isinstance(data, tir.BufferLoad):
region = get_buffer_region_from_load(data)
if region is None:
return buffer_load_to_tile_region(data, access_type, extent)
return buffer_region_to_tile_region(region, access_type, extent)
else:
return buffer_load_to_tile_region(data, access_type, extent)
value = _to_region(value, "r", src_extent)
dst = _to_region(dst, "w", dst_extent)
value = to_buffer_region(value, access_type="r", extents=src_extent)
dst = to_buffer_region(dst, access_type="w", extents=dst_extent)
# Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime
......
......@@ -3,11 +3,11 @@ from __future__ import annotations
from typing import Literal
from tilelang import language as T
from tilelang.utils.language import (
to_buffer_region,
get_buffer_region_from_load,
legalize_pairwise_extents,
)
from tvm import ir, tir
from tilelang.language.utils import buffer_region_to_tile_region, buffer_load_to_tile_region
def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
......@@ -69,27 +69,9 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
# - otherwise -> error
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
def _to_region(data, access_type, extent):
if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, tir.Buffer):
# Restrict a raw buffer to the computed copy extent by creating
# a BufferLoad at origin and passing the extents explicitly.
zeros = [tir.IntImm("int32", 0) for _ in extent]
return buffer_load_to_tile_region(tir.BufferLoad(data, zeros), access_type, extent)
elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent)
elif isinstance(data, tir.BufferLoad):
region = get_buffer_region_from_load(data)
if region is None:
return buffer_load_to_tile_region(data, access_type, extent)
return buffer_region_to_tile_region(region, access_type, extent)
else:
return buffer_load_to_tile_region(data, access_type, extent)
# Use legalized extents for src and dst respectively.
src = _to_region(src, "r", src_extent)
dst = _to_region(dst, "w", dst_extent)
src = to_buffer_region(src, access_type="r", extents=src_extent)
dst = to_buffer_region(dst, access_type="w", extents=dst_extent)
if coalesced_width is None:
coalesced_width = -1 # PrimExpr can not be None
......@@ -129,6 +111,7 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy = 0
else:
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img.access_ptr("r"),
col.access_ptr("w"), nhw_step, c_step, kernel, stride, dilation, pad,
eviction_policy)
img_region = to_buffer_region(img, access_type="r")
col_region = to_buffer_region(col, access_type="w")
return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img_region, col_region,
nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy)
......@@ -3,6 +3,7 @@ from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from tilelang.utils.language import to_buffer_region
def gemm_sp(
......@@ -62,17 +63,18 @@ def gemm_sp(
K_A = A_sparse.shape[0] if transpose_A else A_sparse.shape[1]
K_B = B.shape[1] if transpose_B else B.shape[0]
assert K_A * 2 == K_B, f"T.gemm_sp K shape check failed: K_A = {K_A}, K_B = {K_B}"
Aptr = A_sparse.access_ptr("r")
Bptr = B.access_ptr("r")
Cptr = C.access_ptr("rw")
Eptr = E.access_ptr("r")
# Build tl.region descriptors for operands
A_arg = to_buffer_region(A_sparse, access_type="r")
E_arg = to_buffer_region(E, access_type="r")
B_arg = to_buffer_region(B, access_type="r")
C_arg = to_buffer_region(C, access_type="rw")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_sp"),
Aptr,
Eptr,
Bptr,
Cptr,
A_arg,
E_arg,
B_arg,
C_arg,
transpose_A,
transpose_B,
M,
......
......@@ -2,12 +2,7 @@
from __future__ import annotations
from tvm import tir
from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load
from tilelang.language.utils import (
buffer_to_tile_region,
buffer_region_to_tile_region,
buffer_load_to_tile_region,
)
from tilelang.utils.language import get_buffer_region_from_load, to_buffer_region
def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.PrimExpr):
......@@ -24,26 +19,21 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
if isinstance(buffer, tir.Var) and has_let_value(buffer):
buffer = get_let_value(buffer)
# Convert to a tl.region descriptor (PrimExpr) with write access
region_call = None
# Build tl.region as argument
if isinstance(buffer, tir.Buffer):
region_call = buffer_to_tile_region(buffer, "w")
extents = list(buffer.shape)
elif isinstance(buffer, tir.BufferRegion):
extents = [r.extent for r in buffer.region]
region_call = buffer_region_to_tile_region(buffer, "w", extents)
elif isinstance(buffer, tir.BufferLoad):
region = get_buffer_region_from_load(buffer)
if region is not None:
extents = [r.extent for r in region.region]
region_call = buffer_region_to_tile_region(region, "w", extents)
else:
# Fallback: treat element access as 1-extent per dim
region_call = buffer_load_to_tile_region(buffer, "w", [1] * len(buffer.indices))
extents = [tir.IntImm("int32", 1) for _ in buffer.indices]
else:
# As-is fallback (rare): pass through for downstream handling
region_call = buffer
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), region_call, value)
extents = []
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"),
to_buffer_region(buffer, access_type="w", extents=extents), value)
def clear(buffer: tir.Buffer | tir.Var):
......
......@@ -7,10 +7,11 @@ from tilelang.utils.language import (
to_buffer_region,
retrieve_shape,
retrieve_stride,
retrieve_ptr,
retrieve_offset,
prim_expr_equal,
)
from tilelang.language.utils import (
buffer_region_to_tile_region,)
from tilelang.env import env as _env
......@@ -50,17 +51,17 @@ def _gemm_impl(
C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None
# Normalize A/B/C to BufferRegion to pass into tl.gemm
A = to_buffer_region(A)
B = to_buffer_region(B)
C = to_buffer_region(C)
# Normalize A/B/C to BufferRegion for shape/stride/offset analysis
A_region = to_buffer_region(A)
B_region = to_buffer_region(B)
C_region = to_buffer_region(C)
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
A_shape = retrieve_shape(A_region)
B_shape = retrieve_shape(B_region)
C_shape = retrieve_shape(C_region)
A_stride = retrieve_stride(A)
B_stride = retrieve_stride(B)
A_stride = retrieve_stride(A_region)
B_stride = retrieve_stride(B_region)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
......@@ -82,18 +83,22 @@ def _gemm_impl(
stride_a = A_stride[-2]
stride_b = B_stride[-2]
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
A_offset = retrieve_offset(A_region)
B_offset = retrieve_offset(B_region)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]
mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32")
C_coords = [r.min for r in C.region]
return tir.call_intrin("handle", tir.op.Op.get(op_key), A, B, C, transpose_A, transpose_B, M, N,
K, policy, clear_accum, stride_a, stride_b, offset_a, offset_b, k_pack,
wg_wait, mbarptr, C_coords[0], C_coords[1])
mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, "uint32")
C_coords = [r.min for r in C_region.region]
# Convert BufferRegion to tl.region calls for arguments
A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape])
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
return tir.call_intrin("handle", tir.op.Op.get(op_key), A_arg, B_arg, C_arg, transpose_A,
transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a,
offset_b, k_pack, wg_wait, mbar, C_coords[0], C_coords[1])
# Public wrappers
......
......@@ -2,7 +2,7 @@
from __future__ import annotations
from tvm import tir
from tilelang.language import copy, macro, alloc_shared, alloc_fragment
from tilelang.language.utils import buffer_to_tile_region
from tilelang.utils.language import to_buffer_region
from tilelang.utils.language import is_shared, is_fragment
from tvm.script.ir_builder import IRBuilder
......@@ -51,8 +51,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer_to_tile_region(red_frag_in, "r"),
buffer_to_tile_region(red_frag_out, "w"),
to_buffer_region(red_frag_in, access_type="r"),
to_buffer_region(red_frag_out, access_type="w"),
reduce_type,
dim,
clear,
......@@ -66,8 +66,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer_to_tile_region(red_frag_in, "r"),
buffer_to_tile_region(out, "w"),
to_buffer_region(red_frag_in, access_type="r"),
to_buffer_region(out, access_type="w"),
reduce_type,
dim,
clear,
......@@ -79,8 +79,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer_to_tile_region(buffer, "r"),
buffer_to_tile_region(red_frag_out, "w"),
to_buffer_region(buffer, access_type="r"),
to_buffer_region(red_frag_out, access_type="w"),
reduce_type,
dim,
clear,
......@@ -90,8 +90,8 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer_to_tile_region(buffer, "r"),
buffer_to_tile_region(out, "w"),
to_buffer_region(buffer, access_type="r"),
to_buffer_region(out, access_type="w"),
reduce_type,
dim,
clear,
......@@ -246,8 +246,8 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
tir.call_intrin(
"handle",
tir.op.Op.get("tl.cumsum"),
buffer_to_tile_region(cumsum_smem, "r"),
buffer_to_tile_region(cumsum_smem, "w"),
to_buffer_region(cumsum_smem, access_type="r"),
to_buffer_region(cumsum_smem, access_type="w"),
dim,
reverse,
)
......@@ -300,8 +300,8 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.cumsum"),
buffer_to_tile_region(src, "r"),
buffer_to_tile_region(dst, "w"),
to_buffer_region(src, access_type="r"),
to_buffer_region(dst, access_type="w"),
dim,
reverse,
)
......@@ -323,7 +323,7 @@ def finalize_reducer(reducer: tir.Buffer):
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.finalize_reducer"),
reducer.access_ptr("w"),
to_buffer_region(reducer, access_type="w"),
)
......
from tilelang import tvm as tvm
from tvm import tir
from tvm.tir import PrimExpr, Buffer, BufferLoad, op
from tvm.tir import PrimExpr, BufferLoad, op
from tilelang import language as T
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""
Create a tile memory-region descriptor for a BufferLoad.
Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic
(1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.
Parameters:
buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices.
access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access.
*args (tir.PrimExpr): Extent expressions for each region dimension.
Returns:
tir.Call: A call to the `tl.region` intrinsic describing the memory region.
Raises:
KeyError: If access_type is not one of 'r', 'w', or 'rw'.
"""
"""Create a tl.region call for a BufferLoad and extents."""
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args)
def buffer_to_tile_region(buffer: Buffer, access_type: str):
"""Convert a TVM buffer to a tile region descriptor.
Args:
buffer (tir.Buffer): The buffer to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor covering the entire buffer
"""
mins = [0 for _ in buffer.shape]
extents = [x for x in buffer.shape]
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
load (tir.BufferLoad): The buffer load operation
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
extents (List[tir.PrimExpr]): List of expressions defining the region size
Returns:
tir.Call: A region descriptor for the loaded area
"""
indices = load.indices
"""Convert a BufferLoad to a tl.region call with explicit extents."""
indices = list(load.indices)
if len(indices) > len(extents):
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
# f"region will be expanded in the last 2 dimensions")
new_extents = []
for _ in range(len(indices) - len(extents)):
new_extents.append(1)
for extent in extents:
new_extents.append(extent)
extents = new_extents
extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents))
] + list(extents)
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
extents: list[tir.PrimExpr]):
"""Convert a buffer region to a tile region descriptor.
Args:
buffer_region (tir.BufferRegion): The buffer region to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor for the specified buffer region
"""
mins = [x.min for x in buffer_region.region]
region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len(
extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
# Clamp extents element-wise so that the produced region respects the
# requested copy/fill extent, supporting dynamic PrimExpr via tir.min.
"""Clamp extents and return a tl.region call."""
mins = [r.min for r in buffer_region.region]
region_extents = [r.extent for r in buffer_region.region]
assert len(region_extents) >= len(extents), (
f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
)
clamped_extents = [
tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i]
for i in range(len(region_extents))
]
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents)
return region(tir.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents)
def index_to_coordinates(index, shape) -> list[PrimExpr]:
......
......@@ -123,6 +123,10 @@ class GemmBase:
def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32"))
@property
def mbar(self) -> tir.Buffer:
return getattr(self.gemm_node, "mbar", None)
@property
def C_coords(self):
coords = getattr(self.gemm_node, "cCoords", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment