customize.py 13.7 KB
Newer Older
1
2
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
3
4
"""The language interface for tl programs."""

5
import tilelang.language as T
6
7
from tvm import ir
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op
8
from typing import List, Union
9

10
11
12
13
14
15
16
17
18
_MEMORY_ORDER_ID_MAP = {
    "relaxed": 0,
    "consume": 1,
    "acquire": 2,
    "release": 3,
    "acq_rel": 4,
    "seq_cst": 5,
}

19

20
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
21
22
23
24
25
26
27
28
29
30
31
    """
    Create a tile memory-region descriptor for a BufferLoad.
    
    Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic
    (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.
    
    Parameters:
        buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices.
        access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access.
        *args (tir.PrimExpr): Extent expressions for each region dimension.
    
32
    Returns:
33
34
35
36
        tir.Call: A call to the `tl.region` intrinsic describing the memory region.
    
    Raises:
        KeyError: If access_type is not one of 'r', 'w', or 'rw'.
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    """
    access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
    return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args)


def buffer_to_tile_region(buffer: Buffer, access_type: str):
    """Convert a TVM buffer to a tile region descriptor.

    Args:
        buffer (tir.Buffer): The buffer to convert
        access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write

    Returns:
        tir.Call: A region descriptor covering the entire buffer
    """
    mins = [0 for _ in buffer.shape]
    extents = [x for x in buffer.shape]
    return region(T.BufferLoad(buffer, mins), access_type, *extents)


def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]):
    """Convert a buffer load operation to a tile region descriptor.

    Args:
        load (tir.BufferLoad): The buffer load operation
        access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
        extents (List[tir.PrimExpr]): List of expressions defining the region size

    Returns:
        tir.Call: A region descriptor for the loaded area
    """
    indices = load.indices
    if len(indices) > len(extents):
        # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
        # f"region will be expanded in the last 2 dimensions")
        new_extents = []
        for _ in range(len(indices) - len(extents)):
            new_extents.append(1)
        for extent in extents:
            new_extents.append(extent)
        extents = new_extents
    assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
    return region(load, access_type, *extents)


def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
                                 extents: List[PrimExpr]):
    """
85
86
87
88
89
90
91
92
93
94
95
96
97
                                 Create a tl region descriptor for the given BufferRegion.
                                 
                                 Parameters:
                                     buffer_region (tir.BufferRegion): Source buffer region whose `region` items provide mins and extents.
                                     access_type (str): Access mode: "r", "w", or "rw".
                                     extents (List[PrimExpr]): Requested extents; must have length <= the number of extents in buffer_region.region.
                                 
                                 Returns:
                                     tir.Call: A tile-region descriptor (tl.region) covering the buffer_region.
                                 
                                 Raises:
                                     AssertionError: If the number of extents in buffer_region.region is smaller than len(extents).
                                 """
98
99
100
101
102
103
104
105
106
    mins = [x.min for x in buffer_region.region]
    region_extents = [x.extent for x in buffer_region.region]
    assert len(region_extents) >= len(
        extents
    ), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"

    return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)


107
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
108
109
110
111
112
113
114
115
116
117
118
    """
    Perform an atomic maximum on the value stored at dst with an optional memory-order.
    
    If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern.
    
    Parameters:
        dst (Buffer): Destination buffer/address to apply the atomic max.
        value (PrimExpr): Value to compare/store atomically.
        memory_order (str | None): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
            If provided, it is translated to the corresponding numeric memory-order id before the call.
    
119
    Returns:
120
        PrimExpr: A handle/expression representing the issued atomic maximum operation.
121
122
123
124
125
126
127
128
129
    """
    if memory_order is None:
        return T.call_extern("handle", "AtomicMax", T.address_of(dst), value)
    else:
        return T.call_extern("handle", "AtomicMax", T.address_of(dst), value,
                             _MEMORY_ORDER_ID_MAP[memory_order])


def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
130
131
132
133
134
135
136
137
138
139
    """
    Atomically update the value at dst to the minimum of its current value and value.
    
    If memory_order is provided, it selects the memory-order semantic used by the underlying extern call;
    allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally
    to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument.
    
    Parameters:
        memory_order (str | None): Optional memory-order name controlling the atomic operation's ordering.
    
140
    Returns:
141
        PrimExpr: A handle expression representing the atomic-min operation.
142
143
144
145
146
147
148
149
150
    """
    if memory_order is None:
        return T.call_extern("handle", "AtomicMin", T.address_of(dst), value)
    else:
        return T.call_extern("handle", "AtomicMin", T.address_of(dst), value,
                             _MEMORY_ORDER_ID_MAP[memory_order])


def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
151
152
153
154
155
    """
    Atomically add `value` into `dst`, returning a handle to the operation.
    
    Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`.
    
156
    Returns:
157
        PrimExpr: A handle representing the atomic addition operation.
158
    """
159
160

    def get_extent(data):
161
162
163
164
165
166
167
168
169
170
        """
        Return the inferred extent (shape) of a buffer-like object.
        
        If `data` is a Var bound to a let value, the let value is resolved before inspection.
        Parameters:
            data: A Var, Buffer, or BufferRegion to inspect.
        
        Returns:
            The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined.
        """
171
172
173
174
175
176
177
178
179
180
181
        if isinstance(data, Var) and T.has_let_value(data):
            data = T.get_let_value(data)
        if isinstance(data, Buffer):
            return data.shape
        elif isinstance(data, BufferRegion):
            return [x.extent for x in data.region]
        else:
            return None

    src_extent = get_extent(value)
    dst_extent = get_extent(dst)
182
183
184
185
186
187
188
189
190
191
192

    if dst_extent is None and src_extent is None:
        if memory_order is None:
            return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
        else:
            return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value,
                                 _MEMORY_ORDER_ID_MAP[memory_order])

    if isinstance(dst, Buffer) and isinstance(value, Buffer):
        ir.assert_structural_equal(dst.shape, value.shape)

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
    src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
    dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
    extent = max(src_extent, dst_extent)

    def _to_region(data, access_type):
        if isinstance(data, Var) and T.has_let_value(data):
            data = T.get_let_value(data)
        if isinstance(data, Buffer):
            return buffer_to_tile_region(data, access_type)
        elif isinstance(data, BufferRegion):
            return buffer_region_to_tile_region(data, access_type, extent)
        else:
            return buffer_load_to_tile_region(data, access_type, extent)

    value = _to_region(value, "r")
    dst = _to_region(dst, "w")
    return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst)
211
212


213
def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr:
214
215
216
217
218
219
220
221
222
    """Perform an atomic addition operation with double-width operands.

    Args:
        dst (Buffer): Destination buffer where the atomic addition will be performed
        value (PrimExpr): Value to be atomically added (double-width)

    Returns:
        PrimExpr: Handle to the double-width atomic addition operation
    """
223
    return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value))
224
225


226
def atomic_addx4(dst: Buffer, value: PrimExpr) -> PrimExpr:
227
    """Perform an atomic addition operation with quad-width operands.
228
229
230

    Args:
        dst (Buffer): Destination buffer where the atomic addition will be performed
231
        value (PrimExpr): Value to be atomically added (quad-width)
232
233

    Returns:
234
        PrimExpr: Handle to the quad-width atomic addition operation
235
236
237
238
    """
    return T.call_extern("handle", "AtomicAddx4", T.address_of(dst), T.address_of(value))


239
def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
240
241
242
243
244
245
246
247
248
249
    """Perform a 4-element dot product with accumulation (DP4A).

    Args:
        A (Buffer): First input buffer
        B (Buffer): Second input buffer
        C (Buffer): Accumulation buffer

    Returns:
        PrimExpr: Handle to the DP4A operation
    """
250
    return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
251
252


253
254
255
256
257
258
259
260
261
262
263
264
265
def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
    """Clamps the input value dst between [min_val, max_val]
    
    Args:
        dst: Input value to be clamped
        min_val: Minimum value
        max_val: Maximum value
    
    Returns:
        Value clamped to the specified range
    """
    dst = T.max(dst, min_val)  # Ensure value is not less than minimum
    dst = T.min(dst, max_val)  # Ensure value is not greater than maximum
266
    return dst
267
268
269
270
271
272


def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
    """Reshapes the input buffer to the specified shape.
    
    Args:
273
274
275
276
277
        src (Buffer): Input buffer to be reshaped
        shape (List[PrimExpr]): New shape for the buffer

    Returns:
        Buffer: A new buffer view with the specified shape
278
    """
279
    return T.Tensor(shape, src.dtype, src.data)
280
281
282
283
284
285


def view(src: Buffer,
         shape: Union[List[PrimExpr], None] = None,
         dtype: Union[str, None] = None) -> Buffer:
    """
286
287
288
289
         Return a Tensor view of the input buffer with an optional new shape and dtype.
         
         If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy).
         """
290
291
292
293
    if shape is None:
        shape = src.shape
    if dtype is None:
        dtype = src.dtype
294
    return T.Tensor(shape, dtype, src.data)
295
296
297


def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
298
299
300
301
302
303
304
    """
    Load a value from the given buffer using the specified atomic memory ordering.
    
    Performs an atomic load from `src` and returns a PrimExpr representing the loaded value.
    memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire",
    "release", "acq_rel", or "seq_cst" (default).
    Raises KeyError if an unknown memory_order is provided.
305
306
307
308
309
310
    """
    return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src),
                         _MEMORY_ORDER_ID_MAP[memory_order])


def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
311
312
313
314
315
316
317
318
319
320
    """
    Perform an atomic store of `src` into `dst` with the given memory ordering.
    
    Parameters:
        dst (Buffer): Destination buffer to store into.
        src (PrimExpr): Value to store.
        memory_order (str, optional): Memory ordering name; one of "relaxed", "consume",
            "acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst".
            The name is mapped to an internal numeric ID used by the underlying runtime.
    
321
    Returns:
322
323
324
325
        PrimExpr: A handle representing the issued atomic store operation.
    
    Raises:
        KeyError: If `memory_order` is not one of the supported names.
326
327
328
    """
    return T.call_extern("handle", "AtomicStore", T.address_of(dst), src,
                         _MEMORY_ORDER_ID_MAP[memory_order])