customize.py 2.67 KB
Newer Older
1
"""The language interface for tl programs."""
2
from __future__ import annotations
3
import tilelang.language as T
4
from tvm.tir import PrimExpr, Buffer, op
5
from tilelang.utils.language import (bits_product, prim_expr_equal)
6
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store  # noqa: F401
7

8

9
def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
10
11
12
13
14
15
16
17
18
19
    """Perform a 4-element dot product with accumulation (DP4A).

    Args:
        A (Buffer): First input buffer
        B (Buffer): Second input buffer
        C (Buffer): Accumulation buffer

    Returns:
        PrimExpr: Handle to the DP4A operation
    """
20
    return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
21
22


23
24
def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
    """Clamps the input value dst between [min_val, max_val]
25

26
27
28
29
    Args:
        dst: Input value to be clamped
        min_val: Minimum value
        max_val: Maximum value
30

31
32
33
34
35
    Returns:
        Value clamped to the specified range
    """
    dst = T.max(dst, min_val)  # Ensure value is not less than minimum
    dst = T.min(dst, max_val)  # Ensure value is not greater than maximum
36
    return dst
37
38


39
def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
40
    """Reshapes the input buffer to the specified shape.
41

42
    Args:
43
44
45
46
47
        src (Buffer): Input buffer to be reshaped
        shape (List[PrimExpr]): New shape for the buffer

    Returns:
        Buffer: A new buffer view with the specified shape
48
    """
49
50
51
    assert prim_expr_equal(
        bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)
    ), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}"
52
    return T.Tensor(shape, src.dtype, src.data)
53
54


55
def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = None) -> Buffer:
56
    """Return a Tensor view of the input buffer with an optional new shape and dtype.
57

58
59
    If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy).
    """
60
61
62
63
    if shape is None:
        shape = src.shape
    if dtype is None:
        dtype = src.dtype
64
65
    assert prim_expr_equal(bits_product(shape, dtype),
                           bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
66
    return T.Tensor(shape, dtype, src.data)
67
68


69
70
def loop_break():
    """Break out of the current loop.
71

72
    Returns:
73
        tir.Call: A call to the `tl.loop_break` intrinsic.
74
    """
75
    return T.call_intrin("handle", op.Op.get("tl.loop_break"))