fragment_mma_load_a.py 4.32 KB
Newer Older
1
2
3
4
5
6
7
import tilelang.language as T
from typing import Literal, Callable
from tvm import DataType
from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size


8
def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment:
9
10
11
12
13
14
15
16
17
    """
    Create a layout function for storing MMA results into a fragment buffer.
    This layout is used in conjunction with `inverse_mma_store_layout` to
    map fragment indices to threads and local indices.

    Parameters
    ----------
    dtype : str
        The data type of the matrix.
Jiaxing Ding's avatar
Jiaxing Ding committed
18
19
20
21
    matrix : Literal["A", "B"]
        The mma operand to be loaded.
    transposed : bool
        Whether the matrix is transposed, by default False.
22
23
24
25

    Returns
    -------
    T.Fragment
Jiaxing Ding's avatar
Jiaxing Ding committed
26
        Describes how threads and indices in fragment are laid out.
27
28
29

    """
    from tilelang.intrinsics.mma_layout import (
Jiaxing Ding's avatar
Jiaxing Ding committed
30
31
32
33
34
35
        shared_16x8_to_mma_32x4_layout_sr_a,
        shared_16x16_to_mma_32x8_layout_sr_a,
        shared_16x32_to_mma_32x16_layout_sr_a,
        shared_16x8_to_mma_32x4_layout_sr_b,
        shared_16x16_to_mma_32x8_layout_sr_b,
        shared_16x32_to_mma_32x16_layout_sr_b,
36
    )
37

38
39
40
41
42
43
    assert matrix in ["A", "B"], "matrix should be either A or B"
    dtype_bits = DataType(dtype).bits
    # s represents spatial axis
    # r represents reduction axis
    # sr represents the two dims are spatial + reduction
    # rs represents the two dims are reduction + spatial
Jiaxing Ding's avatar
Jiaxing Ding committed
44
45
46
47
48
49
50
51
    transform_func_sr_a: Callable = None
    transform_func_sr_b: Callable = None
    if dtype_bits == 32:
        transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
        transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
    elif dtype_bits == 16:
        transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a
        transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b
52
    elif dtype_bits == 8:
Jiaxing Ding's avatar
Jiaxing Ding committed
53
54
        transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a
        transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b
55
56
    else:
        raise ValueError(f"Unsupported dtype {dtype}")
Jiaxing Ding's avatar
Jiaxing Ding committed
57

58
59
60
61
62
    is_sr_conditions = [False]
    is_sr_conditions.append(matrix == "A" and not transposed)
    is_sr_conditions.append(matrix == "B" and transposed)
    is_sr_axis_order = any(is_sr_conditions)

Jiaxing Ding's avatar
Jiaxing Ding committed
63
64
65
66
67
68
    micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)

    # the layout of mma.sync is row.col.
    # so the b matrix expected a transposed basic layout
    transform_func: Callable = None
    if matrix == "A":
69
        transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i)
Jiaxing Ding's avatar
Jiaxing Ding committed
70
71
        micro_size_s, micro_size_r = micro_size_x, micro_size_k
    elif matrix == "B":
72
        transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(j, i)
Jiaxing Ding's avatar
Jiaxing Ding committed
73
74
75
        micro_size_s, micro_size_r = micro_size_k, micro_size_y
    else:
        raise ValueError(f"Unsupported matrix {matrix}")
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")

    def forward_thread(i: int, j: int) -> int:
        """
        Given the row index `i` and column index `j` in the fragment,
        """
        lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
        return lane_id

    def forward_index(i: int, j: int) -> int:
        """
        Given the row index `i` and column index `j` in the fragment,
        """
        _, local_id = inverse_mma_load_layout.map_indices([i, j])
        return local_id

    base_fragment = T.Fragment(
Jiaxing Ding's avatar
Jiaxing Ding committed
94
        [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        forward_thread_fn=forward_thread,
        forward_index_fn=forward_index,
    )
    return base_fragment


block_rows = 2
block_cols = 2
warp_rows = 4
warp_cols = 4
chunk = 2

from tilelang.tools import plot_layout

# ldmatrix layout 16x16
base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")

114
# warp layout 32x16
115
warp_layout = base_layout.repeat([block_rows, 1], repeat_on_thread=True).replicate(block_cols)
116
117
print(warp_layout)
plot_layout(warp_layout, name="warp_layout")
118

119
120
121
# block layout 128x32
block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
print(block_layout)
Jiaxing Ding's avatar
Jiaxing Ding committed
122
plot_layout(block_layout, name="block_layout")