"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "7a7fe160866b7b2893be698d77b70cc8cf754fb5"
Unverified Commit b10ef75f authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Analysis] Enhance NestedLoopChecker with tile op cases (#1358)

* [Analysis] Enhance NestedLoopChecker with tile op cases

* fix tileop issue
parent 1b42c87b
...@@ -539,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -539,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop; return vectorized_thread_loop;
} }
TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -2037,7 +2037,7 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const { ...@@ -2037,7 +2037,7 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, // - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy // eviction_policy
// - Marked as opaque since it has side effects (memory writes) // - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy) TIR_REGISTER_TL_TILE_OP(Copy, copy)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
...@@ -2062,7 +2062,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -2062,7 +2062,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T,
// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, // - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride,
// dilation, padding, eviction_policy // dilation, padding, eviction_policy
// - Marked as opaque since it has side effects (memory writes) // - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col) TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(9) .set_num_inputs(9)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -209,7 +209,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T, ...@@ -209,7 +209,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
return {}; return {};
} }
TIR_REGISTER_TL_OP(Fill, fill) TIR_REGISTER_TL_TILE_OP(Fill, fill)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -159,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const { ...@@ -159,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const {
return TileOperator(node); return TileOperator(node);
} }
TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer) TIR_REGISTER_TL_TILE_OP(FinalizeReducerOp, finalize_reducer)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -826,7 +826,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ...@@ -826,7 +826,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
return results; return results;
} }
TIR_REGISTER_TL_OP(Gemm, gemm) TIR_REGISTER_TL_TILE_OP(Gemm, gemm)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -318,7 +318,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, ...@@ -318,7 +318,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
return results; return results;
} }
TIR_REGISTER_TL_OP(GemmPy, gemm_py) TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -302,7 +302,7 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, ...@@ -302,7 +302,7 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
return results; return results;
} }
TIR_REGISTER_TL_OP(GemmSP, gemm_sp) TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -77,12 +77,12 @@ TileOperator ParseOperator(Stmt stmt); ...@@ -77,12 +77,12 @@ TileOperator ParseOperator(Stmt stmt);
using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>)>; using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \ #define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \
const Op &Entry::Get() { \ const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \ static const Op &op = Op::Get("tl.tileop." #OpName); \
return op; \ return op; \
} \ } \
TVM_REGISTER_OP("tl." #OpName) \ TVM_REGISTER_OP("tl.tileop." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \ .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \ .set_attr<OpBuilderFunc>( \
"TLOpBuilder", [](Array<PrimExpr> args) { return Entry(args); }) "TLOpBuilder", [](Array<PrimExpr> args) { return Entry(args); })
......
...@@ -478,7 +478,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -478,7 +478,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
return {}; return {};
} }
TIR_REGISTER_TL_OP(ReduceOp, reduce) TIR_REGISTER_TL_TILE_OP(ReduceOp, reduce)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
...@@ -563,7 +563,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -563,7 +563,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
return {}; return {};
} }
TIR_REGISTER_TL_OP(CumSumOp, cumsum) TIR_REGISTER_TL_TILE_OP(CumSumOp, cumsum)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -76,17 +76,7 @@ LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -76,17 +76,7 @@ LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
return {}; return {};
} }
const Op &RegionOp::Get() { TIR_REGISTER_TL_TILE_OP(RegionOp, region)
static const Op &op = Op::Get("tl.region");
return op;
}
TVM_REGISTER_OP("tl.region")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "region")
.set_attr<OpBuilderFunc>("TLOpBuilder",
[](Array<PrimExpr> args) {
return RegionOp(args);
})
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)); Integer(CallEffectKind::kPure));
......
...@@ -550,5 +550,178 @@ def test_mixed_pp(): ...@@ -550,5 +550,178 @@ def test_mixed_pp():
run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1]) run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1])
"""
TiledOp in a T.Parallel is also not permitted.
"""
def matmul_with_parallel(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
threads,
order,
stage,
):
A_shape = (M, K)
B_shape = (K, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
for i, j in T.Parallel(block_M, block_K):
A_shared[i, j] = A[by * block_M + i, k * block_K + j]
for i, j in T.Parallel(block_K, block_N):
B_shared[i, j] = B[k * block_K + i, bx * block_N + j]
# T.copy(A[by * block_M, k * block_K], A_shared)
# T.copy(B[k * block_K, bx * block_N], B_shared)
for _ in T.Parallel(1):
T.gemm(A_shared, B_shared, C_local, False, False)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_tiled_op_with_parallel(
order,
stage,
):
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
num_threads = 128
program = matmul_nested_pipa(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_threads,
order,
stage,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
program1 = matmul_with_parallel(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_threads,
order,
stage,
)
with pytest.raises(ValueError):
tilelang.compile(
program1,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
@tilelang.jit(out_idx=[1])
def tir_op_with_parallel(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block):
for j in T.Parallel(block):
B[i * block + j] = T.max(A[i * block + j], 0.0)
return main
@tilelang.jit(out_idx=[1])
def customize_op_with_parallel(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block):
for j in T.Parallel(block):
B[i * block + j] = A[i * block + j]
T.atomic_add(B[i * block + j], 1.0)
return main
def test_tiled_op_with_parallel():
run_gemm_tiled_op_with_parallel(order=[0, 1, 2], stage=[0, 0, 1])
kernel1 = tir_op_with_parallel(length=256, block=16)
data = _require_cuda_tensor((256,), torch.float32)
result1 = kernel1(data)
torch.testing.assert_close(result1, torch.relu(data), atol=1e-5, rtol=1e-5)
kernel2 = customize_op_with_parallel(length=256, block=16)
result2 = kernel2(data)
torch.testing.assert_close(result2, data + 1, atol=1e-5, rtol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -14,7 +14,7 @@ def ASTPrinter(): ...@@ -14,7 +14,7 @@ def ASTPrinter():
Pre-order visitor to print all visited statements. Pre-order visitor to print all visited statements.
""" """
print(f"Visiting statement: {type(statement)}") print(f"Visiting statement: {type(statement)}, {statement}")
def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc: def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc:
new_body = ir_transform(func.body, pre_visit, None) new_body = ir_transform(func.body, pre_visit, None)
......
from tvm import tir from tvm import tir
from tvm.tir import ( from tvm.tir import (
For, For,
Call,
PrimFunc, PrimFunc,
PyStmtExprVisitor, PyStmtExprVisitor,
) )
...@@ -17,6 +18,12 @@ def is_pipelined_for(op: For) -> bool: ...@@ -17,6 +18,12 @@ def is_pipelined_for(op: For) -> bool:
return any(key in op.annotations for key in anno_keys) return any(key in op.annotations for key in anno_keys)
def is_tile_op(op: Call) -> bool:
"""Check if a call is a tile-op"""
return op.op.get_attr("TLOpBuilder") is not None
@tir.functor.visitor @tir.functor.visitor
class _NestedLoopCheckVisitor(PyStmtExprVisitor): class _NestedLoopCheckVisitor(PyStmtExprVisitor):
...@@ -39,7 +46,7 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor): ...@@ -39,7 +46,7 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
"Nested parallel loops are not allowed. " "Nested parallel loops are not allowed. "
"Please check your loop structure.") "Please check your loop structure.")
self.in_parallel_context = True self.in_parallel_context = True
self.visit_stmt(child) super().visit_for_(op)
self.in_parallel_context = False self.in_parallel_context = False
return return
elif is_pipelined_for(op): elif is_pipelined_for(op):
...@@ -48,7 +55,14 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor): ...@@ -48,7 +55,14 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
"Pipelined loop cannot be nested inside a parallel loop. " "Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure.") "Please check your loop structure.")
self.visit_stmt(op.body) super().visit_for_(op)
def visit_call_(self, op: Call) -> None:
if self.in_parallel_context and is_tile_op(op):
raise ValueError("[Tilelang Semantic Check] "
"Only elementwise operations are allowed inside a parallel loop. " \
f"Got a tile-op \"{op.op}\"."
)
def NestedLoopChecker(): def NestedLoopChecker():
......
...@@ -76,10 +76,8 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: ...@@ -76,10 +76,8 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:
# Debug # Debug
# tilelang.analysis.ASTPrinter()(mod) # tilelang.analysis.ASTPrinter()(mod)
# Check if there are any invalid nested loops. # Check if there are any invalid nested loops.
tilelang.analysis.NestedLoopChecker()(mod) tilelang.analysis.NestedLoopChecker()(mod)
# Check if there are any invalid symbolic T.Parallel + fragment access. # Check if there are any invalid symbolic T.Parallel + fragment access.
tilelang.analysis.FragmentLoopChecker()(mod) tilelang.analysis.FragmentLoopChecker()(mod)
......
...@@ -212,9 +212,9 @@ def atomic_add(dst: Buffer, ...@@ -212,9 +212,9 @@ def atomic_add(dst: Buffer,
"return_prev is not supported for tile-region-based atomic operations") "return_prev is not supported for tile-region-based atomic operations")
if memory_order is None: if memory_order is None:
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, 0) return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0)
else: else:
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma,
_MEMORY_ORDER_ID_MAP[memory_order]) _MEMORY_ORDER_ID_MAP[memory_order])
......
...@@ -90,7 +90,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, ...@@ -90,7 +90,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
eviction_policy = 0 eviction_policy = 0
else: else:
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width, return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width,
disable_tma, eviction_policy) disable_tma, eviction_policy)
...@@ -124,5 +124,5 @@ def c2d_im2col(img: tir.Buffer, ...@@ -124,5 +124,5 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
img_region = to_buffer_region(img, access_type="r") img_region = to_buffer_region(img, access_type="r")
col_region = to_buffer_region(col, access_type="w") col_region = to_buffer_region(col, access_type="w")
return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img_region, col_region, return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region,
nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy) nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy)
...@@ -70,7 +70,7 @@ def gemm_sp( ...@@ -70,7 +70,7 @@ def gemm_sp(
C_arg = to_buffer_region(C, access_type="rw") C_arg = to_buffer_region(C, access_type="rw")
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.gemm_sp"), tir.op.Op.get("tl.tileop.gemm_sp"),
A_arg, A_arg,
E_arg, E_arg,
B_arg, B_arg,
......
...@@ -32,7 +32,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim ...@@ -32,7 +32,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
extents = [tir.IntImm("int32", 1) for _ in buffer.indices] extents = [tir.IntImm("int32", 1) for _ in buffer.indices]
else: else:
extents = [] extents = []
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"),
to_buffer_region(buffer, access_type="w", extents=extents), value) to_buffer_region(buffer, access_type="w", extents=extents), value)
......
...@@ -116,7 +116,7 @@ def gemm_v1( ...@@ -116,7 +116,7 @@ def gemm_v1(
): ):
"""GEMM v1: use op tl.gemm.""" """GEMM v1: use op tl.gemm."""
return _gemm_impl( return _gemm_impl(
"tl.gemm", "tl.tileop.gemm",
A, A,
B, B,
C, C,
...@@ -145,7 +145,7 @@ def gemm_v2( ...@@ -145,7 +145,7 @@ def gemm_v2(
): ):
"""GEMM v2: use op tl.gemm_py.""" """GEMM v2: use op tl.gemm_py."""
return _gemm_impl( return _gemm_impl(
"tl.gemm_py", "tl.tileop.gemm_py",
A, A,
B, B,
C, C,
......
...@@ -13,6 +13,9 @@ def _legalize_dim(buffer: tir.Buffer, dim: int): ...@@ -13,6 +13,9 @@ def _legalize_dim(buffer: tir.Buffer, dim: int):
return dim return dim
_REDUCE_OP_KEY = "tl.tileop.reduce"
def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
"""Perform a reduction operation on a buffer along a specified dimension. """Perform a reduction operation on a buffer along a specified dimension.
...@@ -50,7 +53,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -50,7 +53,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
copy(buffer, red_frag_in) copy(buffer, red_frag_in)
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.reduce"), tir.op.Op.get(_REDUCE_OP_KEY),
to_buffer_region(red_frag_in, access_type="r"), to_buffer_region(red_frag_in, access_type="r"),
to_buffer_region(red_frag_out, access_type="w"), to_buffer_region(red_frag_out, access_type="w"),
reduce_type, reduce_type,
...@@ -65,7 +68,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -65,7 +68,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
copy(buffer, red_frag_in) copy(buffer, red_frag_in)
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.reduce"), tir.op.Op.get(_REDUCE_OP_KEY),
to_buffer_region(red_frag_in, access_type="r"), to_buffer_region(red_frag_in, access_type="r"),
to_buffer_region(out, access_type="w"), to_buffer_region(out, access_type="w"),
reduce_type, reduce_type,
...@@ -78,7 +81,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -78,7 +81,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.reduce"), tir.op.Op.get(_REDUCE_OP_KEY),
to_buffer_region(buffer, access_type="r"), to_buffer_region(buffer, access_type="r"),
to_buffer_region(red_frag_out, access_type="w"), to_buffer_region(red_frag_out, access_type="w"),
reduce_type, reduce_type,
...@@ -89,7 +92,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -89,7 +92,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
elif is_fragment(buffer) and is_fragment(out): elif is_fragment(buffer) and is_fragment(out):
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.reduce"), tir.op.Op.get(_REDUCE_OP_KEY),
to_buffer_region(buffer, access_type="r"), to_buffer_region(buffer, access_type="r"),
to_buffer_region(out, access_type="w"), to_buffer_region(out, access_type="w"),
reduce_type, reduce_type,
...@@ -245,7 +248,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - ...@@ -245,7 +248,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
copy(src, cumsum_smem) copy(src, cumsum_smem)
tir.call_intrin( tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.cumsum"), tir.op.Op.get("tl.tileop.cumsum"),
to_buffer_region(cumsum_smem, access_type="r"), to_buffer_region(cumsum_smem, access_type="r"),
to_buffer_region(cumsum_smem, access_type="w"), to_buffer_region(cumsum_smem, access_type="w"),
dim, dim,
...@@ -299,7 +302,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse ...@@ -299,7 +302,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse
return cumsum_fragment(src, dst, dim, reverse) return cumsum_fragment(src, dst, dim, reverse)
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.cumsum"), tir.op.Op.get("tl.tileop.cumsum"),
to_buffer_region(src, access_type="r"), to_buffer_region(src, access_type="r"),
to_buffer_region(dst, access_type="w"), to_buffer_region(dst, access_type="w"),
dim, dim,
...@@ -309,7 +312,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse ...@@ -309,7 +312,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse
def finalize_reducer(reducer: tir.Buffer): def finalize_reducer(reducer: tir.Buffer):
""" """
Finalize a reducer buffer by emitting the `tl.finalize_reducer` intrinsic. Finalize a reducer buffer by emitting the `tl.tileop.finalize_reducer` intrinsic.
This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer. This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer.
The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR. The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR.
...@@ -322,7 +325,7 @@ def finalize_reducer(reducer: tir.Buffer): ...@@ -322,7 +325,7 @@ def finalize_reducer(reducer: tir.Buffer):
""" """
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.finalize_reducer"), tir.op.Op.get("tl.tileop.finalize_reducer"),
to_buffer_region(reducer, access_type="w"), to_buffer_region(reducer, access_type="w"),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment