Commit be0bf36d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Proxy tvm ir to make linter happy (#287)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix

* [Refactor] Update tensor creation in matrix multiplication test

- Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency.
- Updated imports in `__init__.py` to include `make_tensor`.
- Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers.

* [Refactor] Update tensor definitions across multiple files

- Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity.
- Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations.
- Improved documentation in README and example files to reflect changes in tensor usage.

* lint fix

* [Refactor] Update tensor types in attention and matrix multiplication examples

- Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity.
- Adjusted tensor definitions in benchmark and example files to align with the new tensor types.
- Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files.

* lint fix

* [Refactor] Update tensor types in GEMM example and test files

- Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity.
- Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions.

* [Refactor] Update tensor usage in customize.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the file.

* [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the test file.

* [Refactor] Update tensor types to SharedBuffer and FragmentBuffer

- Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions.
- Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions.

* [Refactor] Introduce Tensor alias for Buffer in proxy.py

- Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`.
- This change enhances clarity and consistency in tensor usage across the codebase.

* [Refactor] Revamp cache management and enhance documentation in env.py and proxy.py

- Replaced global cache functions with a CacheState class to improve encapsulation and management of kernel caching.
- Updated the `from_ptr` method in BufferProxy and BaseTensorProxy classes to include detailed docstrings for better clarity on parameters and return values.
- Enhanced class docstrings across various proxy classes to provide clearer descriptions of their purpose and functionality, improving overall code documentation.

* [Refactor] Update imports in __init__.py for tir compatibility

- Added imports for `prim_func` and `tir.op` to enhance compatibility with the upstream tir script.
- Marked imports with `# noqa: F401` to suppress linting warnings for unused imports, indicating future removal once compatibility is achieved.

* lint fix

* [Refactor] Update imports in tir.ir.py for improved compatibility

- Removed unused import of `PrimExpr` from `tvm.script.ir_builder.tir` and replaced it with the correct import from `tvm.tir`.
- Added import for `tir.ir` in `__init__.py` to enhance module accessibility and maintain compatibility with upstream changes.

* [Refactor] Update function calls in tir.ir.py to return values

- Modified the `serial`, `parallel`, `vectorized`, `unroll`, `thread_binding`, and `grid` functions to return the results of their respective calls to `_ir` methods, enhancing clarity and ensuring proper value propagation.

* bugfix

* [Enhancement] Add support for uint16 data type in TLCUDASourceWrapper

- Introduced the "uint16" mapping to the type dictionary in the TLCUDASourceWrapper class, expanding the range of supported data types for CUDA operations.

* bugfix

* Uncomment main function call
parent 76435ca8
...@@ -59,7 +59,7 @@ SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0") ...@@ -59,7 +59,7 @@ SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0")
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
if TVM_IMPORT_PYTHON_PATH is not None: if TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")) os.environ["PYTHONPATH"] = TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH) sys.path.insert(0, TVM_IMPORT_PYTHON_PATH)
else: else:
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
...@@ -122,26 +122,32 @@ if os.environ.get("TL_TEMPLATE_PATH", None) is None: ...@@ -122,26 +122,32 @@ if os.environ.get("TL_TEMPLATE_PATH", None) is None:
else: else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
# Cache control
_ENABLE_TILELANG_KERNEL_CACHE = True # Default cache state
def enable_cache():
"""Enable kernel caching globally."""
global _ENABLE_TILELANG_KERNEL_CACHE
_ENABLE_TILELANG_KERNEL_CACHE = True
def disable_cache():
"""Disable kernel caching globally."""
global _ENABLE_TILELANG_KERNEL_CACHE
_ENABLE_TILELANG_KERNEL_CACHE = False
def is_cache_enabled() -> bool:
"""Return current cache state."""
return _ENABLE_TILELANG_KERNEL_CACHE
# Cache control
class CacheState:
"""Class to manage global kernel caching state."""
_enabled = True
@classmethod
def enable(cls):
"""Enable kernel caching globally."""
cls._enabled = True
@classmethod
def disable(cls):
"""Disable kernel caching globally."""
cls._enabled = False
@classmethod
def is_enabled(cls) -> bool:
"""Return current cache state."""
return cls._enabled
# Replace the old functions with class methods
enable_cache = CacheState.enable
disable_cache = CacheState.disable
is_cache_enabled = CacheState.is_enabled
__all__ = [ __all__ = [
"CUTLASS_INCLUDE_DIR", "CUTLASS_INCLUDE_DIR",
......
...@@ -101,6 +101,7 @@ class TLCUDASourceWrapper(object): ...@@ -101,6 +101,7 @@ class TLCUDASourceWrapper(object):
"int8": "int8_t", "int8": "int8_t",
"uint8": "uint8_t", "uint8": "uint8_t",
"int16": "int16_t", "int16": "int16_t",
"uint16": "uint16_t",
"uchar": "uint8_t", "uchar": "uint8_t",
} }
......
...@@ -4,7 +4,13 @@ from typing import Optional ...@@ -4,7 +4,13 @@ from typing import Optional
# from .parser import * # from .parser import *
# now is fully compatible with the upstream # now is fully compatible with the upstream
# tir script # tir script
# TODO(lei): remove this import once the
# upstream tir script is fully compatible
from tvm.script.parser.tir import * from tvm.script.parser.tir import *
from .tir import (
prim_func, # noqa: F401
)
from .tir.ir import * # noqa: F401
from tilelang.layout import Layout, Fragment # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import ( from .proxy import (
ptr, # noqa: F401 ptr, # noqa: F401
......
...@@ -49,12 +49,30 @@ class BufferProxy: ...@@ -49,12 +49,30 @@ class BufferProxy:
return self(keys) return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> Buffer: def from_ptr(self,
return match_buffer(ptr, shape, dtype=dtype) pointer_var: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32") -> Buffer:
"""Create a buffer from a pointer, shape, and data type.
Args:
pointer_var: The pointer variable
shape: The shape of the buffer
dtype: The data type of the buffer (default: float32)
Returns:
A buffer created from the given parameters
"""
return match_buffer(pointer_var, shape, dtype=dtype)
class BaseTensorProxy: class BaseTensorProxy:
"""Base proxy class for tensor types with configurable defaults""" """Base proxy class for tensor types with configurable defaults.
This class serves as a foundation for different tensor proxy types, providing
customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations.
"""
default_scope = "global" default_scope = "global"
default_align = 0 default_align = 0
default_offset_factor = 0 default_offset_factor = 0
...@@ -97,23 +115,55 @@ class BaseTensorProxy: ...@@ -97,23 +115,55 @@ class BaseTensorProxy:
return self(keys) return self(keys)
return self(*keys) return self(*keys)
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> tir.Buffer: def from_ptr(self,
return match_buffer(ptr, shape, dtype=dtype) pointer_var: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32") -> tir.Buffer:
"""Create a buffer from a pointer, shape, and data type.
Args:
pointer_var: The pointer variable
shape: The shape of the buffer
dtype: The data type of the buffer (default: float32)
Returns:
A buffer created from the given parameters
"""
return match_buffer(pointer_var, shape, dtype=dtype)
class TensorProxy(BaseTensorProxy): class TensorProxy(BaseTensorProxy):
"""Main tensor proxy with default global scope""" """Main tensor proxy class for global scope buffers.
This class implements the default tensor proxy with global memory scope,
inheriting all functionality from BaseTensorProxy without modifications.
"""
class FragmentBufferProxy(BaseTensorProxy): class FragmentBufferProxy(BaseTensorProxy):
"""Proxy class for fragment memory buffers.
This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations.
"""
default_scope = "local.fragment" default_scope = "local.fragment"
class SharedBufferProxy(BaseTensorProxy): class SharedBufferProxy(BaseTensorProxy):
"""Proxy class for shared memory buffers.
This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations.
"""
default_scope = "shared.dyn" default_scope = "shared.dyn"
class LocalBufferProxy(BaseTensorProxy): class LocalBufferProxy(BaseTensorProxy):
"""Proxy class for local memory buffers.
This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels.
"""
default_scope = "local" default_scope = "local"
......
from .entry import prim_func # noqa: F401
from .ir import * # noqa: F401
from typing import Callable, Optional, Union
from tvm.tir.function import PrimFunc
import tvm.script.parser.tir.entry as _tir_entry
import inspect
from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Optional[Callable] = None,
private: bool = False,
check_well_formed=True) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
----------
func : Callable
The function to be parsed as prim func.
(Listed as optional to allow the decorator to be used
without arguments, like `@prim_func`,
or with an argument, `@prim_func(private=True)`)
private : bool, optional
Whether the function should be treated as private.
A private function has no global symbol attribute;
if the function is not private, it will have a global symbol
matching the function name.
Returns
-------
res : Union[PrimFunc, Callable]
The parsed tir prim func.
"""
# pylint: disable=unused-argument
# (private will be used in the parser, but not immediately)
# need to capture this var outside the wrapper because the wrapper
# adds to the stack
outer_stack = inspect.stack()
def decorator_wrapper(func):
if not inspect.isfunction(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
setattr(f, "__name__", func.__name__) # noqa: B010
return f
if func is not None:
# no optional args given => use wrapper directly
return decorator_wrapper(func)
else:
# if there is an optional arg given, return a new decorator
# that will then be invoked
setattr(decorator_wrapper, "dispatch_token", "tir") # noqa: B010
return decorator_wrapper
setattr(prim_func, "dispatch_token", "tir") # noqa: B010
def macro(*args, hygienic: bool = True) -> Callable:
"""Decorator for macro definitions.
Parameters
----------
hygienic: bool
Specifies whether the macro is hygienic or not.
A macro is hygienic if all symbols used in the macro's body are resolved
to values from the location of the macro definition. A non-hygienic macro
will have its symbols resolved to values at the time of the macro's use.
Example:
```
import tvm
from tvm.script import tir as T
x_value = 128
@T.macro(hygienic=True)
def static_capture(A, B):
B[()] = A[x_value] ### x_value binds to 128
@T.macro(hygienic=False)
def dynamic_capture(A, B):
B[()] = A[x_value] ### x_value will bind at the time of use
@T.prim_func
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
static_capture(A, B) ### Produces B[()] = A[128]
@T.prim_func
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B) ### Produces B[()] = A[x_value]
```
"""
def _decorator(func: Callable) -> _tir_entry.TIRMacro:
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = _tir_entry.TIRMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
return obj
if len(args) == 0:
return _decorator
if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0])
raise ValueError(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
setattr(macro, "dispatch_token", "tir") # noqa: B010
import tvm.script.ir_builder.tir.ir as _ir
from tvm.script.ir_builder.tir import frame
from tvm.tir import PrimExpr
from typing import Any, Dict
import tilelang.language.tir.op as _tir_op
import functools
def serial(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.serial(start=start, stop=stop, annotations=annotations)
def parallel(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.parallel(start=start, stop=stop, annotations=annotations)
def vectorized(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.vectorized(start=start, stop=stop, annotations=annotations)
def unroll(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.unroll(start=start, stop=stop, annotations=annotations)
def thread_binding(
start: PrimExpr,
stop: PrimExpr = None,
thread: str = None,
*,
annotations: Dict[str, Any] = None,
) -> frame.ForFrame:
"""The thread-binding For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
thread : str
The thread for loop variable to bind.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.thread_binding(start=start, stop=stop, thread=thread, annotations=annotations)
def grid(*extents: PrimExpr) -> frame.ForFrame:
"""The grid For statement.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.grid(*extents)
def _dtype_forward(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
args = (kwargs.pop("dtype"),) + args
return func(*args, **kwargs)
return wrapped
def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
kwargs.pop("dtype")
return func(*args, **kwargs)
return wrapped
abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin
acos = _op_wrapper(_tir_op.acos)
acosh = _op_wrapper(_tir_op.acosh)
address_of = _op_wrapper(_tir_op.address_of)
asin = _op_wrapper(_tir_op.asin)
asinh = _op_wrapper(_tir_op.asinh)
atan = _op_wrapper(_tir_op.atan)
atan2 = _op_wrapper(_tir_op.atan2)
atanh = _op_wrapper(_tir_op.atanh)
bitwise_and = _op_wrapper(_tir_op.bitwise_and)
bitwise_not = _op_wrapper(_tir_op.bitwise_not)
bitwise_or = _op_wrapper(_tir_op.bitwise_or)
bitwise_xor = _op_wrapper(_tir_op.bitwise_xor)
ceil = _op_wrapper(_tir_op.ceil)
clz = _op_wrapper(_tir_op.clz)
copysign = _op_wrapper(_tir_op.copysign)
cos = _op_wrapper(_tir_op.cos)
cosh = _op_wrapper(_tir_op.cosh)
erf = _op_wrapper(_tir_op.erf)
exp = _op_wrapper(_tir_op.exp)
exp2 = _op_wrapper(_tir_op.exp2)
exp10 = _op_wrapper(_tir_op.exp10)
floor = _op_wrapper(_tir_op.floor)
ceildiv = _op_wrapper(_tir_op.ceildiv)
floordiv = _op_wrapper(_tir_op.floordiv)
floormod = _op_wrapper(_tir_op.floormod)
fmod = _op_wrapper(_tir_op.fmod)
hypot = _op_wrapper(_tir_op.hypot)
if_then_else = _op_wrapper(_tir_op.if_then_else)
infinity = _op_wrapper(_tir_op.infinity)
isfinite = _op_wrapper(_tir_op.isfinite)
isinf = _op_wrapper(_tir_op.isinf)
isnan = _op_wrapper(_tir_op.isnan)
isnullptr = _op_wrapper(_tir_op.isnullptr)
ldexp = _op_wrapper(_tir_op.ldexp)
likely = _op_wrapper(_tir_op.likely)
log = _op_wrapper(_tir_op.log)
log1p = _op_wrapper(_tir_op.log1p)
log2 = _op_wrapper(_tir_op.log2)
log10 = _op_wrapper(_tir_op.log10)
lookup_param = _op_wrapper(_tir_op.lookup_param)
max_value = _op_wrapper(_tir_op.max_value)
min_value = _op_wrapper(_tir_op.min_value)
nearbyint = _op_wrapper(_tir_op.nearbyint)
nextafter = _op_wrapper(_tir_op.nextafter)
popcount = _op_wrapper(_tir_op.popcount)
pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin
q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis)
ret = _op_wrapper(_tir_op.ret)
round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
rsqrt = _op_wrapper(_tir_op.rsqrt)
shift_left = _op_wrapper(_tir_op.shift_left)
shift_right = _op_wrapper(_tir_op.shift_right)
sigmoid = _op_wrapper(_tir_op.sigmoid)
sin = _op_wrapper(_tir_op.sin)
sinh = _op_wrapper(_tir_op.sinh)
sqrt = _op_wrapper(_tir_op.sqrt)
tan = _op_wrapper(_tir_op.tan)
tanh = _op_wrapper(_tir_op.tanh)
trunc = _op_wrapper(_tir_op.trunc)
truncdiv = _op_wrapper(_tir_op.truncdiv)
truncmod = _op_wrapper(_tir_op.truncmod)
tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
tvm_check_return = _op_wrapper(_tir_op.tvm_check_return)
call_packed = _op_wrapper(_tir_op.call_packed)
call_cpacked = _op_wrapper(_tir_op.call_cpacked)
call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
tvm_struct_get = _tir_op.tvm_struct_get
tvm_thread_invariant = _op_wrapper(_tir_op.tvm_thread_invariant)
tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
tvm_storage_sync = _tir_op.tvm_storage_sync
tvm_warp_shuffle = _tir_op.tvm_warp_shuffle
tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up
tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
tvm_warp_activemask = _tir_op.tvm_warp_activemask
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count)
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic)
anylist_getitem = _op_wrapper(_tir_op.anylist_getitem)
anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem)
anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed)
anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked)
vscale = _op_wrapper(_tir_op.vscale)
reinterpret = _dtype_forward(_tir_op.reinterpret)
call_extern = _dtype_forward(_tir_op.call_extern)
call_intrin = _dtype_forward(_tir_op.call_intrin)
call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
ptx_mma = _dtype_forward(_tir_op.ptx_mma)
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
mma_store = _dtype_forward(_tir_op.mma_store)
mma_fill = _dtype_forward(_tir_op.mma_fill)
vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
tvm_mfma = _dtype_forward(_tir_op.tvm_mfma)
tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store)
tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma)
tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment