Unverified Commit 6c67a77f authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[Layout] fix plot layout (#890)

parent 599264ca
......@@ -17,53 +17,66 @@ def make_mma_load_base_layout(dtype: str = "float16",
----------
dtype : str
The data type of the matrix.
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
matrix : Literal["A", "B"]
The mma operand to be loaded.
transposed : bool
Whether the matrix is transposed, by default False.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Describes how threads and indices in fragment are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.intrinsics.mma_layout import (
shared_16x16_to_mma_32x8_layout_sr,
shared_16x16_to_mma_32x8_layout_rs,
shared_16x32_to_mma_32x16_layout,
shared_32x16_to_mma_32x16_layout,
shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x16_to_mma_32x8_layout_sr_a,
shared_16x32_to_mma_32x16_layout_sr_a,
shared_16x8_to_mma_32x4_layout_sr_b,
shared_16x16_to_mma_32x8_layout_sr_b,
shared_16x32_to_mma_32x16_layout_sr_b,
)
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits
assert transposed is False, "transposed is not supported yet"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr: Callable = None
transform_func_rs: Callable = None
if dtype_bits == 16:
transform_func_sr = shared_16x16_to_mma_32x8_layout_sr
transform_func_rs = shared_16x16_to_mma_32x8_layout_rs
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if dtype_bits == 32:
transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
elif dtype_bits == 16:
transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a
transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b
elif dtype_bits == 8:
transform_func_sr = shared_16x32_to_mma_32x16_layout
transform_func_rs = shared_32x16_to_mma_32x16_layout
transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a
transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b
else:
raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed)
is_sr_axis_order = any(is_sr_conditions)
transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs
micro_size_s, _, micro_size_r = get_mma_micro_size(dtype)
micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
else:
raise ValueError(f"Unsupported matrix {matrix}")
transform_func = transform_func
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
......@@ -81,7 +94,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
return local_id
base_fragment = T.Fragment(
[micro_size_r, micro_size_s],
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
......@@ -109,4 +122,4 @@ plot_layout(warp_layout, name="warp_layout")
# block layout 128x32
block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
print(block_layout)
# plot_layout(block_layout, name="block_layout")
plot_layout(block_layout, name="block_layout")
......@@ -490,7 +490,6 @@ class TensorCoreIntrinEmitter(object):
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if dtype_bits == 32:
...
transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
elif dtype_bits == 16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment