"testing/python/vscode:/vscode.git/clone" did not exist on "181267c79a3893c4642ecb9a68da871e29ae0e02"
Unverified Commit 6c67a77f authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[Layout] fix plot layout (#890)

parent 599264ca
...@@ -17,53 +17,66 @@ def make_mma_load_base_layout(dtype: str = "float16", ...@@ -17,53 +17,66 @@ def make_mma_load_base_layout(dtype: str = "float16",
---------- ----------
dtype : str dtype : str
The data type of the matrix. The data type of the matrix.
local_buf : tir.Buffer matrix : Literal["A", "B"]
The local buffer representing a fragment of a matrix. The mma operand to be loaded.
transposed : bool
Whether the matrix is transposed, by default False.
Returns Returns
------- -------
T.Fragment T.Fragment
A fragment object that describes how threads and indices Describes how threads and indices in fragment are laid out.
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
""" """
from tilelang.intrinsics.mma_layout import ( from tilelang.intrinsics.mma_layout import (
shared_16x16_to_mma_32x8_layout_sr, shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x16_to_mma_32x8_layout_rs, shared_16x16_to_mma_32x8_layout_sr_a,
shared_16x32_to_mma_32x16_layout, shared_16x32_to_mma_32x16_layout_sr_a,
shared_32x16_to_mma_32x16_layout, shared_16x8_to_mma_32x4_layout_sr_b,
shared_16x16_to_mma_32x8_layout_sr_b,
shared_16x32_to_mma_32x16_layout_sr_b,
) )
assert matrix in ["A", "B"], "matrix should be either A or B" assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits dtype_bits = DataType(dtype).bits
assert transposed is False, "transposed is not supported yet"
# s represents spatial axis # s represents spatial axis
# r represents reduction axis # r represents reduction axis
# sr represents the two dims are spatial + reduction # sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial # rs represents the two dims are reduction + spatial
transform_func_sr: Callable = None transform_func_sr_a: Callable = None
transform_func_rs: Callable = None transform_func_sr_b: Callable = None
if dtype_bits == 16: if dtype_bits == 32:
transform_func_sr = shared_16x16_to_mma_32x8_layout_sr transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
transform_func_rs = shared_16x16_to_mma_32x8_layout_rs transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
elif dtype_bits == 16:
transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a
transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b
elif dtype_bits == 8: elif dtype_bits == 8:
transform_func_sr = shared_16x32_to_mma_32x16_layout transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a
transform_func_rs = shared_32x16_to_mma_32x16_layout transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b
else: else:
raise ValueError(f"Unsupported dtype {dtype}") raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False] is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed) is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed) is_sr_conditions.append(matrix == "B" and transposed)
is_sr_axis_order = any(is_sr_conditions) is_sr_axis_order = any(is_sr_conditions)
transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)
micro_size_s, _, micro_size_r = get_mma_micro_size(dtype) # the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
else:
raise ValueError(f"Unsupported matrix {matrix}")
transform_func = transform_func
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int: def forward_thread(i: int, j: int) -> int:
...@@ -81,7 +94,7 @@ def make_mma_load_base_layout(dtype: str = "float16", ...@@ -81,7 +94,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
return local_id return local_id
base_fragment = T.Fragment( base_fragment = T.Fragment(
[micro_size_r, micro_size_s], [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread, forward_thread_fn=forward_thread,
forward_index_fn=forward_index, forward_index_fn=forward_index,
) )
...@@ -109,4 +122,4 @@ plot_layout(warp_layout, name="warp_layout") ...@@ -109,4 +122,4 @@ plot_layout(warp_layout, name="warp_layout")
# block layout 128x32 # block layout 128x32
block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False) block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
print(block_layout) print(block_layout)
# plot_layout(block_layout, name="block_layout") plot_layout(block_layout, name="block_layout")
...@@ -490,7 +490,6 @@ class TensorCoreIntrinEmitter(object): ...@@ -490,7 +490,6 @@ class TensorCoreIntrinEmitter(object):
transform_func_sr_a: Callable = None transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None transform_func_sr_b: Callable = None
if dtype_bits == 32: if dtype_bits == 32:
...
transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
elif dtype_bits == 16: elif dtype_bits == 16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment