Commit 41d4988b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Enhance ReduceOp and JITKernel for improved dimension handling...

[Enhancement] Enhance ReduceOp and JITKernel for improved dimension handling and initialization (#507)

* [Refactor] Update reduce functions to support default dimension values and improve dimension handling

* Added a helper function `_legalize_dim` to handle negative dimension values in reduction operations.
* Updated `reduce_max`, `reduce_min`, `reduce_sum`, `reduce_abssum`, and `reduce_absmax` functions to accept a default dimension value of -1, enhancing usability and flexibility in buffer reduction operations.
* Ensured consistent dimension handling across all reduction functions for improved clarity and correctness.

* Update submodule `tvm` to latest commit c2921fd, ensuring compatibility with recent changes.

* [Refactor] Enhance ReduceOp and JITKernel for improved dimension handling and initialization

* Updated ReduceOp to handle 1D reduction cases and ensure correct dimension checks, improving robustness in reduction operations.
* Initialized prim_func in JITKernel to enhance clarity and prevent potential null reference issues.
* Added whitespace for better code readability in reduce.py.
parent 84ddb9e1
Subproject commit b16c9f298bc37fa502ffdb2ea809c2793e2a0bd6 Subproject commit c2921fdaf795b1103d21abc962e83a209c7258d7
...@@ -105,14 +105,28 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -105,14 +105,28 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto dst_buffer = T.buffer_remap[this->dst]; auto dst_buffer = T.buffer_remap[this->dst];
Fragment src_layout = T.layout_map[this->src].as<Fragment>().value(); Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value(); Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
ICHECK(src_layout->InputDim() == dst_layout->InputDim() + 1); size_t src_dim = src_layout->InputDim();
size_t dst_dim = dst_layout->InputDim();
bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;
if (is_1d_reduce) {
ICHECK(is_one(dst_layout->OutputShape().back()))
<< "Reduce for scalar not implemented.";
} else {
ICHECK(src_dim == dst_dim + 1) << "Reduce dimension mismatch.";
}
Array<IterVar> dst_vars; Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_layout->InputDim(); i++) { for (size_t i = 0; i < dst_dim; i++) {
Var var = Var(std::string{char('i' + i)}); Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
IterVarType::kDataPar)); IterVarType::kDataPar));
} }
Array<IterVar> src_vars = dst_vars; Array<IterVar> src_vars;
if (!is_1d_reduce) {
src_vars = dst_vars;
}
src_vars.insert(src_vars.begin() + this->dim, src_vars.insert(src_vars.begin() + this->dim,
{Range(0, src_layout->InputShape()[this->dim]), Var("rv"), {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
IterVarType::kDataPar}); IterVarType::kDataPar});
......
...@@ -28,7 +28,7 @@ class JITKernel(object): ...@@ -28,7 +28,7 @@ class JITKernel(object):
torch_function : Callable torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function. The compiled function that can be invoked as a PyTorch-compatible function.
""" """
prim_func: PrimFunc = None
artifact: CompiledArtifact = None artifact: CompiledArtifact = None
adapter: BaseKernelAdapter = None adapter: BaseKernelAdapter = None
torch_function: Callable = None torch_function: Callable = None
...@@ -71,6 +71,7 @@ class JITKernel(object): ...@@ -71,6 +71,7 @@ class JITKernel(object):
from_database : bool, optional from_database : bool, optional
Whether to create a TorchFunction from a database. Whether to create a TorchFunction from a database.
""" """
self.prim_func = func
self.execution_backend = execution_backend self.execution_backend = execution_backend
self.target = target self.target = target
self.target_host = target_host self.target_host = target_host
......
...@@ -5,6 +5,12 @@ from typing import Optional ...@@ -5,6 +5,12 @@ from typing import Optional
from tilelang.language import copy, macro, alloc_shared from tilelang.language import copy, macro, alloc_shared
def _legalize_dim(buffer: tir.Buffer, dim: int):
if dim < 0:
dim = len(buffer.shape) + dim
return dim
def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
"""Perform a reduction operation on a buffer along a specified dimension. """Perform a reduction operation on a buffer along a specified dimension.
...@@ -31,7 +37,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -31,7 +37,7 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
) )
def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce max on input buffer, store the result to output buffer """Perform reduce max on input buffer, store the result to output buffer
Parameters Parameters
...@@ -48,10 +54,11 @@ def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True ...@@ -48,10 +54,11 @@ def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
------- -------
handle : PrimExpr handle : PrimExpr
""" """
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "max", dim, clear) return reduce(buffer, out, "max", dim, clear)
def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce min on input buffer, store the result to output buffer. """Perform reduce min on input buffer, store the result to output buffer.
Args: Args:
...@@ -63,10 +70,11 @@ def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True ...@@ -63,10 +70,11 @@ def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
Returns: Returns:
tir.Call: Handle to the reduction operation tir.Call: Handle to the reduction operation
""" """
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "min", dim, clear) return reduce(buffer, out, "min", dim, clear)
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce sum on input buffer, store the result to output buffer. """Perform reduce sum on input buffer, store the result to output buffer.
Args: Args:
...@@ -87,10 +95,11 @@ def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True ...@@ -87,10 +95,11 @@ def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
Returns: Returns:
tir.Call: Handle to the reduction operation tir.Call: Handle to the reduction operation
""" """
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "sum", dim, clear) return reduce(buffer, out, "sum", dim, clear)
def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1):
"""Perform reduce absolute sum on input buffer, store the result to output buffer. """Perform reduce absolute sum on input buffer, store the result to output buffer.
Args: Args:
...@@ -101,10 +110,11 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): ...@@ -101,10 +110,11 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
Returns: Returns:
tir.Call: Handle to the reduction operation tir.Call: Handle to the reduction operation
""" """
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "abssum", dim, True) return reduce(buffer, out, "abssum", dim, True)
def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True): def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce absolute max on input buffer, store the result to output buffer. """Perform reduce absolute max on input buffer, store the result to output buffer.
Args: Args:
...@@ -115,6 +125,7 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = T ...@@ -115,6 +125,7 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = T
Returns: Returns:
tir.Call: Handle to the reduction operation tir.Call: Handle to the reduction operation
""" """
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "absmax", dim, clear) return reduce(buffer, out, "absmax", dim, clear)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment