Unverified Commit 3cfefc8e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Support python reflection for tile operators (#783)

* Implement Fill operator and related reflection methods in TileLang

- Added Fill operator implementation in `fill.cc` and `fill.h` for element-wise filling of buffers.
- Introduced reflection methods for Fill, AtomicAdd, Copy, Conv2DIm2Col, FinalizeReducer, Gemm, and Parallel operators to enhance introspection capabilities.
- Updated relevant files to register reflection methods and ensure proper initialization in static blocks.
- Removed outdated comments and unnecessary code in various operator files to improve clarity and maintainability.
- Added new Python bindings for the Fill operator in `tilelang/ir/fill.py` and updated the module imports accordingly.

* Refactor operator reflection methods and improve code clarity

- Updated reflection methods for AtomicAdd, Copy, FinalizeReducer, Gemm, and Parallel operators to enhance readability by using `empty()` instead of size checks.
- Consolidated static initialization blocks for various oper...
parent 141e01fb
......@@ -104,3 +104,5 @@ from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa
from .version import __version__ # noqa: F401
from .math import * # noqa: F403
from . import ir # noqa: F401
......@@ -94,6 +94,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Infer memory layouts for fragments and shared memory
mod = tilelang.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations
print("LowerTileOp")
print(mod.script())
mod = tilelang.transform.LowerTileOp()(mod)
# Lower l2 persistent map
mod = tilelang.transform.LowerL2Persistent()(mod)
......
from tilelang import tvm as tvm
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
@tvm.ffi.register_object("tl.Fill")
class Fill(Node, Scriptable):
...
@tvm.ffi.register_object("tl.AtomicAdd")
class AtomicAdd(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Copy")
class Copy(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Conv2DIm2Col")
class Conv2DIm2ColOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.GemmWarpPolicy")
class GemmWarpPolicy(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable):
...
@tvm.ffi.register_object("tl.GemmSP")
class GemmSP(Node, Scriptable):
...
@tvm.ffi.register_object("tl.FinalizeReducerOp")
class FinalizeReducerOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ParallelOp")
class ParallelOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ReduceOp")
class ReduceOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.CumSumOp")
class CumSumOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.RegionOp")
class RegionOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ReduceType")
class ReduceType(Node, Scriptable):
...
......@@ -153,11 +153,16 @@ class TensorProxy(BaseTensorProxy):
def __call__(self,
shape: Union[Tuple[Any], PrimExpr, int],
dtype: str = "float32",
data=None) -> tir.Buffer:
data=None,
scope=None) -> tir.Buffer:
if isinstance(shape, (int, PrimExpr)):
shape = (shape,)
return super().__call__(
shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data)
shape,
dtype=dtype,
strides=TensorProxy._construct_strides(shape),
data=data,
scope=scope)
class StridedTensorProxy(BaseTensorProxy):
......@@ -169,13 +174,14 @@ class StridedTensorProxy(BaseTensorProxy):
def __call__(self,
shape: Tuple[Any],
strides: Tuple[Any],
dtype: str = "float32") -> tir.Buffer:
dtype: str = "float32",
scope=None) -> tir.Buffer:
if len(shape) != len(strides):
raise ValueError("Invalid shape/strides' dimensions")
if not bool(strides[-1] == 1):
# TODO(chenggang): shall we support non-contiguous even for the last dimension?
raise ValueError("The stride of the last dimension must be 1 (contiguous)")
return super().__call__(shape, dtype=dtype, strides=strides)
return super().__call__(shape, dtype=dtype, strides=strides, scope=scope)
class FragmentBufferProxy(BaseTensorProxy):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment