Commit 41d4988b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Enhance ReduceOp and JITKernel for improved dimension handling...

[Enhancement] Enhance ReduceOp and JITKernel for improved dimension handling and initialization (#507)

* [Refactor] Update reduce functions to support default dimension values and improve dimension handling

* Added a helper function `_legalize_dim` to handle negative dimension values in reduction operations.
* Updated `reduce_max`, `reduce_min`, `reduce_sum`, `reduce_abssum`, and `reduce_absmax` functions to accept a default dimension value of -1, enhancing usability and flexibility in buffer reduction operations.
* Ensured consistent dimension handling across all reduction functions for improved clarity and correctness.

* Update submodule `tvm` to latest commit c2921fd, ensuring compatibility with recent changes.

* [Refactor] Enhance ReduceOp and JITKernel for improved dimension handling and initialization

* Updated ReduceOp to handle 1D reduction cases and ensure correct dimension checks, improving robustness in reduction operations.
* Initialized prim_func in JITKernel to enhance clarity and prevent potential null reference issues.
* Added whitespace for better code readability in reduce.py.
parent 84ddb9e1
Subproject commit b16c9f298bc37fa502ffdb2ea809c2793e2a0bd6
Subproject commit c2921fdaf795b1103d21abc962e83a209c7258d7
......@@ -105,14 +105,28 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto dst_buffer = T.buffer_remap[this->dst];
Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
ICHECK(src_layout->InputDim() == dst_layout->InputDim() + 1);
size_t src_dim = src_layout->InputDim();
size_t dst_dim = dst_layout->InputDim();
bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;
if (is_1d_reduce) {
ICHECK(is_one(dst_layout->OutputShape().back()))
<< "Reduce for scalar not implemented.";
} else {
ICHECK(src_dim == dst_dim + 1) << "Reduce dimension mismatch.";
}
Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_layout->InputDim(); i++) {
for (size_t i = 0; i < dst_dim; i++) {
Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
IterVarType::kDataPar));
}
Array<IterVar> src_vars = dst_vars;
Array<IterVar> src_vars;
if (!is_1d_reduce) {
src_vars = dst_vars;
}
src_vars.insert(src_vars.begin() + this->dim,
{Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
IterVarType::kDataPar});
......
......@@ -28,7 +28,7 @@ class JITKernel(object):
torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function.
"""
prim_func: PrimFunc = None
artifact: CompiledArtifact = None
adapter: BaseKernelAdapter = None
torch_function: Callable = None
......@@ -71,6 +71,7 @@ class JITKernel(object):
from_database : bool, optional
Whether to create a TorchFunction from a database.
"""
self.prim_func = func
self.execution_backend = execution_backend
self.target = target
self.target_host = target_host
......
......@@ -5,6 +5,12 @@ from typing import Optional
from tilelang.language import copy, macro, alloc_shared
def _legalize_dim(buffer: tir.Buffer, dim: int):
if dim < 0:
dim = len(buffer.shape) + dim
return dim
def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
"""Perform a reduction operation on a buffer along a specified dimension.
......@@ -31,7 +37,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
)
def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce max on input buffer, store the result to output buffer
Parameters
......@@ -48,10 +54,11 @@ def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
-------
handle : PrimExpr
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "max", dim, clear)
def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce min on input buffer, store the result to output buffer.
Args:
......@@ -63,10 +70,11 @@ def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "min", dim, clear)
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce sum on input buffer, store the result to output buffer.
Args:
......@@ -87,10 +95,11 @@ def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "sum", dim, clear)
def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1):
"""Perform reduce absolute sum on input buffer, store the result to output buffer.
Args:
......@@ -101,10 +110,11 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "abssum", dim, True)
def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce absolute max on input buffer, store the result to output buffer.
Args:
......@@ -115,6 +125,7 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = T
Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "absmax", dim, clear)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment