Unverified Commit 3cfefc8e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Support python reflection for tile operators (#783)

* Implement Fill operator and related reflection methods in TileLang

- Added Fill operator implementation in `fill.cc` and `fill.h` for element-wise filling of buffers.
- Introduced reflection methods for Fill, AtomicAdd, Copy, Conv2DIm2Col, FinalizeReducer, Gemm, and Parallel operators to enhance introspection capabilities.
- Updated relevant files to register reflection methods and ensure proper initialization in static blocks.
- Removed outdated comments and unnecessary code in various operator files to improve clarity and maintainability.
- Added new Python bindings for the Fill operator in `tilelang/ir/fill.py` and updated the module imports accordingly.

* Refactor operator reflection methods and improve code clarity

- Updated reflection methods for AtomicAdd, Copy, FinalizeReducer, Gemm, and Parallel operators to enhance readability by using `empty()` instead of size checks.
- Consolidated static initialization blocks for various operators to a single line for improved consistency.
- Cleaned up whitespace and formatting in multiple files to adhere to coding standards and improve maintainability.
- Added new Python bindings for operators in the `tilelang/ir` module, ensuring proper registration and organization of imports.

* Refactor GEMM and AtomicAdd operations for improved clarity

- Updated the `GetArchInt` function in `atomic_add.cc` to use `std::string` and `std::stoi` for better readability and type safety.
- Removed unnecessary variables and comments in `gemm_sp.cc` and `gemm.cc` to streamline the `ComputeWarpPartition` method.
- Cleaned up the `layout_reducer.cc` file by removing unused variable declarations, enhancing code clarity.
- Added import for the `ir` module in `tilelang/__init__.py` to ensure proper organization of module imports.

* Remove deprecated operator files from the tilelang IR module

- Deleted files for Fill, AtomicAdd, Copy, Gemm, GemmSP, FinalizeReducer, Parallel, Reduce, and Region operators to streamline the codebase.
- This cleanup enhances maintainability by removing unused code and improving overall organization of the module.

* Refactor imports in tilelang IR module for improved organization

- Updated import statements in `tilelang/ir.py` to reflect changes in the TVM library structure, enhancing clarity and maintainability of the codebase.

* lint fix

* Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability

- Updated the `Gemm` and `GemmSP` classes to utilize a new `GemmWarpPolicy` object for warp partitioning, improving encapsulation and readability.
- Removed deprecated `ComputeWarpPartition` methods and replaced them with calls to the new policy object, streamlining the code.
- Cleaned up comments and unnecessary code in `gemm.cc`, `gemm_sp.cc`, and related header files to enhance overall clarity.
- Introduced a new `GemmWarpPolicyNode` class to manage warp policy attributes and methods, facilitating better organization of related functionalities.
- Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities.

* Refactor Reduce operation to utilize ReduceType class for improved clarity and maintainability

- Replaced multiple conditional checks for reduce types with a single ReduceType object, simplifying the code structure.
- Introduced a new ReduceTypeNode class to encapsulate reduce type logic and methods, enhancing organization.
- Updated MakeInitValue, MakeReduce, and Lower methods to leverage the new ReduceType class, improving readability.
- Added Python bindings for the ReduceType class in tilelang IR module to ensure proper registration and usability.

* comment

* Refactor operator header files for improved readability

- Cleaned up formatting and whitespace in `atomic_add.h`, `copy.h`, `fill.h`, `reduce.cc`, and `reduce.h` to enhance code clarity.
- Consolidated comments and adjusted line breaks for better organization and maintainability across multiple operator definitions.

* Refactor MakeReduce method in ReduceOpNode for clarity

- Updated the parameter name in the MakeReduce method from `rhs` to `b` and assigned it to `rhs` for improved readability.
- This change enhances the clarity of the method's purpose and aligns with the overall refactoring efforts in the Reduce operation.

* Update Reduce operation type checks for consistency

- Changed string comparisons for reduce types in the MakeReduce method from "abs_sum" to "abssum" and "abs_max" to "absmax" for uniformity.
- This adjustment enhances the clarity and consistency of the reduce type handling in the codebase.
parent 141e01fb
...@@ -104,3 +104,5 @@ from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa ...@@ -104,3 +104,5 @@ from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa
from .version import __version__ # noqa: F401 from .version import __version__ # noqa: F401
from .math import * # noqa: F403 from .math import * # noqa: F403
from . import ir # noqa: F401
...@@ -94,6 +94,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -94,6 +94,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Infer memory layouts for fragments and shared memory # Infer memory layouts for fragments and shared memory
mod = tilelang.transform.LayoutInference()(mod) mod = tilelang.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations # Lower high-level tile operations to low-level operations
print("LowerTileOp")
print(mod.script())
mod = tilelang.transform.LowerTileOp()(mod) mod = tilelang.transform.LowerTileOp()(mod)
# Lower l2 persistent map # Lower l2 persistent map
mod = tilelang.transform.LowerL2Persistent()(mod) mod = tilelang.transform.LowerL2Persistent()(mod)
......
from tilelang import tvm as tvm
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
@tvm.ffi.register_object("tl.Fill")
class Fill(Node, Scriptable):
...
@tvm.ffi.register_object("tl.AtomicAdd")
class AtomicAdd(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Copy")
class Copy(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Conv2DIm2Col")
class Conv2DIm2ColOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.GemmWarpPolicy")
class GemmWarpPolicy(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable):
...
@tvm.ffi.register_object("tl.GemmSP")
class GemmSP(Node, Scriptable):
...
@tvm.ffi.register_object("tl.FinalizeReducerOp")
class FinalizeReducerOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ParallelOp")
class ParallelOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ReduceOp")
class ReduceOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.CumSumOp")
class CumSumOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.RegionOp")
class RegionOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ReduceType")
class ReduceType(Node, Scriptable):
...
...@@ -153,11 +153,16 @@ class TensorProxy(BaseTensorProxy): ...@@ -153,11 +153,16 @@ class TensorProxy(BaseTensorProxy):
def __call__(self, def __call__(self,
shape: Union[Tuple[Any], PrimExpr, int], shape: Union[Tuple[Any], PrimExpr, int],
dtype: str = "float32", dtype: str = "float32",
data=None) -> tir.Buffer: data=None,
scope=None) -> tir.Buffer:
if isinstance(shape, (int, PrimExpr)): if isinstance(shape, (int, PrimExpr)):
shape = (shape,) shape = (shape,)
return super().__call__( return super().__call__(
shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data) shape,
dtype=dtype,
strides=TensorProxy._construct_strides(shape),
data=data,
scope=scope)
class StridedTensorProxy(BaseTensorProxy): class StridedTensorProxy(BaseTensorProxy):
...@@ -169,13 +174,14 @@ class StridedTensorProxy(BaseTensorProxy): ...@@ -169,13 +174,14 @@ class StridedTensorProxy(BaseTensorProxy):
def __call__(self, def __call__(self,
shape: Tuple[Any], shape: Tuple[Any],
strides: Tuple[Any], strides: Tuple[Any],
dtype: str = "float32") -> tir.Buffer: dtype: str = "float32",
scope=None) -> tir.Buffer:
if len(shape) != len(strides): if len(shape) != len(strides):
raise ValueError("Invalid shape/strides' dimensions") raise ValueError("Invalid shape/strides' dimensions")
if not bool(strides[-1] == 1): if not bool(strides[-1] == 1):
# TODO(chenggang): shall we support non-contiguous even for the last dimension? # TODO(chenggang): shall we support non-contiguous even for the last dimension?
raise ValueError("The stride of the last dimension must be 1 (contiguous)") raise ValueError("The stride of the last dimension must be 1 (contiguous)")
return super().__call__(shape, dtype=dtype, strides=strides) return super().__call__(shape, dtype=dtype, strides=strides, scope=scope)
class FragmentBufferProxy(BaseTensorProxy): class FragmentBufferProxy(BaseTensorProxy):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment