Commit be0bf36d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Proxy tvm ir to make linter happy (#287)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix

* [Refactor] Update tensor creation in matrix multiplication test

- Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency.
- Updated imports in `__init__.py` to include `make_tensor`.
- Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers.

* [Refactor] Update tensor definitions across multiple files

- Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity.
- Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations.
- Improved documentation in README and example files to reflect changes in tensor usage.

* lint fix

* [Refactor] Update tensor types in attention and matrix multiplication examples

- Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity.
- Adjusted tensor definitions in benchmark and example files to align with the new tensor types.
- Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files.

* lint fix

* [Refactor] Update tensor types in GEMM example and test files

- Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity.
- Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions.

* [Refactor] Update tensor usage in customize.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the file.

* [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the test file.

* [Refactor] Update tensor types to SharedBuffer and FragmentBuffer

- Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions.
- Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions.

* [Refactor] Introduce Tensor alias for Buffer in proxy.py

- Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`.
- This change enhances clarity and consistency in tensor usage across the codebase.

* [Refactor] Revamp cache management and enhance documentation in env.py and proxy.py

- Replaced global cache functions with a CacheState class to improve encapsulation and management of kernel caching.
- Updated the `from_ptr` method in BufferProxy and BaseTensorProxy classes to include detailed docstrings for better clarity on parameters and return values.
- Enhanced class docstrings across various proxy classes to provide clearer descriptions of their purpose and functionality, improving overall code documentation.

* [Refactor] Update imports in __init__.py for tir compatibility

- Added imports for `prim_func` and `tir.op` to enhance compatibility with the upstream tir script.
- Marked imports with `# noqa: F401` to suppress linting warnings for unused imports, indicating future removal once compatibility is achieved.

* lint fix

* [Refactor] Update imports in tir.ir.py for improved compatibility

- Removed unused import of `PrimExpr` from `tvm.script.ir_builder.tir` and replaced it with the correct import from `tvm.tir`.
- Added import for `tir.ir` in `__init__.py` to enhance module accessibility and maintain compatibility with upstream changes.

* [Refactor] Update function calls in tir.ir.py to return values

- Modified the `serial`, `parallel`, `vectorized`, `unroll`, `thread_binding`, and `grid` functions to return the results of their respective calls to `_ir` methods, enhancing clarity and ensuring proper value propagation.

* bugfix

* [Enhancement] Add support for uint16 data type in TLCUDASourceWrapper

- Introduced the "uint16" mapping to the type dictionary in the TLCUDASourceWrapper class, expanding the range of supported data types for CUDA operations.

* bugfix

* Uncomment main function call
parent 76435ca8
...@@ -59,7 +59,7 @@ SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0") ...@@ -59,7 +59,7 @@ SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0")
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
if TVM_IMPORT_PYTHON_PATH is not None: if TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")) os.environ["PYTHONPATH"] = TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH) sys.path.insert(0, TVM_IMPORT_PYTHON_PATH)
else: else:
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
...@@ -122,26 +122,32 @@ if os.environ.get("TL_TEMPLATE_PATH", None) is None: ...@@ -122,26 +122,32 @@ if os.environ.get("TL_TEMPLATE_PATH", None) is None:
else: else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
# Cache control
_ENABLE_TILELANG_KERNEL_CACHE = True # Default cache state
def enable_cache():
"""Enable kernel caching globally."""
global _ENABLE_TILELANG_KERNEL_CACHE
_ENABLE_TILELANG_KERNEL_CACHE = True
def disable_cache():
"""Disable kernel caching globally."""
global _ENABLE_TILELANG_KERNEL_CACHE
_ENABLE_TILELANG_KERNEL_CACHE = False
def is_cache_enabled() -> bool:
"""Return current cache state."""
return _ENABLE_TILELANG_KERNEL_CACHE
# Cache control
class CacheState:
"""Class to manage global kernel caching state."""
_enabled = True
@classmethod
def enable(cls):
"""Enable kernel caching globally."""
cls._enabled = True
@classmethod
def disable(cls):
"""Disable kernel caching globally."""
cls._enabled = False
@classmethod
def is_enabled(cls) -> bool:
"""Return current cache state."""
return cls._enabled
# Replace the old functions with class methods
enable_cache = CacheState.enable
disable_cache = CacheState.disable
is_cache_enabled = CacheState.is_enabled
__all__ = [ __all__ = [
"CUTLASS_INCLUDE_DIR", "CUTLASS_INCLUDE_DIR",
......
...@@ -101,6 +101,7 @@ class TLCUDASourceWrapper(object): ...@@ -101,6 +101,7 @@ class TLCUDASourceWrapper(object):
"int8": "int8_t", "int8": "int8_t",
"uint8": "uint8_t", "uint8": "uint8_t",
"int16": "int16_t", "int16": "int16_t",
"uint16": "uint16_t",
"uchar": "uint8_t", "uchar": "uint8_t",
} }
......
...@@ -4,7 +4,13 @@ from typing import Optional ...@@ -4,7 +4,13 @@ from typing import Optional
# from .parser import * # from .parser import *
# now is fully compatible with the upstream # now is fully compatible with the upstream
# tir script # tir script
# TODO(lei): remove this import once the
# upstream tir script is fully compatible
from tvm.script.parser.tir import * from tvm.script.parser.tir import *
from .tir import (
prim_func, # noqa: F401
)
from .tir.ir import * # noqa: F401
from tilelang.layout import Layout, Fragment # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import ( from .proxy import (
ptr, # noqa: F401 ptr, # noqa: F401
......
...@@ -49,12 +49,30 @@ class BufferProxy: ...@@ -49,12 +49,30 @@ class BufferProxy:
return self(keys) return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> Buffer: def from_ptr(self,
return match_buffer(ptr, shape, dtype=dtype) pointer_var: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32") -> Buffer:
"""Create a buffer from a pointer, shape, and data type.
Args:
pointer_var: The pointer variable
shape: The shape of the buffer
dtype: The data type of the buffer (default: float32)
Returns:
A buffer created from the given parameters
"""
return match_buffer(pointer_var, shape, dtype=dtype)
class BaseTensorProxy: class BaseTensorProxy:
"""Base proxy class for tensor types with configurable defaults""" """Base proxy class for tensor types with configurable defaults.
This class serves as a foundation for different tensor proxy types, providing
customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations.
"""
default_scope = "global" default_scope = "global"
default_align = 0 default_align = 0
default_offset_factor = 0 default_offset_factor = 0
...@@ -97,23 +115,55 @@ class BaseTensorProxy: ...@@ -97,23 +115,55 @@ class BaseTensorProxy:
return self(keys) return self(keys)
return self(*keys) return self(*keys)
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> tir.Buffer: def from_ptr(self,
return match_buffer(ptr, shape, dtype=dtype) pointer_var: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32") -> tir.Buffer:
"""Create a buffer from a pointer, shape, and data type.
Args:
pointer_var: The pointer variable
shape: The shape of the buffer
dtype: The data type of the buffer (default: float32)
Returns:
A buffer created from the given parameters
"""
return match_buffer(pointer_var, shape, dtype=dtype)
class TensorProxy(BaseTensorProxy): class TensorProxy(BaseTensorProxy):
"""Main tensor proxy with default global scope""" """Main tensor proxy class for global scope buffers.
This class implements the default tensor proxy with global memory scope,
inheriting all functionality from BaseTensorProxy without modifications.
"""
class FragmentBufferProxy(BaseTensorProxy): class FragmentBufferProxy(BaseTensorProxy):
"""Proxy class for fragment memory buffers.
This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations.
"""
default_scope = "local.fragment" default_scope = "local.fragment"
class SharedBufferProxy(BaseTensorProxy): class SharedBufferProxy(BaseTensorProxy):
"""Proxy class for shared memory buffers.
This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations.
"""
default_scope = "shared.dyn" default_scope = "shared.dyn"
class LocalBufferProxy(BaseTensorProxy): class LocalBufferProxy(BaseTensorProxy):
"""Proxy class for local memory buffers.
This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels.
"""
default_scope = "local" default_scope = "local"
......
from .entry import prim_func # noqa: F401
from .ir import * # noqa: F401
from typing import Callable, Optional, Union
from tvm.tir.function import PrimFunc
import tvm.script.parser.tir.entry as _tir_entry
import inspect
from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Optional[Callable] = None,
private: bool = False,
check_well_formed=True) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
----------
func : Callable
The function to be parsed as prim func.
(Listed as optional to allow the decorator to be used
without arguments, like `@prim_func`,
or with an argument, `@prim_func(private=True)`)
private : bool, optional
Whether the function should be treated as private.
A private function has no global symbol attribute;
if the function is not private, it will have a global symbol
matching the function name.
Returns
-------
res : Union[PrimFunc, Callable]
The parsed tir prim func.
"""
# pylint: disable=unused-argument
# (private will be used in the parser, but not immediately)
# need to capture this var outside the wrapper because the wrapper
# adds to the stack
outer_stack = inspect.stack()
def decorator_wrapper(func):
if not inspect.isfunction(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
setattr(f, "__name__", func.__name__) # noqa: B010
return f
if func is not None:
# no optional args given => use wrapper directly
return decorator_wrapper(func)
else:
# if there is an optional arg given, return a new decorator
# that will then be invoked
setattr(decorator_wrapper, "dispatch_token", "tir") # noqa: B010
return decorator_wrapper
setattr(prim_func, "dispatch_token", "tir") # noqa: B010
def macro(*args, hygienic: bool = True) -> Callable:
"""Decorator for macro definitions.
Parameters
----------
hygienic: bool
Specifies whether the macro is hygienic or not.
A macro is hygienic if all symbols used in the macro's body are resolved
to values from the location of the macro definition. A non-hygienic macro
will have its symbols resolved to values at the time of the macro's use.
Example:
```
import tvm
from tvm.script import tir as T
x_value = 128
@T.macro(hygienic=True)
def static_capture(A, B):
B[()] = A[x_value] ### x_value binds to 128
@T.macro(hygienic=False)
def dynamic_capture(A, B):
B[()] = A[x_value] ### x_value will bind at the time of use
@T.prim_func
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
static_capture(A, B) ### Produces B[()] = A[128]
@T.prim_func
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B) ### Produces B[()] = A[x_value]
```
"""
def _decorator(func: Callable) -> _tir_entry.TIRMacro:
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = _tir_entry.TIRMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
return obj
if len(args) == 0:
return _decorator
if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0])
raise ValueError(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
setattr(macro, "dispatch_token", "tir") # noqa: B010
import tvm.script.ir_builder.tir.ir as _ir
from tvm.script.ir_builder.tir import frame
from tvm.tir import PrimExpr
from typing import Any, Dict
import tilelang.language.tir.op as _tir_op
import functools
def serial(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.serial(start=start, stop=stop, annotations=annotations)
def parallel(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.parallel(start=start, stop=stop, annotations=annotations)
def vectorized(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.vectorized(start=start, stop=stop, annotations=annotations)
def unroll(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.unroll(start=start, stop=stop, annotations=annotations)
def thread_binding(
start: PrimExpr,
stop: PrimExpr = None,
thread: str = None,
*,
annotations: Dict[str, Any] = None,
) -> frame.ForFrame:
"""The thread-binding For statement.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
thread : str
The thread for loop variable to bind.
annotations : Dict[str, Any]
The optional annotations of the For statement.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.thread_binding(start=start, stop=stop, thread=thread, annotations=annotations)
def grid(*extents: PrimExpr) -> frame.ForFrame:
"""The grid For statement.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ir.grid(*extents)
def _dtype_forward(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
args = (kwargs.pop("dtype"),) + args
return func(*args, **kwargs)
return wrapped
def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
kwargs.pop("dtype")
return func(*args, **kwargs)
return wrapped
abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin
acos = _op_wrapper(_tir_op.acos)
acosh = _op_wrapper(_tir_op.acosh)
address_of = _op_wrapper(_tir_op.address_of)
asin = _op_wrapper(_tir_op.asin)
asinh = _op_wrapper(_tir_op.asinh)
atan = _op_wrapper(_tir_op.atan)
atan2 = _op_wrapper(_tir_op.atan2)
atanh = _op_wrapper(_tir_op.atanh)
bitwise_and = _op_wrapper(_tir_op.bitwise_and)
bitwise_not = _op_wrapper(_tir_op.bitwise_not)
bitwise_or = _op_wrapper(_tir_op.bitwise_or)
bitwise_xor = _op_wrapper(_tir_op.bitwise_xor)
ceil = _op_wrapper(_tir_op.ceil)
clz = _op_wrapper(_tir_op.clz)
copysign = _op_wrapper(_tir_op.copysign)
cos = _op_wrapper(_tir_op.cos)
cosh = _op_wrapper(_tir_op.cosh)
erf = _op_wrapper(_tir_op.erf)
exp = _op_wrapper(_tir_op.exp)
exp2 = _op_wrapper(_tir_op.exp2)
exp10 = _op_wrapper(_tir_op.exp10)
floor = _op_wrapper(_tir_op.floor)
ceildiv = _op_wrapper(_tir_op.ceildiv)
floordiv = _op_wrapper(_tir_op.floordiv)
floormod = _op_wrapper(_tir_op.floormod)
fmod = _op_wrapper(_tir_op.fmod)
hypot = _op_wrapper(_tir_op.hypot)
if_then_else = _op_wrapper(_tir_op.if_then_else)
infinity = _op_wrapper(_tir_op.infinity)
isfinite = _op_wrapper(_tir_op.isfinite)
isinf = _op_wrapper(_tir_op.isinf)
isnan = _op_wrapper(_tir_op.isnan)
isnullptr = _op_wrapper(_tir_op.isnullptr)
ldexp = _op_wrapper(_tir_op.ldexp)
likely = _op_wrapper(_tir_op.likely)
log = _op_wrapper(_tir_op.log)
log1p = _op_wrapper(_tir_op.log1p)
log2 = _op_wrapper(_tir_op.log2)
log10 = _op_wrapper(_tir_op.log10)
lookup_param = _op_wrapper(_tir_op.lookup_param)
max_value = _op_wrapper(_tir_op.max_value)
min_value = _op_wrapper(_tir_op.min_value)
nearbyint = _op_wrapper(_tir_op.nearbyint)
nextafter = _op_wrapper(_tir_op.nextafter)
popcount = _op_wrapper(_tir_op.popcount)
pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin
q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis)
ret = _op_wrapper(_tir_op.ret)
round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
rsqrt = _op_wrapper(_tir_op.rsqrt)
shift_left = _op_wrapper(_tir_op.shift_left)
shift_right = _op_wrapper(_tir_op.shift_right)
sigmoid = _op_wrapper(_tir_op.sigmoid)
sin = _op_wrapper(_tir_op.sin)
sinh = _op_wrapper(_tir_op.sinh)
sqrt = _op_wrapper(_tir_op.sqrt)
tan = _op_wrapper(_tir_op.tan)
tanh = _op_wrapper(_tir_op.tanh)
trunc = _op_wrapper(_tir_op.trunc)
truncdiv = _op_wrapper(_tir_op.truncdiv)
truncmod = _op_wrapper(_tir_op.truncmod)
tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
tvm_check_return = _op_wrapper(_tir_op.tvm_check_return)
call_packed = _op_wrapper(_tir_op.call_packed)
call_cpacked = _op_wrapper(_tir_op.call_cpacked)
call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
tvm_struct_get = _tir_op.tvm_struct_get
tvm_thread_invariant = _op_wrapper(_tir_op.tvm_thread_invariant)
tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
tvm_storage_sync = _tir_op.tvm_storage_sync
tvm_warp_shuffle = _tir_op.tvm_warp_shuffle
tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up
tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
tvm_warp_activemask = _tir_op.tvm_warp_activemask
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count)
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic)
anylist_getitem = _op_wrapper(_tir_op.anylist_getitem)
anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem)
anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed)
anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked)
vscale = _op_wrapper(_tir_op.vscale)
reinterpret = _dtype_forward(_tir_op.reinterpret)
call_extern = _dtype_forward(_tir_op.call_extern)
call_intrin = _dtype_forward(_tir_op.call_intrin)
call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
ptx_mma = _dtype_forward(_tir_op.ptx_mma)
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
mma_store = _dtype_forward(_tir_op.mma_store)
mma_fill = _dtype_forward(_tir_op.mma_fill)
vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
tvm_mfma = _dtype_forward(_tir_op.tvm_mfma)
tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store)
tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma)
tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store)
from typing import Any, Optional
import tvm
from tvm.ir import PrimExpr
from tvm.ir.base import Span
from tvm.runtime import const
from tvm.tir.expr import IntImm, PrimExprWithOp
import tvm.tir.op as _tvm_op
def call_packed(*args, span=None):
"""Build expression by call an external packed function.
The argument to packed function can be Expr or Buffer.
The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc
will receive an TVMArrayHandle whose content is valid during the callback period.
If the PackedFunc is a python callback, then the corresponding argument is NDArray.
Parameters
----------
args : list of Expr or Buffer.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
te.extern : Create tensor with extern function call.
"""
return _tvm_op.call_packed(*args, span=span)
def call_cpacked(*args, span=None):
"""Build expression by call an external packed function.
Same as call_packed, except that the first argument is the function name
(as in call_extern), and the last argument is the resource handle.
Parameters
----------
args : list of Expr or Buffer.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
te.extern : Create tensor with extern function call.
"""
return _tvm_op.call_cpacked(*args, span=span)
def call_packed_lowered(*args, span=None):
"""Lowered version of call packed.
The argument to packed function can be Expr or Buffer.
The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc
will receive an TVMArrayHandle whose content is valid during the callback period.
If the PackedFunc is a python callback, then the corresponding argument is NDArray.
Parameters
----------
args : list of Expr or Buffer.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
te.extern : Create tensor with extern function call.
"""
return _tvm_op.call_packed_lowered(*args, span=span)
def call_cpacked_lowered(*args, span=None):
"""Lowered version of call c-packed.
Same as call_packed, except that the first argument is the function name
(as in call_extern), and the last argument is the resource handle.
Parameters
----------
args : list of Expr or Buffer.
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
te.extern : Create tensor with extern function call.
"""
return _tvm_op.call_cpacked_lowered(*args, span=span)
def call_intrin(dtype, func_name, *args, span=None):
"""Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via
the intrinsic translation rule.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The intrinsic function name.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.call_intrin(dtype, func_name, *args, span=span)
def call_pure_extern(dtype, func_name, *args, span=None):
"""Build expression by calling a pure extern function.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The extern function name.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.call_pure_extern(dtype, func_name, *args, span=span)
def call_extern(dtype, func_name, *args, span=None):
"""Build expression by calling a extern function.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The extern function name.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.call_extern(dtype, func_name, *args, span=span)
def call_llvm_intrin(dtype, name, *args, span=None):
"""Build expression by calling a llvm intrinsic function
Parameters
----------
dtype : str
The data type of the result.
name : str
The name of the llvm intrinsic function.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.call_llvm_intrin(dtype, name, *args, span=span)
def call_llvm_pure_intrin(dtype, name, *args, span=None):
"""Build expression by calling a pure llvm intrinsic function
Parameters
----------
dtype : str
The data type of the result.
name : str
The name of the llvm intrinsic function.
args : list
Positional arguments.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.call_llvm_pure_intrin(dtype, name, *args, span=span)
def tvm_check_return(expected, return_unexpected, nested_call):
"""Return new on stack dtype[num]
Parameters
----------
expected : int
The expected return code.
return_unexpected : int
The unexpected return code.
nested_call : PrimExpr
The call expression to check return.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_check_return(expected, return_unexpected, nested_call)
def tvm_stack_alloca(dtype_str, num):
"""Return new on stack dtype[num]
Parameters
----------
dtype_str : str
The data type of array.
num : int
The size of array.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_stack_alloca(dtype_str, num)
def tvm_stack_make_shape(*args):
"""Allocate a shape tuple on stack, return the handle
Parameters
----------
args : int
The tuple shape.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_stack_make_shape(*args)
def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset):
"""Allocate a NDArray(DLTensor) on stack, return the handle
Parameters
----------
data : Expr
The data of array.
shape : Expr
The shape of array.
strides : Expr
The strides of array.
ndim : Expr
The dimensions of array.
arr_dtype : Expr
The data type of array.
elem_offse : Expr
The element offset of array.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)
def assume(cond=None):
"""Provide a true statement that can be used for simplifications
Parameters
----------
cond : Expr
The constraint condition.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.assume(cond)
def undef():
"""Returns an initialized but arbitrary value
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.undef()
def call_tir(global_var: tvm.ir.GlobalVar, *args):
"""Performs a call into another PrimFunc in the same IRModule
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.call_tir(global_var, *args)
def start_profile_intrinsic(id):
"""Start profile intrinsic.
Parameters
----------
id : int
The intrinsic id.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.start_profile_intrinsic(id)
def end_profile_intrinsic(id):
"""End profile intrinsic.
Parameters
----------
id : int
The intrinsic id.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.end_profile_intrinsic(id)
def tvm_tuple(*value):
"""Create a tuple structure in value field of AttrStmt
Parameters
----------
value : Expr
The value in tuple.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_tuple(*value)
def tvm_struct_get(arr, index, field, dtype):
"""Get struct field value in array
Parameters
----------
dtype : str
The date type of the result.
arr : StructType*
The array of struct.
index : int
The index of struct.
field : int
The field of struct.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_struct_get(arr, index, field, dtype)
def tvm_struct_set(arr, index, field, value):
"""Set value in struct field in array
Parameters
----------
arr : StructType*
The array of struct.
index : int
The index of struct.
field : int
The field of struct.
value : Expr
The value to be set in field.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_struct_set(arr, index, field, value)
def address_of(buffer_load, span=None):
"""Returns the address of an element in the buffer
Parameters
----------
buffer_load: BufferLoad
The buffer load.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.address_of(buffer_load, span=span)
def lookup_param(param_name, span=None):
"""Returns the param by name
Parameters
----------
param_name : str
The name of param.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.lookup_param(param_name, span=span)
def tvm_thread_allreduce(*freduce_args):
"""Perform allreduce inside threadblock.
Parameters
----------
freduce_args : Expr
The args.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_thread_allreduce(*freduce_args)
def tvm_thread_invariant(cond):
"""Mark condition as thread invariant.
Parameters
----------
cond : Expr
The condition.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_thread_invariant(cond)
def tvm_storage_sync(storage_scope):
"""Perform synchronization in specified scope.
Parameters
----------
storage_scope : str
The storage scope to perform synchronization.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_storage_sync(storage_scope)
def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
"""Exchange value between threads inside a warp.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
warp_id : PrimExpr
The source lane index to fetch value.
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_warp_shuffle(mask, value, warp_id, width, warp_size)
def tvm_warp_shuffle_up(mask, value, offset, width, warp_size):
"""Copy value from a lane with lower (by offset) index relative to caller.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
offset : PrimExpr
The difference between source lane index and destination lane index:
`offset = dst_lane_idx - src_lane_idx`
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_warp_shuffle_up(mask, value, offset, width, warp_size)
def tvm_warp_shuffle_down(mask, value, offset, width, warp_size):
"""Copy value from a lane with higher (by offset) index relative to caller.
Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
offset : PrimExpr
The difference between source lane index and destination lane index:
`offset = src_lane_idx - dst_lane_idx`
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_warp_shuffle_down(mask, value, offset, width, warp_size)
def tvm_warp_activemask():
"""Return a 32-bit mask indicates currently active threads in a calling warp.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_warp_activemask()
def type_annotation(dtype):
"""Create a type annotation expression
Parameters
----------
dtype : Expr
The data type.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.type_annotation(dtype)
def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
"""Get head access address with memory access pattern info
Parameters
----------
ptype : Expr
The data type of pointer.
data : DType*
The data of pointer.
offset : int
The offset of pointer.
extent : int
The extent of pointer.
rw_mask : int
The read write mask.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
def tvm_throw_last_error():
"""Throw TVMGetLastError()
Returns
-------
ret : PrimExpr
The return expression
"""
return _tvm_op.tvm_throw_last_error()
def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
"""TVM intrinsic for tensor core load operators
Parameters
----------
fragment : Var
The wmma fragment.
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
buffer_ptr : Expr
The fragment buffer pointer.
stride : Expr
The fragment stride.
layout : Literal["row_major", "column_major"]
The fragment layout.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)
def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c,
index_c):
"""TVM intrinsic for tensor core mma_sync operators
Parameters
----------
fragment_d : Var
The wmma fragment_d.
index_d : Expr
The fragment_d index.
fragment_a : Var
The wmma fragment_a.
index_a : Expr
The fragment_a index.
fragment_b : Var
The wmma fragment_b.
index_b : Expr
The fragment_b index.
fragment_c : Var
The wmma fragment_c.
index_c : Expr
The fragment_c index.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
fragment_c, index_c)
def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c,
index_c):
"""TVM intrinsic for tensor core bmma_sync operators
Parameters
----------
fragment_d : Var
The bwmma fragment_d.
index_d : Expr
The fragment_d index.
fragment_a : Var
The bwmma fragment_a.
index_a : Expr
The fragment_a index.
fragment_b : Var
The bwmma fragment_b.
index_b : Expr
The fragment_b index.
fragment_c : Var
The bwmma fragment_c.
index_c : Expr
The fragment_c index.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
fragment_c, index_c)
def tvm_fill_fragment(fragment, m, n, k, index, value):
"""TVM intrinsic for tensor core fill_fragment operators
Parameters
----------
fragment : Var
The wmma fragment
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
value : Expr
The value to be filled in fragment.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_fill_fragment(fragment, m, n, k, index, value)
def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
"""TVM intrinsic for tensor core store operators
Parameters
----------
fragment : Var
The wmma fragment.
m : UIntImm
The shape of wmma fragment.
n : UIntImm
The shape of wmma fragment.
k : UIntImm
The shape of wmma fragment.
index : Expr
The fragment index.
buffer_ptr : Expr
The fragment buffer pointer.
stride : Expr
The fragment stride.
layout : Literal["row_major", "column_major"]
The fragment layout.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)
def ptx_mma(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
saturate,
operator=None,
):
"""TVM intrinsic for ptx tensor core mma instructions
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
Parameters
----------
dtype : str
The data type of the result.
shape : str
The shape of mma fragment.
A_layout : Literal["row", "col"]
The layout of multiplicand fragment A.
B_layout : Literal["row", "col"]
The layout of multiplicand fragment B.
A_dtype : str
The data type of multiplicand fragment A.
B_dtype : str
The data type of multiplicand fragment B.
C_dtype : str
The data type of accumulator fragment C.
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment A.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
saturate : bool
The optional saturation at the output.
operator : Optional[Literal["xor", "and"]]
The 1-bit operator.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_mma(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
saturate,
operator,
)
def ptx_mma_sp(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
metadata,
meta_index,
sparse_selector,
saturate,
):
"""TVM intrinsic for sparse tensor core ptx instructions
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
Parameters
----------
dtype : str
The data type of the result.
shape : str
The shape of mma fragment.
A_layout : Literal["row", "col"]
The layout of multiplicand fragment A.
B_layout : Literal["row", "col"]
The layout of multiplicand fragment B.
A_dtype : str
The data type of multiplicand fragment A.
B_dtype : str
The data type of multiplicand fragment B.
C_dtype : str
The data type of accumulator fragment C.
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment B.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
metadata : Expr
The metadata of operand.
meta_index : Expr
The metadata index of operand.
sparse_selector : Expr
The sparse selector indicating the thread that stores the metadata.
saturate : bool
The optional saturation at the output.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_mma_sp(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
metadata,
meta_index,
sparse_selector,
saturate,
)
def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
Parameters
----------
dtype : str
The data type of the result.
m : IntImm
The shape of mma fragment.
n : IntImm
The shape of mma fragment.
dst_ptr : Var
The destination pointer variable.
src_ptr : Var
The source pointer variable.
src_offset : Expr
The source offset.
dst_stride : Var
The destination stride.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)
def mma_fill(dtype, local_size, local_ptr, offset):
"""TVM intrinsic for zero-initalizing an MMA accumulation register
Parameters
----------
dtype : str
The data type of the result.
local_size : IntImm
The number of elements.
local_ptr : Var
The destination pointer variable.
offset : Expr
The destination offset.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.mma_fill(dtype, local_size, local_ptr, offset)
def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
"""TVM intrinsic for ptx load matrix from shared memory
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
Parameters
----------
dtype : str
The data type of the result.
trans : bool
The matrix is loaded in column-major format.
num : IntImm
The number of matrices.
type : Literal[".b16"]
The data type of the matrices.
local_ptr : Var
The local pointer variable.
local_offset : Expr
The offset of local pointer.
smem_ptr : Var
The shared memory pointer variable.
smem_offset : Expr
The offset of shared memort pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr,
smem_offset)
def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
Parameters
----------
dtype : str
The data type of the result.
shared_ptr : Var
The shared memory pointer variable.
shared_offset : Expr
The offset of shared memory pointer.
global_ptr : Var
The global memory pointer variable.
global_offset : Expr
The offset of global memory pointer.
bytes : int
The data size to copy.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)
def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes,
barrier_id):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
Parameters
----------
dtype : str
The data type of the result.
shared_ptr : Var
The shared memory pointer variable.
shared_offset : Expr
The offset of shared memory pointer.
global_ptr : Var
The global memory pointer variable.
global_offset : Expr
The offset of global memory pointer.
bytes : int
The data size to copy.
barrier_id : int
The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset,
bytes, barrier_id)
def ptx_commit_group():
"""TVM intrinsic for ptx async copy commit
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_commit_group()
def ptx_wait_group(num):
"""TVM intrinsic for ptx async copy wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group
Parameters
----------
num : int
The number of the most recent uncommitted pending cp.async groups to wait.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_wait_group(num)
def tvm_mfma(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
):
"""TVM intrinsic for amd matrix core mfma instructions
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
Parameters
----------
dtype : str
The data type of the result.
shape : str
The shape of mma fragment.
A_layout : Literal["row", "col"]
The layout of multiplicand fragment A.
B_layout : Literal["row", "col"]
The layout of multiplicand fragment B.
A_dtype : str
The data type of multiplicand fragment A.
B_dtype : str
The data type of multiplicand fragment B.
C_dtype : str
The data type of accumulator fragment C.
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment A.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_mfma(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
)
def tvm_mfma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
Parameters
----------
dtype : str
The data type of the result.
m : IntImm
The shape of mma fragment.
n : IntImm
The shape of mma fragment.
dst_ptr : Var
The destination pointer variable.
src_ptr : Var
The source pointer variable.
src_offset : Expr
The source offset.
dst_stride : Var
The destination stride.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_mfma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)
def tvm_rdna_wmma(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
):
"""TVM intrinsic for amd matrix core mfma instructions
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
Parameters
----------
dtype : str
The data type of the result.
shape : str
The shape of mma fragment.
A_layout : Literal["row", "col"]
The layout of multiplicand fragment A.
B_layout : Literal["row", "col"]
The layout of multiplicand fragment B.
A_dtype : str
The data type of multiplicand fragment A.
B_dtype : str
The data type of multiplicand fragment B.
C_dtype : str
The data type of accumulator fragment C.
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment A.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_rdna_wmma(
dtype,
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
)
def tvm_rdna_wmma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
Parameters
----------
dtype : str
The data type of the result.
m : IntImm
The shape of mma fragment.
n : IntImm
The shape of mma fragment.
dst_ptr : Var
The destination pointer variable.
src_ptr : Var
The source pointer variable.
src_offset : Expr
The source offset.
dst_stride : Var
The destination stride.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_rdna_wmma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)
def ptx_cp_async_barrier(barrier_id):
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_cp_async_barrier(barrier_id)
def ptx_init_barrier_thread_count(barrier_id, thread_count):
"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
thread_count : int
Number of threads expected to arrive at the barrier.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_init_barrier_thread_count(barrier_id, thread_count)
def ptx_arrive_barrier(barrier_id):
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_arrive_barrier(barrier_id)
def ptx_arrive_barrier_expect_tx(barrier_id, byte_count):
"""TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
byte_count : int
Increases the tx count of the mbarrier object to track completion of
additional async transactions.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_arrive_barrier_expect_tx(barrier_id, byte_count)
def ptx_wait_barrier(barrier_id):
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
Parameters
----------
barrier_id : int
The ID of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_wait_barrier(barrier_id)
def create_barriers(barrier_count):
"""TVM intrinsic to create N barriers
Parameters
----------
barrier_count : int
The number of barriers to create.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.create_barriers(barrier_count)
def vectorlow(dtype, vec):
"""Get the low level half of the vector
Parameters
----------
dtype : str
The data type of the result.
vec : list
The input vector.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.vectorlow(dtype, vec)
def vectorhigh(dtype, vec):
"""Get the high level half of the vector
Parameters
----------
dtype : str
The data type of the result.
vec : list
The input vector.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.vectorhigh(dtype, vec)
def vectorcombine(dtype, vec1, vec2):
"""Concat two vectors
Parameters
----------
vec1 : list
The input vector.
vec2 : list
The input vector.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.vectorcombine(dtype, vec1, vec2)
def ret(val):
"""Create a tir return expression
Parameters
----------
val : Expr
The returned tir expression, whose data type is int, float or void pointer.
Returns
-------
ret : PrimExpr
The return expression
"""
return _tvm_op.ret(val)
def any(*args, span=None):
"""Create a new expression of the union of all conditions in the arguments
Parameters
----------
args : list
List of symbolic boolean expressions
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
expr: Expr
Expression
"""
return _tvm_op.any(*args, span=span)
def all(*args, span=None):
"""Create a new expression of the intersection of all conditions in the
arguments
Parameters
----------
args : list
List of symbolic boolean expressions
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
expr: Expr
Expression
"""
return _tvm_op.all(*args, span=span)
def trace(args, trace_action="tvm.default_trace_action"):
"""Trace tensor data at the runtime.
The trace function allows to trace specific tensor at the
runtime. The tracing value should come as last argument.
The trace action should be specified, by default
tvm.default_trace_action is used.
Parameters
----------
args : list of Expr or Buffers.
Positional arguments.
trace_action : str.
The name of the trace action.
Returns
-------
call : PrimExpr
The call expression.
See Also
--------
tvm.tir.call_packed : Creates packed function.
"""
return _tvm_op.trace(args, trace_action)
def min_value(dtype, span=None):
"""minimum value of dtype
Parameters
----------
dtype : str
The data type.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
value : tvm.Expr
The minimum value of dtype.
"""
return _tvm_op.min_value(dtype, span)
def max_value(dtype: str, span: Optional[Span] = None) -> Any:
"""maximum value of dtype
Parameters
----------
dtype : str
The data type.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
value : tvm.Expr
The maximum value of dtype.
"""
return _tvm_op.max_value(dtype, span)
def infinity(dtype: str, span: Optional[Span] = None) -> Any:
"""infinity value of dtype
Parameters
----------
dtype : str
The data type.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
value : tvm.Expr
The infinity value of dtype.
"""
return _tvm_op.infinity(dtype, span)
def reinterpret(dtype, value, span: Optional[Span] = None) -> Any:
"""infinity value of dtype
Parameters
----------
dtype : str
The data type.
value : PrimExpr
The input value.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
value : tvm.Expr
The reinterpret cast value of dtype.
"""
return _tvm_op.reinterpret(dtype, value, span)
def exp(x):
"""Take exponential of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.exp(x)
def exp2(x):
"""Calculate 2**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.exp2(x)
def exp10(x):
"""Calculate 10**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.exp10(x)
def erf(x):
"""Take gauss error function of the input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.erf(x)
def tanh(x):
"""Take hyperbolic tanh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.tanh(x)
def sigmoid(x):
"""Quick function to get sigmoid
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.sigmoid(x)
def log(x):
"""Take log of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.log(x)
def log2(x):
"""Take log2 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.log2(x)
def log10(x):
"""Take log10 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.log10(x)
def log1p(x):
"""Take log(x + 1) with respect to input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.log1p(x)
def tan(x):
"""Take tan of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.tan(x)
def cos(x):
"""Take cos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.cos(x)
def cosh(x):
"""Take cosh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.cosh(x)
def acos(x):
"""Take acos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.acos(x)
def acosh(x):
"""Take acos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.acosh(x)
def sin(x):
"""Take sin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.sin(x)
def sinh(x):
"""Take sinh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.sinh(x)
def asin(x):
"""Take asin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.asin(x)
def asinh(x):
"""Take asinh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.asinh(x)
def atan(x):
"""Take atan of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.atan(x)
def atanh(x):
"""Take atanh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.atanh(x)
def atan2(x1, x2):
"""Take arctan2(x1, x2).
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.atan2(x1, x2)
def sqrt(x):
"""Take square root of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.sqrt(x)
def rsqrt(x):
"""Take reciprocal of square root of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.rsqrt(x)
def clz(x):
"""Count leading zero bits of an integer x.
Parameters
----------
x : PrimExpr
Input 32 or 64 bit integer.
The result is undefined if the input is 0.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.clz(x)
def floor(x: PrimExprWithOp, span=None):
"""Take floor of float input x.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.floor(x, span)
def ceil(x, span=None):
"""Take ceil of float input x.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.ceil(x, span)
def trunc(x, span=None):
"""Get truncated value of the input.
The truncated value of the scalar x is the
nearest integer i which is closer to zero than x is.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.trunc(x, span)
def abs(x, span=None):
"""Get absolute value of the input element-wise.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.abs(x, span)
def bitwise_and(x, y, span=None):
"""Take bitwise and of two values
Parameters
----------
x : PrimExpr
Left operand
y : PrimExpr
Right operand
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
res : PrimExpr
The result.
"""
return _tvm_op.bitwise_and(x, y, span)
def bitwise_not(x, span=None):
"""Take bitwise not of input value
Parameters
----------
x : PrimExpr
Input operand
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
res : PrimExpr
The result.
"""
return _tvm_op.bitwise_not(x, span)
def bitwise_or(x, y, span=None):
"""Take bitwise or of two values
Parameters
----------
x : PrimExpr
Left operand
y : PrimExpr
Right operand
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
res : PrimExpr
The result.
"""
return _tvm_op.bitwise_or(x, y, span)
def bitwise_xor(x, y, span=None):
"""Take bitwise xor of two values
Parameters
----------
x : PrimExpr
Left operand
y : PrimExpr
Right operand
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
res : PrimExpr
The result.
"""
return _tvm_op.bitwise_xor(x, y, span)
def round(x, span=None):
"""Round elements of the array to the nearest integer.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.round(x, span)
def nearbyint(x, span=None):
"""Round elements of the array to the nearest integer.
This intrinsic uses llvm.nearbyint instead of llvm.round
which is faster but will results different from te.round.
Notably nearbyint rounds according to the rounding mode,
whereas te.round (llvm.round) ignores that.
For differences between the two see:
https://en.cppreference.com/w/cpp/numeric/math/round
https://en.cppreference.com/w/cpp/numeric/math/nearbyint
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.nearbyint(x, span)
def nextafter(x1, x2):
"""Return the next floating-point value after x1 towards x2.
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.nextafter(x1, x2)
def hypot(x1, x2):
"""Equivalent to sqrt(x1**2 + x2**2), element-wise.
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.hypot(x1, x2)
def copysign(x1, x2):
"""Change the sign of x1 to that of x2, element-wise.
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.copysign(x1, x2)
def ldexp(x1, x2):
"""Returns x1 * (2 ** x2).
Parameters
----------
x1 : PrimExpr
Input argument.
x2 : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.ldexp(x1, x2)
def likely(cond, span=None):
"""Mark condition as likely.
Parameters
----------
cond : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The marked expression.
"""
return _tvm_op.likely(cond, span)
def isnan(x, span=None):
"""Check if input value is Nan.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.isnan(x, span)
def isnullptr(x, span=None):
"""Check if input value is nullptr.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.isnullptr(x, span)
def isfinite(x, span=None):
"""Check if input value is finite.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.isfinite(x, span)
def isinf(x, span=None):
"""Check if input value is infinite.
Parameters
----------
x : PrimExpr
Input argument.
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.isinf(x, span)
def power(x, y, span=None):
"""x power y
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
The exponent
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
z : PrimExpr
The result.
"""
return _tvm_op.power(x, y, span)
def pow(x, y, span=None):
"""x power y
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
The exponent
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
z : PrimExpr
The result.
"""
return _tvm_op.pow(x, y, span)
def popcount(x):
"""Count the number of set bits in input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.popcount(x)
def q_multiply_shift(x, y, q, s):
"""Execute a multiplication between two Q-numbers x and y
followed by a right shift s. The mathematical expression is:
out = round(x*y*2^-s)
More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
The rounding rule is to the nearest value, rounding half up
(i.e., round(x.1) = x and round (x.5) = x+1)
Parameters
----------
x : PrimExpr
First Q-number
y : PrimExpr
Second Q-number
q : PrimExpr
Number of fractional bits in x and y. Needs to be > 0
s : PrimExpr
Integer shift
Returns
-------
y : PrimExpr
The result.
"""
return _tvm_op.q_multiply_shift(x, y, q, s)
def q_multiply_shift_per_axis(
x: PrimExpr,
y: PrimExpr,
ls: PrimExpr,
rs: PrimExpr,
q: IntImm,
is_lshift_required: IntImm,
is_rshift_required: IntImm,
):
"""Execute a multiplication between two Q-numbers x and y
Parameters
----------
x : PrimExpr
First Q-number.
y : PrimExpr
Second Q-number.
ls : PrimExpr
Integer left shift.
rs : PrimExpr
Integer right shift.
q : IntImm
Number of fractional bits in x and y. Needs to be > 0.
is_lshift_required : IntImm
Whether we need to do left shift or not.
is_rshift_required : IntImm
Whether we need to do right shift or not.
Returns
-------
z : PrimExpr
The result.
"""
return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required,
is_rshift_required)
def shift_left(x, y, span=None):
"""Return the result of x left shifted by y bits.
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
Input argument.
Returns
-------
z : PrimExpr
The result.
"""
return _tvm_op.shift_left(x, y, span)
def shift_right(x, y, span=None):
"""Return the result of x right shifted by y bits.
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
Input argument.
Returns
-------
z : PrimExpr
The result.
"""
return _tvm_op.shift_right(x, y, span)
def fmod(x, y):
"""Return the remainder of x divided by y with the same sign as x.
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
Input argument.
Returns
-------
z : PrimExpr
The result.
"""
return _tvm_op.fmod(x, y)
def if_then_else(cond, t, f, span=None):
"""Conditional selection expression.
Parameters
----------
cond : PrimExpr
The condition
t : PrimExpr
The result expression if cond is true.
f : PrimExpr
The result expression if cond is false.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
result : Node
The result of conditional expression.
Note
----
Unlike Select, if_then_else will not execute
the branch that does not satisfy the condition.
You can use it to guard against out of bound access.
Unlike Select, if_then_else cannot be vectorized
if some lanes in the vector have different conditions.
"""
return _tvm_op.if_then_else(cond, t, f, span)
def div(a, b, span=None):
"""Compute a / b as in C/C++ semantics.
Parameters
----------
a : PrimExpr
The left hand operand, known to be non-negative.
b : PrimExpr
The right hand operand, known to be non-negative.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
When operands are integers, returns truncdiv(a, b, span).
"""
return _tvm_op.div(a, b, span)
def indexdiv(a, b, span=None):
"""Compute floor(a / b) where a and b are non-negative.
Parameters
----------
a : PrimExpr
The left hand operand, known to be non-negative.
b : PrimExpr
The right hand operand, known to be non-negative.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return _tvm_op.indexdiv(a, b, span)
def indexmod(a, b, span=None):
"""Compute the remainder of indexdiv. a and b are non-negative.
Parameters
----------
a : PrimExpr
The left hand operand, known to be non-negative.
b : PrimExpr
The right hand operand, known to be non-negative.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
Use this function to split non-negative indices.
This function may take advantage of operands'
non-negativeness.
"""
return _tvm_op.indexmod(a, b, span)
def truncdiv(a, b, span=None):
"""Compute the truncdiv of two expressions.
Parameters
----------
a : PrimExpr
The left hand operand
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
This is the default integer division behavior in C.
"""
return _tvm_op.truncdiv(a, b, span)
def truncmod(a, b, span=None):
"""Compute the truncmod of two expressions.
Parameters
----------
a : PrimExpr
The left hand operand
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
Note
----
This is the default integer division behavior in C.
"""
return _tvm_op.truncmod(a, b, span)
def floordiv(a, b, span=None):
"""Compute the floordiv of two expressions.
Parameters
----------
a : PrimExpr
The left hand operand
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
"""
return _tvm_op.floordiv(a, b, span)
def floormod(a, b, span=None):
"""Compute the floormod of two expressions.
Parameters
----------
a : PrimExpr
The left hand operand
b : PrimExpr
The right hand operand
span : Optional[Span]
The location of this operator in the source.
Returns
-------
res : PrimExpr
The result expression.
"""
return _tvm_op.floormod(a, b, span)
def ceildiv(lhs, rhs, span=None):
"""Generic ceildiv operator.
Parameters
----------
lhs : object
The left operand.
rhs : object
The right operand.
span : Optional[Span]
The location of this operator in the source.
Returns
-------
op : tvm.Expr
The result Expr of ceildiv operation.
"""
return _tvm_op.ceildiv(lhs, rhs, span)
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
Parameters
----------
fcombine : function(Expr -> Expr -> Expr)
A binary function which takes two Expr as input to return a Expr.
fidentity : function(str -> Expr)
A function which takes a type string as input to return a const Expr.
Returns
-------
reducer : function
A function which creates a reduce expression over axis.
There are two ways to use it:
1. accept (expr, axis, where) to produce an Reduce Expr on
specified axis;
2. simply use it with multiple Exprs.
Example
-------
.. code-block:: python
n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
"""
return _tvm_op.comm_reducer(fcombine, fidentity, name)
def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
"""Backend function to allocate temporal workspace
Parameters
----------
device_type : int
The device type which the space will be allocated.
device_id : int
The device id which the space will be allocated.
nbytes : int
The size of the space requested.
dtype_code_hint : int
The type code of the array elements. Only used in certain backends such as OpenGL.
dtype_bits_hint : int
The type bits of the array elements. Only used in certain backends such as OpenGL.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint,
dtype_bits_hint)
def TVMBackendFreeWorkspace(device_type, device_id, ptr):
"""Backend function to free temporal workspace.
Parameters
----------
device_type : int
The device type which the space will be allocated.
device_id : int
The device id which the space will be allocated.
ptr : Var
The result allocated space pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.TVMBackendFreeWorkspace(device_type, device_id, ptr)
def anylist_getitem(list_handle, index):
"""Returns an item from any list.
list_handle: Var
The handle to anylist
index : int
The index
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.anylist_getitem(list_handle, index)
def anylist_resetitem(list_handle, index):
"""Reset an item from any list.
list_handle: Var
The handle to anylist
index : int
The index
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.anylist_resetitem(list_handle, index)
def anylist_setitem_call_packed(list_handle, index, func_name, *args):
"""Set anylist item by result of packed call.
list_handle: Var
The handle to anylist
index : int
The index
func_name: str
The name of the function to be called.
args:
Extra arguments
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.anylist_setitem_call_packed(list_handle, index, func_name, *args)
def anylist_setitem_call_cpacked(list_handle, index, func_name, *args):
"""Set anylist item by result of packed call.
list_handle: Var
The handle to anylist
index : int
The index
func_name: str
The name of the function to be called.
args:
Extra arguments
Returns
-------
call : PrimExpr
The call expression.
"""
return _tvm_op.anylist_setitem_call_cpacked(list_handle, index, func_name, *args)
def vscale():
"""Get the target's vscale value. It will be lowered to llvm.vscale intrinsic
(https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic)
Returns
-------
call : PrimExpr
Call to the vscale intrinsic
"""
return _tvm_op.vscale()
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _tvm_op._OpMin(x, y, None), max_value, name="min")
max = comm_reducer(lambda x, y: _tvm_op._OpMax(x, y, None), min_value, name="max")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment