fragment_mfma_load_a.py 4.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import tilelang.language as T
from typing import Literal, Callable
from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size

from tilelang.intrinsics.mfma_layout import (
    shared_16x4_to_local_64x1_layout_A,
    shared_16x16_to_local_64x4_layout_A,
    shared_16x32_to_local_64x8_layout_A,
    shared_16x64_to_local_64x16_layout_A,
)


def make_mfma_load_base_layout(dtype: str = "float16",
                               matrix: Literal["A", "B"] = "A",
                               k_dim: int = 16,
                               transposed: bool = False) -> T.Fragment:
    """
    Create a layout function for storing MFMA results into a fragment buffer.
    This layout is used in conjunction with `inverse_mfma_store_layout` to
    map fragment indices to threads and local indices.

    Parameters
    ----------
    dtype : str
        The data type of the matrix.
    matrix : Literal["A", "B"]
        The mfma operand to be loaded.
    k_dim : int
        The k dimension of the mfma.
    transposed : bool
        Whether the matrix is transposed, by default False.

    Returns
    -------
    T.Fragment
        Describes how threads and indices in fragment are laid out.

    """

    assert matrix in ["A", "B"], "matrix should be either A or B"
    # s represents spatial axis
    # r represents reduction axis
    # sr represents the two dims are spatial + reduction
    # rs represents the two dims are reduction + spatial
    transform_func_sr_a: Callable = None
    transform_func_sr_b: Callable = None

    if k_dim == 4:
        transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
        transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
    elif k_dim == 16:
        transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
        transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
    elif k_dim == 32:
        transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
        transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
    elif k_dim == 64:
        transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
        transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
    else:
        raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")

    is_sr_conditions = [False]
    is_sr_conditions.append(matrix == "A" and not transposed)
    is_sr_conditions.append(matrix == "B" and transposed)
    is_sr_axis_order = any(is_sr_conditions)

    micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)

    # the layout of mma.sync is row.col.
    # so the b matrix expected a transposed basic layout
    transform_func: Callable = None
    if matrix == "A":
        transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
            j, i)
        micro_size_s, micro_size_r = micro_size_x, micro_size_k
    elif matrix == "B":
        transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
            j, i)
        micro_size_s, micro_size_r = micro_size_k, micro_size_y
    else:
        raise ValueError(f"Unsupported matrix {matrix}")

    inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")

    def forward_thread(i: int, j: int) -> int:
        """
        Given the row index `i` and column index `j` in the fragment,
        """
        lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
        return lane_id

    def forward_index(i: int, j: int) -> int:
        """
        Given the row index `i` and column index `j` in the fragment,
        """
        _, local_id = inverse_mma_load_layout.map_indices([i, j])
        return local_id

    base_fragment = T.Fragment(
        [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
        forward_thread_fn=forward_thread,
        forward_index_fn=forward_index,
    )
    return base_fragment


block_rows = 2
block_cols = 2
warp_rows = 2
warp_cols = 2
chunk = 2

from tilelang.tools import plot_layout

# ldmatrix layout 16x16
base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")

# warp layout 32x32
warp_layout = base_layout.repeat([warp_rows, warp_cols],
                                 repeat_on_thread=False,
                                 lower_dim_first=False)
print(warp_layout)
plot_layout(warp_layout, name="warp_layout")

# block layout 64x32
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True,
                                  lower_dim_first=True).replicate(block_cols)
print(block_layout)
plot_layout(block_layout, name="block_layout")