Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
...@@ -201,7 +201,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -201,7 +201,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including: In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including:
- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilzing magic layout transformation and intrins to accelerate dequantize gemm. - [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilizing magic layout transformation and intrins to accelerate dequantize gemm.
- [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning. - [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning.
- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations. - [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations.
- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col. - [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col.
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
- **Python Version**: >= 3.8 - **Python Version**: >= 3.8
- **CUDA Version**: >= 11.0 - **CUDA Version**: >= 11.0
The easiest way to install TileLang is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal. The easiest way to install TileLang is directly from the PyPi using pip. To install the latest version, run the following command in your terminal.
**Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/microsoft/TileLang/blob/main/docs/Installation.md#building-from-source).** **Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/microsoft/TileLang/blob/main/docs/Installation.md#building-from-source).**
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
import argparse import argparse
from functools import partial from functools import partial
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4 assert nbit == 4
assert dtype == "float16" assert dtype == "float16"
...@@ -24,12 +25,15 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: ...@@ -24,12 +25,15 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype:
s = f4 >> tir.const(3, "uint16") s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16") e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16") e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret("float16", val_f16 = tir.reinterpret(
"float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16")) ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16 return val_f16
def torch_convert(tensor): def torch_convert(tensor):
def print_bit(name, val): def print_bit(name, val):
val_cpu = val.cpu().item() val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}' binary_repr = f'{val_cpu:032b}'
...@@ -55,6 +59,7 @@ def torch_convert(tensor): ...@@ -55,6 +59,7 @@ def torch_convert(tensor):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor return new_tensor
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
...@@ -72,10 +77,7 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): ...@@ -72,10 +77,7 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.ceildiv(K, block_K),
num_stages=1
):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local) T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K): for i, j in T.Parallel(block_N, block_K):
...@@ -89,6 +91,7 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): ...@@ -89,6 +91,7 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
return main return main
def test_fp4_fp16_convert_close(): def test_fp4_fp16_convert_close():
N, K = 256, 256 N, K = 256, 256
block_N, block_K = 64, 64 block_N, block_K = 64, 64
...@@ -109,6 +112,7 @@ def test_fp4_fp16_convert_close(): ...@@ -109,6 +112,7 @@ def test_fp4_fp16_convert_close():
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out) assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass") print("Pass")
def get_configs(): def get_configs():
block_M = [128] block_M = [128]
block_N = [128, 256] block_N = [128, 256]
...@@ -118,13 +122,19 @@ def get_configs(): ...@@ -118,13 +122,19 @@ def get_configs():
splits = [1] splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [ configs = [{
{'block_M': c[0], 'block_N': c[1], 'block_K': c[2], 'num_stages': c[3], 'threads': c[4], 'split': c[5]} 'block_M': c[0],
for c in _configs 'block_N': c[1],
] 'block_K': c[2],
'num_stages': c[3],
'threads': c[4],
'split': c[5]
} for c in _configs]
return configs return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
...@@ -142,10 +152,13 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -142,10 +152,13 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
B: T.Buffer(B_shape, storage_dtype), B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype), Ct: T.Buffer((N, M), out_dtype),
): ):
SplitC = T.alloc_buffer( SplitC = T.alloc_buffer([
[split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype split, (N + block_N - 1) // block_N * block_N,
) (M + block_M - 1) // block_M * block_M
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz): ], out_dtype)
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
...@@ -154,12 +167,10 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -154,12 +167,10 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout( T.annotate_layout({
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
} })
)
T.clear(Ct_local) T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
...@@ -175,7 +186,8 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -175,7 +186,8 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
) )
T.copy(B_dequantize_local, B_dequantize_prev_local) T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype) acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc) T.clear(acc)
...@@ -190,7 +202,8 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -190,7 +202,8 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
B: T.Buffer(B_shape, storage_dtype), B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype), Ct: T.Buffer((N, M), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
...@@ -199,12 +212,10 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -199,12 +212,10 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout( T.annotate_layout({
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
} })
)
T.clear(Ct_local) T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages): for k in T.Pipelined(K // block_K, num_stages=num_stages):
...@@ -221,7 +232,8 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -221,7 +232,8 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
T.copy(B_dequantize_local, B_dequantize_prev_local) T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared) T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M]) T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
if split == 1: if split == 1:
return main return main
...@@ -229,23 +241,34 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -229,23 +241,34 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
return main_split return main_split
if tune: if tune:
@autotune( @autotune(
configs=get_configs(), configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"], keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10, warmup=10,
rep=10 rep=10)
) @jit(
@jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None, profiler="auto") out_idx=[2],
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None): supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split) return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel() return kernel()
else: else:
def kernel(block_M, block_N, block_K, num_stages, threads, split=1): def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split) return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel return kernel
def ref_program(A, qB): def ref_program(A, qB):
dtypeC = "float16" dtypeC = "float16"
B = torch_convert(qB) B = torch_convert(qB)
...@@ -253,6 +276,7 @@ def ref_program(A, qB): ...@@ -253,6 +276,7 @@ def ref_program(A, qB):
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1) return C.transpose(0, 1)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M') parser.add_argument('--m', type=int, default=256, help='M')
...@@ -264,7 +288,9 @@ if __name__ == "__main__": ...@@ -264,7 +288,9 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K total_flops = 2 * M * N * K
if (not args.tune): if (not args.tune):
program = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) program = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
mod, params = tilelang.lower(program) mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
...@@ -276,7 +302,8 @@ if __name__ == "__main__": ...@@ -276,7 +302,8 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_latency, best_config, ref_latency = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune) best_latency, best_config, ref_latency = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
...@@ -6,18 +6,15 @@ from tilelang import Profiler ...@@ -6,18 +6,15 @@ from tilelang import Profiler
import tilelang.language as T import tilelang.language as T
def matmul( def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"
):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Buffer((M, N), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128
) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......
...@@ -9,8 +9,7 @@ import tilelang as TL ...@@ -9,8 +9,7 @@ import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter, TensorCoreIntrinEmitter,)
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
torch.manual_seed(0) torch.manual_seed(0)
...@@ -180,6 +179,7 @@ def tl_matmul( ...@@ -180,6 +179,7 @@ def tl_matmul(
return main return main
M, N, K = 128, 128, 128 M, N, K = 128, 128, 128
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float16" in_dtype, out_dtype, accum_dtype = "float16", "float16", "float16"
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
......
...@@ -61,7 +61,6 @@ Fragment makeGemmFragmentC16x16CDNA() { ...@@ -61,7 +61,6 @@ Fragment makeGemmFragmentC16x16CDNA() {
return Fragment({i, j}, {index}, forward_thread, rep); return Fragment({i, j}, {index}, forward_thread, rep);
} }
Fragment makeGemmFragment8x8Transposed() { Fragment makeGemmFragment8x8Transposed() {
IterVar i = make_itervar("i", 8); IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8); IterVar j = make_itervar("j", 8);
...@@ -80,57 +79,70 @@ Fragment makeGemmFragment8x16() { ...@@ -80,57 +79,70 @@ Fragment makeGemmFragment8x16() {
return Fragment({i, j}, {index}, forward_thread, rep); return Fragment({i, j}, {index}, forward_thread, rep);
} }
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, const int warp_m, Fragment makeGemmFragmentC_F64(const int block_m, const int block_n,
const int warp_n) { const int warp_m, const int warp_n) {
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0);
ICHECK(warp_n % 16 == 0); ICHECK(warp_n % 16 == 0);
auto base_layout = makeGemmFragment8x8(); auto base_layout = makeGemmFragment8x8();
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); auto warp_layout =
auto block_layout = warp_layout->Repeat({warp_m / 8, warp_n / 8}, false, false); base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout =
warp_layout->Repeat({warp_m / 8, warp_n / 8}, false, false);
return block_layout; return block_layout;
} }
Fragment makeGemmFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n, Fragment makeGemmFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) { const int element_size) {
if (element_size == 64) return makeGemmFragmentC_F64(block_m, block_n, warp_m, warp_n); if (element_size == 64)
return makeGemmFragmentC_F64(block_m, block_n, warp_m, warp_n);
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n; ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false); auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); auto warp_layout =
auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 8}, false, false); base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout =
warp_layout->Repeat({warp_m / 16, warp_n / 8}, false, false);
return block_layout; return block_layout;
} }
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n, Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) { const int element_size) {
if (element_size == 64) LOG(FATAL) << "Not supported"; if (element_size == 64)
LOG(FATAL) << "Not supported";
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n; ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false); auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
auto warp_layout = base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true); auto warp_layout =
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
return block_layout; return block_layout;
} }
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m, Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_n, const int element_size) { const int warp_m, const int warp_n,
const int element_size) {
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
// ICHECK(block_n == warp_n); // ICHECK(block_n == warp_n);
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0);
auto warp_layout = auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false,
makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false, false); // 16 x N (1 warp) false); // 16 x N (1 warp)
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); // 16*Y x N (Y warp) auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n},
true, false); // 16*Y x N (Y warp)
return block_layout->Repeat({warp_m / 16, 1}, false, false); return block_layout->Repeat({warp_m / 16, 1}, false, false);
} }
Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block_k, Fragment makeGemmFragmentA(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int element_size) { const int block_k, const int warp_m,
const int warp_n, const int element_size) {
// assume not transposed // assume not transposed
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
...@@ -140,13 +152,17 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block ...@@ -140,13 +152,17 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block
ICHECK(element_size == 8 || element_size == 16); ICHECK(element_size == 8 || element_size == 16);
if (element_size == 8) { if (element_size == 8) {
auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false); auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)->Replicate(block_n / warp_n); auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false); ->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false);
return block_layout; return block_layout;
} else if (element_size == 16) { } else if (element_size == 16) {
auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false); auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)->Replicate(block_n / warp_n); auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false); ->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
return block_layout; return block_layout;
} else { } else {
ICHECK(0); ICHECK(0);
...@@ -154,36 +170,45 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block ...@@ -154,36 +170,45 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block
} }
} }
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, const int block_k, Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n, bool transposed) { const int block_k, const int warp_m,
const int warp_n, bool transposed) {
// assume not transposed // assume not transposed
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0); ICHECK(block_k % 16 == 0);
if (transposed) { if (transposed) {
auto base_layout = makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false); auto base_layout =
auto warp_layout = base_layout->Repeat({warp_m / 16, block_k / 16}, false, false); makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)->Replicate(block_n / warp_n); auto warp_layout =
base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
->Replicate(block_n / warp_n);
return block_layout; return block_layout;
} else { } else {
auto base_layout = makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false); auto base_layout =
auto warp_layout = base_layout->Repeat({warp_m / 16, block_k / 16}, false, false); makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false);
auto block_layout = auto warp_layout =
warp_layout->Repeat({block_m / warp_m, 1}, true, true)->Replicate(block_n / warp_n); base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
->Replicate(block_n / warp_n);
return block_layout; return block_layout;
} }
} }
Fragment makeGemmFragmentB(const int block_m, const int block_n,
Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k, const int block_k, const int warp_m,
const int warp_m, const int warp_n) { const int warp_n) {
// transposed // transposed
ICHECK(warp_n % 8 == 0); ICHECK(warp_n % 8 == 0);
ICHECK(block_k % 16 == 0); ICHECK(block_k % 16 == 0);
auto base_layout = makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false); auto base_layout =
auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({1, block_n / warp_n}, true); makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
auto block_layout = warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true); auto warp_layout = base_layout->Replicate(block_m / warp_m)
->Repeat({1, block_n / warp_n}, true);
auto block_layout =
warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
return block_layout; return block_layout;
} }
...@@ -195,33 +220,39 @@ Fragment makeGemmFragment32x32(int element_size) { ...@@ -195,33 +220,39 @@ Fragment makeGemmFragment32x32(int element_size) {
if (element_size == 16) { if (element_size == 16) {
PrimExpr thd = FloorMod(i, 4) + FloorDiv(FloorMod(i, 16), 8) * 4 + PrimExpr thd = FloorMod(i, 4) + FloorDiv(FloorMod(i, 16), 8) * 4 +
FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16; FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
PrimExpr idx = FloorMod(j, 4) + FloorDiv(j, 16) * 4 + FloorDiv(FloorMod(i, 8), 4) * 8 + PrimExpr idx = FloorMod(j, 4) + FloorDiv(j, 16) * 4 +
FloorDiv(FloorMod(i, 8), 4) * 8 +
FloorDiv(FloorMod(j, 8), 4) * 16; FloorDiv(FloorMod(j, 8), 4) * 16;
return Fragment({i, j}, {idx}, thd, rep); return Fragment({i, j}, {idx}, thd, rep);
} else { } else {
PrimExpr thd = FloorMod(i, 2) + 2 * FloorDiv(FloorMod(j, 4), 2) + PrimExpr thd = FloorMod(i, 2) + 2 * FloorDiv(FloorMod(j, 4), 2) +
FloorDiv(FloorMod(i, 16), 8) * 4 + FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(FloorMod(i, 16), 8) * 4 +
FloorDiv(i, 16) * 16; FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
PrimExpr idx = FloorMod(j, 2) + 2 * FloorDiv(FloorMod(i, 4), 2) + FloorDiv(j, 16) * 4 + PrimExpr idx = FloorMod(j, 2) + 2 * FloorDiv(FloorMod(i, 4), 2) +
FloorDiv(FloorMod(i, 8), 4) * 8 + FloorDiv(FloorMod(j, 8), 4) * 16; FloorDiv(j, 16) * 4 + FloorDiv(FloorMod(i, 8), 4) * 8 +
FloorDiv(FloorMod(j, 8), 4) * 16;
return Fragment({i, j}, {idx}, thd, rep); return Fragment({i, j}, {idx}, thd, rep);
} }
} }
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m, Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
const int warp_n, int element_size) { const int warp_m, const int warp_n,
int element_size) {
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 32 == 0); ICHECK(warp_m % 32 == 0);
ICHECK(warp_n % 32 == 0); ICHECK(warp_n % 32 == 0);
auto base_layout = makeGemmFragment32x32(element_size); auto base_layout = makeGemmFragment32x32(element_size);
auto warp_layout = base_layout->Repeat({warp_m / 32, warp_n / 32}, false, false); auto warp_layout =
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true); base_layout->Repeat({warp_m / 32, warp_n / 32}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true);
return block_layout; return block_layout;
} }
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int block_k, Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
const int warp_m, const int warp_n) { const int block_k, const int warp_m,
const int warp_n) {
// assume not transposed // assume not transposed
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
...@@ -231,17 +262,22 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int ...@@ -231,17 +262,22 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int
IterVar i = make_itervar("i", 32); IterVar i = make_itervar("i", 32);
IterVar j = make_itervar("j", 4); IterVar j = make_itervar("j", 4);
IterVar rep = make_itervar("rep", 2); IterVar rep = make_itervar("rep", 2);
PrimExpr thd = FloorDiv(FloorMod(i, 16), 8) * 4 + 16 * FloorDiv(i, 16) + FloorMod(i, 4) + 8 * rep; PrimExpr thd = FloorDiv(FloorMod(i, 16), 8) * 4 + 16 * FloorDiv(i, 16) +
FloorMod(i, 4) + 8 * rep;
PrimExpr idx = j + FloorDiv(FloorMod(i, 8), 4) * 4; PrimExpr idx = j + FloorDiv(FloorMod(i, 8), 4) * 4;
Fragment base_layout = Fragment({i, j}, {idx}, thd, rep); Fragment base_layout = Fragment({i, j}, {idx}, thd, rep);
auto warp_layout = base_layout->Repeat({warp_m / 32, block_k / 4}, false, false); auto warp_layout =
auto block_layout = warp_layout->Replicate(block_n / warp_n)->Repeat({block_m / warp_m, 1}, true); base_layout->Repeat({warp_m / 32, block_k / 4}, false, false);
auto block_layout = warp_layout->Replicate(block_n / warp_n)
->Repeat({block_m / warp_m, 1}, true);
return block_layout; return block_layout;
} }
PrimExpr xor2x2(const PrimExpr& i, const PrimExpr& j) { return FloorMod(i + j, 2); } PrimExpr xor2x2(const PrimExpr &i, const PrimExpr &j) {
return FloorMod(i + j, 2);
}
PrimExpr xor4x4(const PrimExpr& i, const PrimExpr& j) { PrimExpr xor4x4(const PrimExpr &i, const PrimExpr &j) {
PrimExpr i0 = FloorMod(i, 2); PrimExpr i0 = FloorMod(i, 2);
PrimExpr j0 = FloorMod(j, 2); PrimExpr j0 = FloorMod(j, 2);
PrimExpr i1 = FloorDiv(i, 2); PrimExpr i1 = FloorDiv(i, 2);
...@@ -249,7 +285,7 @@ PrimExpr xor4x4(const PrimExpr& i, const PrimExpr& j) { ...@@ -249,7 +285,7 @@ PrimExpr xor4x4(const PrimExpr& i, const PrimExpr& j) {
return 2 * xor2x2(i1, j1) + xor2x2(i0, j0); return 2 * xor2x2(i1, j1) + xor2x2(i0, j0);
} }
PrimExpr xor8x8(const PrimExpr& i, const PrimExpr j) { PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) {
PrimExpr i0 = FloorMod(i, 2); PrimExpr i0 = FloorMod(i, 2);
PrimExpr j0 = FloorMod(j, 2); PrimExpr j0 = FloorMod(j, 2);
PrimExpr i1 = FloorDiv(i, 2); PrimExpr i1 = FloorDiv(i, 2);
...@@ -291,8 +327,10 @@ Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) { ...@@ -291,8 +327,10 @@ Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) {
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index}); return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
} }
// Detail implementation please ref to bitblas::tl::mfma_layout::make_mfma_swizzle_layout // Detail implementation please ref to
Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size, int kPack=1) { // bitblas::tl::mfma_layout::make_mfma_swizzle_layout
Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size,
int kPack = 1) {
const int numBanks = 32; const int numBanks = 32;
const int bankBitWidth = 32; const int bankBitWidth = 32;
const int SIMDWidth = 16; const int SIMDWidth = 16;
...@@ -352,7 +390,8 @@ Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) { ...@@ -352,7 +390,8 @@ Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) {
IterVar j = make_itervar("j", continuous); IterVar j = make_itervar("j", continuous);
int padded = continuous; int padded = continuous;
// Add 128 bits padding when the last dim is a multiple of 256 bits // Add 128 bits padding when the last dim is a multiple of 256 bits
if ((element_size * continuous) % 256 == 0) padded += 128 / element_size; if ((element_size * continuous) % 256 == 0)
padded += 128 / element_size;
return Layout(Array{i, j}, {i * padded + j}); return Layout(Array{i, j}, {i * padded + j});
} }
...@@ -363,14 +402,17 @@ Layout MakeGemmVoltaABLayoutCrosswise(int stride, int continuous) { ...@@ -363,14 +402,17 @@ Layout MakeGemmVoltaABLayoutCrosswise(int stride, int continuous) {
PrimExpr vec_contiguous_idx = FloorDiv(j, 4); PrimExpr vec_contiguous_idx = FloorDiv(j, 4);
PrimExpr vec_strided_within_tile = FloorMod(vec_contiguous_idx, 8); PrimExpr vec_strided_within_tile = FloorMod(vec_contiguous_idx, 8);
PrimExpr bit2 = FloorMod(FloorDiv(FloorMod(i, 32), 16) + FloorDiv(FloorMod(i, 16), 8) + PrimExpr bit2 =
FloorMod(FloorDiv(FloorMod(i, 32), 16) + FloorDiv(FloorMod(i, 16), 8) +
FloorDiv(vec_strided_within_tile, 4), FloorDiv(vec_strided_within_tile, 4),
2); 2);
PrimExpr bit1 = PrimExpr bit1 = xor2x2(FloorDiv(FloorMod(i, 8), 4),
xor2x2(FloorDiv(FloorMod(i, 8), 4), FloorDiv(FloorMod(vec_strided_within_tile, 4), 2)); FloorDiv(FloorMod(vec_strided_within_tile, 4), 2));
PrimExpr permuted_vec_contiguous = FloorDiv(i, 16) * 16 + FloorMod(i, 4) * 4 + bit2 * 2 + bit1; PrimExpr permuted_vec_contiguous =
FloorDiv(i, 16) * 16 + FloorMod(i, 4) * 4 + bit2 * 2 + bit1;
PrimExpr offset = FloorMod(j, 4) + permuted_vec_contiguous * 4 + vec_contiguous_idx * stride * 4; PrimExpr offset = FloorMod(j, 4) + permuted_vec_contiguous * 4 +
vec_contiguous_idx * stride * 4;
return Layout(Array{i, j}, {offset}); return Layout(Array{i, j}, {offset});
} }
...@@ -390,9 +432,11 @@ Layout MakeGemmVoltaALayoutCongruous(int stride, int continuous) { ...@@ -390,9 +432,11 @@ Layout MakeGemmVoltaALayoutCongruous(int stride, int continuous) {
FloorMod(tile_contiguous_residual, 2) * 4 + FloorMod(tile_contiguous_residual, 2) * 4 +
xor4x4(tile_strided_residual, permuted_strided_within_tile); xor4x4(tile_strided_residual, permuted_strided_within_tile);
PrimExpr element_strided = permuted_strided_within_tile + tile_strided_idx * 4; PrimExpr element_strided =
permuted_strided_within_tile + tile_strided_idx * 4;
PrimExpr element_contiguous = PrimExpr element_contiguous =
FloorMod(j, 8) + (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8; FloorMod(j, 8) +
(permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
PrimExpr offset = element_strided * continuous + element_contiguous; PrimExpr offset = element_strided * continuous + element_contiguous;
return Layout(Array{i, j}, {offset}); return Layout(Array{i, j}, {offset});
} }
...@@ -413,21 +457,28 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) { ...@@ -413,21 +457,28 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
FloorDiv(tile_contiguous_residual, 4) * 4 + FloorDiv(tile_contiguous_residual, 4) * 4 +
xor4x4(tile_strided_residual, permuted_strided_within_tile); xor4x4(tile_strided_residual, permuted_strided_within_tile);
PrimExpr element_strided = permuted_strided_within_tile + tile_strided_idx * 4; PrimExpr element_strided =
permuted_strided_within_tile + tile_strided_idx * 4;
PrimExpr element_contiguous = PrimExpr element_contiguous =
FloorMod(j, 8) + (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8; FloorMod(j, 8) +
(permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
PrimExpr offset = element_strided * continuous + element_contiguous; PrimExpr offset = element_strided * continuous + element_contiguous;
return Layout(Array{i, j}, {offset}); return Layout(Array{i, j}, {offset});
} }
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, int kfactor) { Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
if (kfactor == 2) return MakeGemmVoltaABLayoutCrosswise(stride, continuous); int kfactor) {
if (is_a && continuous % 64 == 0) return MakeGemmVoltaALayoutCongruous(stride, continuous); if (kfactor == 2)
if (!is_a && continuous % 64 == 0) return MakeGemmVoltaBLayoutCongruous(stride, continuous); return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous);
if (!is_a && continuous % 64 == 0)
return MakeGemmVoltaBLayoutCongruous(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, 16); return makeGemmABLayoutPadded(stride, continuous, 16);
} }
Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfactor) { Layout makeGemmABLayout(int stride, int continuous, int element_size,
int kfactor) {
if (element_size == 64) { if (element_size == 64) {
if (kfactor == 1 && continuous % 16 == 0) // float64 KxN if (kfactor == 1 && continuous % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(stride, continuous); return makeGemmABLayoutF64_Kouter(stride, continuous);
...@@ -447,7 +498,8 @@ Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfacto ...@@ -447,7 +498,8 @@ Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfacto
} }
} }
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kPack) { Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kPack) {
int vector_size = 128 / element_size; int vector_size = 128 / element_size;
if (continuous % (vector_size * 4) == 0) if (continuous % (vector_size * 4) == 0)
return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack); return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
......
...@@ -20,7 +20,7 @@ namespace tl { ...@@ -20,7 +20,7 @@ namespace tl {
using namespace tir; using namespace tir;
static Var getPlaceholder(const std::string& s) { static Var getPlaceholder(const std::string &s) {
static std::unordered_map<std::string, Var> map; static std::unordered_map<std::string, Var> map;
if (map.find(s) == map.end()) { if (map.find(s) == map.end()) {
map[s] = Var(s); map[s] = Var(s);
...@@ -29,7 +29,9 @@ static Var getPlaceholder(const std::string& s) { ...@@ -29,7 +29,9 @@ static Var getPlaceholder(const std::string& s) {
} }
Var ReplicationPlaceholder() { return getPlaceholder("_rep"); } Var ReplicationPlaceholder() { return getPlaceholder("_rep"); }
Var InputPlaceholder(size_t idx) { return getPlaceholder(std::string{'_', char('i' + idx)}); } Var InputPlaceholder(size_t idx) {
return getPlaceholder(std::string{'_', char('i' + idx)});
}
Map<Var, Range> LayoutNode::getVarMap() const { Map<Var, Range> LayoutNode::getVarMap() const {
Map<Var, Range> map; Map<Var, Range> map;
...@@ -45,11 +47,13 @@ Map<Var, Range> FragmentNode::getVarMap() const { ...@@ -45,11 +47,13 @@ Map<Var, Range> FragmentNode::getVarMap() const {
return map; return map;
} }
LayoutNode::LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) { LayoutNode::LayoutNode(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index) {
input_size_ = input_size; input_size_ = input_size;
arith::Analyzer analyzer; arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer); UpdateAnalyzer(&analyzer);
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); }); forward_index_ = forward_index.Map(
[&](const PrimExpr &e) { return analyzer.Simplify(e); });
} }
Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) { Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
...@@ -60,7 +64,8 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) { ...@@ -60,7 +64,8 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
CHECK(is_zero(forward_var[i]->dom->min)); CHECK(is_zero(forward_var[i]->dom->min));
input_size.push_back(forward_var[i]->dom->extent); input_size.push_back(forward_var[i]->dom->extent);
} }
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); }); forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<LayoutNode>(input_size, forward_index); auto n = make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n); data_ = std::move(n);
...@@ -71,13 +76,13 @@ Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) { ...@@ -71,13 +76,13 @@ Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
data_ = std::move(n); data_ = std::move(n);
} }
void LayoutNode::VisitAttrs(AttrVisitor* v) { void LayoutNode::VisitAttrs(AttrVisitor *v) {
v->Visit("input_size", &input_size_); v->Visit("input_size", &input_size_);
v->Visit("forward_index", &forward_index_); v->Visit("forward_index", &forward_index_);
} }
void LayoutNode::UpdateAnalyzer(arith::Analyzer* analyzer) const { void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
for (const auto& [var, dom] : getVarMap()) { for (const auto &[var, dom] : getVarMap()) {
analyzer->Bind(var, dom); analyzer->Bind(var, dom);
} }
} }
...@@ -99,66 +104,74 @@ Array<PrimExpr> LayoutNode::OutputShape() const { ...@@ -99,66 +104,74 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
return ret; return ret;
} }
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr>& vars) const { Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
if (vars.empty()) return forward_index_; if (vars.empty())
return forward_index_;
ICHECK_EQ(vars.size(), InputDim()); ICHECK_EQ(vars.size(), InputDim());
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < InputDim(); i++) { for (size_t i = 0; i < InputDim(); i++) {
vmap.Set(InputPlaceholder(i), vars[i]); vmap.Set(InputPlaceholder(i), vars[i]);
} }
return forward_index_.Map([&](const PrimExpr& e) { return Substitute(e, vmap); }); return forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); });
} }
Fragment FragmentNode::Repeat(const Array<PrimExpr>& repeats, bool repeat_on_thread, Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
bool repeat_on_thread,
bool lower_dim_first) const { bool lower_dim_first) const {
ICHECK_EQ(repeats.size(), InputDim()); ICHECK_EQ(repeats.size(), InputDim());
Array<PrimExpr> new_input_size; Array<PrimExpr> new_input_size;
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < InputDim(); i++) { for (size_t i = 0; i < InputDim(); i++) {
new_input_size.push_back(input_size_[i] * repeats[i]); new_input_size.push_back(input_size_[i] * repeats[i]);
vmap.Set(InputPlaceholder(i), FloorMod(InputPlaceholder(i), InputShape()[i])); vmap.Set(InputPlaceholder(i),
FloorMod(InputPlaceholder(i), InputShape()[i]));
} }
PrimExpr repeats_index = 0, repeat_stride = 1; PrimExpr repeats_index = 0, repeat_stride = 1;
if (lower_dim_first) { if (lower_dim_first) {
for (int i = InputDim() - 1; i >= 0; i--) { for (int i = InputDim() - 1; i >= 0; i--) {
repeats_index += repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]); repeats_index +=
repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
repeat_stride *= repeats[i]; repeat_stride *= repeats[i];
} }
} else { } else {
for (size_t i = 0; i < InputDim(); i++) { for (size_t i = 0; i < InputDim(); i++) {
repeats_index += repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]); repeats_index +=
repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
repeat_stride *= repeats[i]; repeat_stride *= repeats[i];
} }
} }
if (repeat_on_thread) { if (repeat_on_thread) {
PrimExpr thread_size = ThreadExtent(); PrimExpr thread_size = ThreadExtent();
auto new_forward_index = auto new_forward_index = forward_index_.Map(
forward_index_.Map([&](const PrimExpr& e) { return Substitute(e, vmap); }); [&](const PrimExpr &e) { return Substitute(e, vmap); });
auto new_forward_thread = Substitute(forward_thread_, vmap) + thread_size * repeats_index; auto new_forward_thread =
return Fragment(new_input_size, new_forward_index, new_forward_thread, replicate_size_, Substitute(forward_thread_, vmap) + thread_size * repeats_index;
NullOpt); return Fragment(new_input_size, new_forward_index, new_forward_thread,
replicate_size_, NullOpt);
} else { } else {
ICHECK(OutputDim() == 1); ICHECK(OutputDim() == 1);
PrimExpr frag_len = OutputShape()[0]; PrimExpr frag_len = OutputShape()[0];
Array<PrimExpr> new_forward_index = {Substitute(forward_index_[0], vmap) + Array<PrimExpr> new_forward_index = {Substitute(forward_index_[0], vmap) +
frag_len * repeats_index}; frag_len * repeats_index};
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap); PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
return Fragment(new_input_size, new_forward_index, new_forward_thread, replicate_size_, return Fragment(new_input_size, new_forward_index, new_forward_thread,
NullOpt); replicate_size_, NullOpt);
} }
} }
Fragment FragmentNode::Replicate(int repeats) const { Fragment FragmentNode::Replicate(int repeats) const {
ICHECK(repeats >= 1); ICHECK(repeats >= 1);
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), FloorMod(ReplicationPlaceholder(), ReplicateExtent())); vmap.Set(ReplicationPlaceholder(),
FloorMod(ReplicationPlaceholder(), ReplicateExtent()));
PrimExpr new_forward_thread = PrimExpr new_forward_thread =
Substitute(forward_thread_, vmap) + Substitute(forward_thread_, vmap) +
ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent()); ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
return Fragment(input_size_, forward_index_, new_forward_thread, ReplicateExtent() * repeats, return Fragment(input_size_, forward_index_, new_forward_thread,
NullOpt); ReplicateExtent() * repeats, NullOpt);
} }
Fragment FragmentNode::DeReplicate() const { Fragment FragmentNode::DeReplicate() const {
...@@ -171,20 +184,22 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -171,20 +184,22 @@ Fragment FragmentNode::DeReplicate() const {
if (rep_size && idx_size) { if (rep_size && idx_size) {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size); factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
} }
if (factor == 1) return GetRef<Fragment>(this); if (factor == 1)
return GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
ReplicationPlaceholder() * factor + FloorMod(forward_index_[0], factor)); FloorMod(forward_index_[0], factor));
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap); PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)}; Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
return Fragment(input_size_, new_forward_index, new_forward_thread, int(*rep_size) / factor, return Fragment(input_size_, new_forward_index, new_forward_thread,
NullOpt); int(*rep_size) / factor, NullOpt);
} }
Layout LayoutNode::Inverse() const { Layout LayoutNode::Inverse() const {
arith::Analyzer analyzer; arith::Analyzer analyzer;
arith::IterMapResult res = arith::DetectIterMap(forward_index_, getVarMap(), 1, arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1,
arith::IterMapLevel::Bijective, &analyzer); arith::IterMapLevel::Bijective, &analyzer);
ICHECK(res->errors.empty()) << res->errors; ICHECK(res->errors.empty()) << res->errors;
...@@ -208,21 +223,25 @@ Layout LayoutNode::Inverse() const { ...@@ -208,21 +223,25 @@ Layout LayoutNode::Inverse() const {
return Layout(outputs_shape, backward_index); return Layout(outputs_shape, backward_index);
} }
PrimExpr infer_fragment_index(const Map<Var, Range>& input_iters, const PrimExpr& forward_thread, PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
arith::Analyzer* analyzer) { const PrimExpr &forward_thread,
Array<arith::IterSplitExpr> splits = arith::Analyzer *analyzer) {
DivideUnusedIterators({forward_thread}, ToIterVars(input_iters), analyzer); Array<arith::IterSplitExpr> splits = DivideUnusedIterators(
{forward_thread}, ToIterVars(input_iters), analyzer);
Array<arith::IterSplitExpr> split_without_rep; Array<arith::IterSplitExpr> split_without_rep;
for (const auto& split : splits) { for (const auto &split : splits) {
CHECK(split->source->source.as<Var>()); CHECK(split->source->source.as<Var>());
if (split->source->source.as<Var>().value().same_as(ReplicationPlaceholder())) continue; if (split->source->source.as<Var>().value().same_as(
ReplicationPlaceholder()))
continue;
split_without_rep.push_back(split); split_without_rep.push_back(split);
} }
return MakeFlattenedExpression(split_without_rep); return MakeFlattenedExpression(split_without_rep);
} }
FragmentNode::FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, FragmentNode::FragmentNode(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size) { PrimExpr forward_thread, PrimExpr replicate_size) {
input_size_ = input_size; input_size_ = input_size;
replicate_size_ = replicate_size; replicate_size_ = replicate_size;
...@@ -230,9 +249,11 @@ FragmentNode::FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_i ...@@ -230,9 +249,11 @@ FragmentNode::FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_i
UpdateAnalyzer(&analyzer); UpdateAnalyzer(&analyzer);
forward_thread_ = analyzer.Simplify(forward_thread); forward_thread_ = analyzer.Simplify(forward_thread);
if (forward_index.empty()) { if (forward_index.empty()) {
forward_index = {infer_fragment_index(getVarMap(), forward_thread_, &analyzer)}; forward_index = {
infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
} }
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); }); forward_index_ = forward_index.Map(
[&](const PrimExpr &e) { return analyzer.Simplify(e); });
} }
Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index, Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
...@@ -250,24 +271,28 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index, ...@@ -250,24 +271,28 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
replicate_size = thread_replicate->dom->extent; replicate_size = thread_replicate->dom->extent;
vmap.Set(thread_replicate->var, ReplicationPlaceholder()); vmap.Set(thread_replicate->var, ReplicationPlaceholder());
} }
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); }); forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap); forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, replicate_size); auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
data_ = std::move(n); data_ = std::move(n);
} }
Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size, Optional<Var> replicate_var) { PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var) {
if (replicate_var.defined()) { if (replicate_var.defined()) {
forward_thread = forward_thread = Substitute(
Substitute(forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}}); forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
} }
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, replicate_size); auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
data_ = std::move(n); data_ = std::move(n);
} }
void FragmentNode::VisitAttrs(tvm::AttrVisitor* v) { void FragmentNode::VisitAttrs(tvm::AttrVisitor *v) {
LayoutNode::VisitAttrs(v); LayoutNode::VisitAttrs(v);
v->Visit("forward_thread", &forward_thread_); v->Visit("forward_thread", &forward_thread_);
v->Visit("replicate_size", &replicate_size_); v->Visit("replicate_size", &replicate_size_);
...@@ -282,14 +307,15 @@ PrimExpr FragmentNode::ThreadExtent() const { ...@@ -282,14 +307,15 @@ PrimExpr FragmentNode::ThreadExtent() const {
return ist.max(); return ist.max();
} }
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr>& vars, PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
const Optional<PrimExpr>& rep_var) const { const Optional<PrimExpr> &rep_var) const {
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
ICHECK_EQ(vars.size(), InputDim()); ICHECK_EQ(vars.size(), InputDim());
for (size_t i = 0; i < InputDim(); i++) { for (size_t i = 0; i < InputDim(); i++) {
vmap.Set(InputPlaceholder(i), vars[i]); vmap.Set(InputPlaceholder(i), vars[i]);
} }
if (rep_var.defined()) vmap.Set(ReplicationPlaceholder(), rep_var.value()); if (rep_var.defined())
vmap.Set(ReplicationPlaceholder(), rep_var.value());
return Substitute(forward_thread_, vmap); return Substitute(forward_thread_, vmap);
} }
...@@ -299,7 +325,8 @@ Layout FragmentNode::Inverse() const { ...@@ -299,7 +325,8 @@ Layout FragmentNode::Inverse() const {
input_size_copy.push_back(ReplicateExtent()); input_size_copy.push_back(ReplicateExtent());
auto forward_index_copy = forward_index_; auto forward_index_copy = forward_index_;
forward_index_copy.push_back( forward_index_copy.push_back(
Substitute(forward_thread_, {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}})); Substitute(forward_thread_,
{{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
auto fwd = Layout(input_size_copy, forward_index_copy); auto fwd = Layout(input_size_copy, forward_index_copy);
auto bwd = fwd->Inverse(); auto bwd = fwd->Inverse();
return bwd; return bwd;
...@@ -311,8 +338,9 @@ Fragment FragmentNode::CondenseReplicateVar() const { ...@@ -311,8 +338,9 @@ Fragment FragmentNode::CondenseReplicateVar() const {
input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()}); input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
PrimExpr new_forward_thread; PrimExpr new_forward_thread;
IterVar new_thread_replicate; IterVar new_thread_replicate;
std::tie(new_forward_thread, new_thread_replicate) = CompressIterator( std::tie(new_forward_thread, new_thread_replicate) =
forward_thread_, ToIterVars(input_iters), ReplicationPlaceholder(), &analyzer); CompressIterator(forward_thread_, ToIterVars(input_iters),
ReplicationPlaceholder(), &analyzer);
return Fragment(input_size_, forward_index_, new_forward_thread, return Fragment(input_size_, forward_index_, new_forward_thread,
new_thread_replicate->dom->extent, new_thread_replicate->var); new_thread_replicate->dom->extent, new_thread_replicate->var);
} }
...@@ -330,12 +358,14 @@ void FragmentNode::DebugOutput() const { ...@@ -330,12 +358,14 @@ void FragmentNode::DebugOutput() const {
LOG_DEBUG << "Fragment ThreadIndex: " << forward_thread_; LOG_DEBUG << "Fragment ThreadIndex: " << forward_thread_;
} }
bool LayoutNode::SEqualReduce(const LayoutNode* other, SEqualReducer equal) const { bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) && return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_); equal(this->forward_index_, other->forward_index_);
} }
bool FragmentNode::SEqualReduce(const FragmentNode* other, SEqualReducer equal) const { bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) && return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) && equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) && equal(this->ThreadExtent(), other->ThreadExtent()) &&
...@@ -346,7 +376,7 @@ bool FragmentNode::SEqualReduce(const FragmentNode* other, SEqualReducer equal) ...@@ -346,7 +376,7 @@ bool FragmentNode::SEqualReduce(const FragmentNode* other, SEqualReducer equal)
TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode); TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1])); *ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1]));
}); });
...@@ -366,31 +396,32 @@ TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) { ...@@ -366,31 +396,32 @@ TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) {
return layout->GetForwardIndex(); return layout->GetForwardIndex();
}); });
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Fragment(args[0], args[1], args[2], args[3]); *ret = Fragment(args[0], args[1], args[2], args[3]);
}); });
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size").set_body_typed([](Fragment fragment) { TVM_REGISTER_GLOBAL("tl.Fragment_thread_size")
return fragment->ThreadExtent(); .set_body_typed([](Fragment fragment) { return fragment->ThreadExtent(); });
});
TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) { TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) {
return fragment->GetForwardThread(); return fragment->GetForwardThread();
}); });
TVM_REGISTER_GLOBAL("tl.Fragment_repeat") TVM_REGISTER_GLOBAL("tl.Fragment_repeat")
.set_body_typed([](Fragment fragment, Array<PrimExpr> repeats, bool repeat_on_thread, .set_body_typed([](Fragment fragment, Array<PrimExpr> repeats,
bool lower_dim_first) { bool repeat_on_thread, bool lower_dim_first) {
return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first); return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first);
}); });
TVM_REGISTER_GLOBAL("tl.Fragment_replicate").set_body_typed([](Fragment fragment, int repeats) { TVM_REGISTER_GLOBAL("tl.Fragment_replicate")
.set_body_typed([](Fragment fragment, int repeats) {
return fragment->Replicate(repeats); return fragment->Replicate(repeats);
}); });
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var").set_body_typed([](Fragment fragment) { TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var")
.set_body_typed([](Fragment fragment) {
return fragment->CondenseReplicateVar(); return fragment->CondenseReplicateVar();
}); });
TVM_REGISTER_GLOBAL("tl.make_swizzled_layout") TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
.set_body_typed([](int stride, int continuous, int element_size) { .set_body_typed([](int stride, int continuous, int element_size) {
......
...@@ -20,7 +20,7 @@ class Layout; ...@@ -20,7 +20,7 @@ class Layout;
class Fragment; class Fragment;
class LayoutNode : public Object { class LayoutNode : public Object {
public: public:
LayoutNode() = default; LayoutNode() = default;
LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index); LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
...@@ -34,21 +34,21 @@ class LayoutNode : public Object { ...@@ -34,21 +34,21 @@ class LayoutNode : public Object {
Array<PrimExpr> GetForwardIndex() const { return forward_index_; } Array<PrimExpr> GetForwardIndex() const { return forward_index_; }
virtual Array<PrimExpr> Forward(const Array<PrimExpr>& vars) const; virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
virtual Layout Inverse() const; virtual Layout Inverse() const;
virtual void DebugOutput() const; virtual void DebugOutput() const;
static constexpr bool _type_has_method_sequal_reduce = true; static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "tl.Layout"; static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode* other, SEqualReducer equal) const; bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor* v); void VisitAttrs(tvm::AttrVisitor *v);
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
protected: protected:
virtual Map<Var, Range> getVarMap() const; virtual Map<Var, Range> getVarMap() const;
void UpdateAnalyzer(arith::Analyzer* analyzer) const; void UpdateAnalyzer(arith::Analyzer *analyzer) const;
Array<PrimExpr> forward_index_; Array<PrimExpr> forward_index_;
Array<PrimExpr> input_size_; Array<PrimExpr> input_size_;
}; };
...@@ -57,7 +57,7 @@ class LayoutNode : public Object { ...@@ -57,7 +57,7 @@ class LayoutNode : public Object {
* \brief Layout reference class. * \brief Layout reference class.
*/ */
class Layout : public ObjectRef { class Layout : public ObjectRef {
public: public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index); TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index); TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
...@@ -65,10 +65,10 @@ class Layout : public ObjectRef { ...@@ -65,10 +65,10 @@ class Layout : public ObjectRef {
}; };
class FragmentNode : public LayoutNode { class FragmentNode : public LayoutNode {
public: public:
FragmentNode() = default; FragmentNode() = default;
FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, PrimExpr forward_thread, FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr replicate_size); PrimExpr forward_thread, PrimExpr replicate_size);
PrimExpr GetForwardThread() const { return forward_thread_; } PrimExpr GetForwardThread() const { return forward_thread_; }
...@@ -78,9 +78,10 @@ class FragmentNode : public LayoutNode { ...@@ -78,9 +78,10 @@ class FragmentNode : public LayoutNode {
PrimExpr ReplicateExtent() const { return replicate_size_; }; PrimExpr ReplicateExtent() const { return replicate_size_; };
PrimExpr ForwardThread(const Array<PrimExpr>& vars, const Optional<PrimExpr>& rep_var) const; PrimExpr ForwardThread(const Array<PrimExpr> &vars,
const Optional<PrimExpr> &rep_var) const;
Fragment Repeat(const Array<PrimExpr>& repeats, bool repeat_on_thread, Fragment Repeat(const Array<PrimExpr> &repeats, bool repeat_on_thread,
bool lower_dim_first = true) const; bool lower_dim_first = true) const;
Fragment Replicate(int repeats) const; Fragment Replicate(int repeats) const;
...@@ -91,12 +92,12 @@ class FragmentNode : public LayoutNode { ...@@ -91,12 +92,12 @@ class FragmentNode : public LayoutNode {
void DebugOutput() const final; void DebugOutput() const final;
void VisitAttrs(tvm::AttrVisitor* v); void VisitAttrs(tvm::AttrVisitor *v);
bool SEqualReduce(const FragmentNode* other, SEqualReducer equal) const; bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char* _type_key = "tl.Fragment"; static constexpr const char *_type_key = "tl.Fragment";
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode); TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);
protected: protected:
Map<Var, Range> getVarMap() const final; Map<Var, Range> getVarMap() const final;
PrimExpr forward_thread_; PrimExpr forward_thread_;
PrimExpr replicate_size_; PrimExpr replicate_size_;
...@@ -106,12 +107,13 @@ class FragmentNode : public LayoutNode { ...@@ -106,12 +107,13 @@ class FragmentNode : public LayoutNode {
* \brief Fragment reference class. * \brief Fragment reference class.
*/ */
class Fragment : public Layout { class Fragment : public Layout {
public: public:
TVM_DLL Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index, TVM_DLL Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
PrimExpr forward_thread, IterVar thread_replicate); PrimExpr forward_thread, IterVar thread_replicate);
TVM_DLL Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, TVM_DLL Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size, Optional<Var> replicate_var); PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
}; };
...@@ -121,38 +123,49 @@ Var ReplicationPlaceholder(); ...@@ -121,38 +123,49 @@ Var ReplicationPlaceholder();
Fragment makeGemmFragment8x8(); Fragment makeGemmFragment8x8();
Fragment makeGemmFragment8x8Transposed(); Fragment makeGemmFragment8x8Transposed();
Fragment makeGemmFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n, Fragment makeGemmFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size); const int element_size);
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n, Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size); const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m, Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size); const int warp_n, const int element_size);
Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block_k, Fragment makeGemmFragmentB(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int element_size); const int block_k, const int warp_m,
Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k, const int warp_n);
const int warp_m, const int warp_n);
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, const int block_k, Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n, bool transposed = false); const int block_k, const int warp_m,
const int warp_n, bool transposed = false);
// Default Memory Layout // Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous); Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfactor); Layout makeGemmABLayout(int stride, int continuous, int element_size,
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kfactor); int kfactor);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor);
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m, Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
const int warp_n, const int element_size); const int warp_m, const int warp_n,
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int block_k, const int element_size);
const int warp_m, const int warp_n); Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, int kfactor); const int block_k, const int warp_m,
const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor);
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);
namespace attr { namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block // BlockAttr, Containing the layout for all the buffers in the block
constexpr const char* kLayoutMap = "layout_map"; constexpr const char *kLayoutMap = "layout_map";
} // namespace attr } // namespace attr
} // namespace tl } // namespace tl
......
...@@ -34,20 +34,23 @@ PrimExpr SwizzlePattern::swizzle(PrimExpr expr) const { ...@@ -34,20 +34,23 @@ PrimExpr SwizzlePattern::swizzle(PrimExpr expr) const {
return low + high * base; return low + high * base;
} }
bool SwizzlePattern::operator==(const SwizzlePattern& other) const { bool SwizzlePattern::operator==(const SwizzlePattern &other) const {
return std::tie(base_, bits_, shift_) == std::tie(other.base_, other.bits_, other.shift_); return std::tie(base_, bits_, shift_) ==
std::tie(other.base_, other.bits_, other.shift_);
} }
SwizzledLayoutNode::SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, SwizzledLayoutNode::SwizzledLayoutNode(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
SwizzlePattern pattern) SwizzlePattern pattern)
: pattern_(pattern) { : pattern_(pattern) {
input_size_ = input_size; input_size_ = input_size;
arith::Analyzer analyzer; arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer); UpdateAnalyzer(&analyzer);
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); }); forward_index_ = forward_index.Map(
[&](const PrimExpr &e) { return analyzer.Simplify(e); });
} }
Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr>& vars) const { Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr> &vars) const {
auto expr_list = LayoutNode::Forward(vars); auto expr_list = LayoutNode::Forward(vars);
auto expr = expr_list.back(); auto expr = expr_list.back();
expr_list.pop_back(); expr_list.pop_back();
...@@ -57,8 +60,8 @@ Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr>& vars) const { ...@@ -57,8 +60,8 @@ Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr>& vars) const {
void SwizzledLayoutNode::DebugOutput() const { void SwizzledLayoutNode::DebugOutput() const {
LayoutNode::DebugOutput(); LayoutNode::DebugOutput();
std::cout << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() << " " std::cout << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits()
<< pattern_.Shift(); << " " << pattern_.Shift();
} }
Layout SwizzledLayoutNode::Inverse() const { Layout SwizzledLayoutNode::Inverse() const {
...@@ -66,7 +69,8 @@ Layout SwizzledLayoutNode::Inverse() const { ...@@ -66,7 +69,8 @@ Layout SwizzledLayoutNode::Inverse() const {
return {}; return {};
} }
SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forward_index, SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
Array<PrimExpr> forward_index,
SwizzlePattern pattern) { SwizzlePattern pattern) {
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
Array<PrimExpr> input_size; Array<PrimExpr> input_size;
...@@ -75,23 +79,29 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forwa ...@@ -75,23 +79,29 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forwa
CHECK(is_zero(forward_var[i]->dom->min)); CHECK(is_zero(forward_var[i]->dom->min));
input_size.push_back(forward_var[i]->dom->extent); input_size.push_back(forward_var[i]->dom->extent);
} }
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); }); forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern); auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
data_ = std::move(n); data_ = std::move(n);
} }
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
SwizzlePattern pattern) { SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern); auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
data_ = std::move(n); data_ = std::move(n);
} }
void SwizzledLayoutNode::VisitAttrs(tvm::AttrVisitor* v) { LayoutNode::VisitAttrs(v); } void SwizzledLayoutNode::VisitAttrs(tvm::AttrVisitor *v) {
LayoutNode::VisitAttrs(v);
}
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode* other, SEqualReducer equal) const { bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) && return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) && pattern_ == other->pattern_; equal(this->forward_index_, other->forward_index_) &&
pattern_ == other->pattern_;
} }
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode); TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
......
...@@ -19,16 +19,16 @@ namespace tl { ...@@ -19,16 +19,16 @@ namespace tl {
* \brief Swizzle pattern * \brief Swizzle pattern
*/ */
class SwizzlePattern { class SwizzlePattern {
public: public:
SwizzlePattern() = default; SwizzlePattern() = default;
SwizzlePattern(int bits, int base, int shift); SwizzlePattern(int bits, int base, int shift);
PrimExpr swizzle(PrimExpr expr) const; PrimExpr swizzle(PrimExpr expr) const;
int Bits() const { return bits_; } int Bits() const { return bits_; }
int Base() const { return base_; } int Base() const { return base_; }
int Shift() const { return shift_; } int Shift() const { return shift_; }
bool operator==(const SwizzlePattern& other) const; bool operator==(const SwizzlePattern &other) const;
private: private:
int bits_; int bits_;
int base_; int base_;
int shift_; int shift_;
...@@ -38,21 +38,21 @@ class SwizzlePattern { ...@@ -38,21 +38,21 @@ class SwizzlePattern {
* \brief Layout with swizzle * \brief Layout with swizzle
*/ */
class SwizzledLayoutNode : public LayoutNode { class SwizzledLayoutNode : public LayoutNode {
public: public:
SwizzledLayoutNode() = default; SwizzledLayoutNode() = default;
SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzlePattern pattern); SwizzlePattern pattern);
Array<PrimExpr> Forward(const Array<PrimExpr>& vars) const final; Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const final;
Layout Inverse() const final; Layout Inverse() const final;
void DebugOutput() const final; void DebugOutput() const final;
static constexpr const char* _type_key = "tl.SwizzledLayout"; static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode* other, SEqualReducer equal) const; bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor* v); void VisitAttrs(tvm::AttrVisitor *v);
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode); TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);
private: private:
SwizzlePattern pattern_; SwizzlePattern pattern_;
}; };
...@@ -60,11 +60,11 @@ class SwizzledLayoutNode : public LayoutNode { ...@@ -60,11 +60,11 @@ class SwizzledLayoutNode : public LayoutNode {
* \brief SwizzledLayout reference class. * \brief SwizzledLayout reference class.
*/ */
class SwizzledLayout : public Layout { class SwizzledLayout : public Layout {
public: public:
TVM_DLL SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forward_index, TVM_DLL SwizzledLayout(Array<IterVar> forward_var,
SwizzlePattern pattern); Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
SwizzlePattern pattern); Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode); TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode);
}; };
......
...@@ -18,9 +18,9 @@ namespace tl { ...@@ -18,9 +18,9 @@ namespace tl {
using namespace tir; using namespace tir;
using namespace arith; using namespace arith;
bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { bool CanProveDivisible(const PrimExpr &lhs, const PrimExpr &rhs) {
const auto* clhs = lhs.as<IntImmNode>(); const auto *clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>(); const auto *crhs = rhs.as<IntImmNode>();
if (crhs && crhs->value == 0) { if (crhs && crhs->value == 0) {
return false; return false;
} else if (clhs && crhs) { } else if (clhs && crhs) {
...@@ -33,20 +33,22 @@ bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { ...@@ -33,20 +33,22 @@ bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
/*! /*!
* \brief Collector that collects the outgoing split reference of each IterMark. * \brief Collector that collects the outgoing split reference of each IterMark.
* *
* These out-going splits can then be used to check if the iterators are independent. * These out-going splits can then be used to check if the iterators are
* independent.
*/ */
class IterMarkSplitCollector { class IterMarkSplitCollector {
public: public:
// mark all IterMarks that are visited. // mark all IterMarks that are visited.
std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_; std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
// each iter mark to its outgoing splits that are referenced. // each iter mark to its outgoing splits that are referenced.
std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash,
ObjectPtrEqual>
mark2splits_; mark2splits_;
/*! /*!
* \brief Collect all mark2splits recursively from indices. * \brief Collect all mark2splits recursively from indices.
* \param indices The iterator of interest. * \param indices The iterator of interest.
*/ */
void Collect(const Array<IterSumExpr>& indices) { void Collect(const Array<IterSumExpr> &indices) {
for (IterSumExpr sum_expr : indices) { for (IterSumExpr sum_expr : indices) {
for (IterSplitExpr split : sum_expr->args) { for (IterSplitExpr split : sum_expr->args) {
this->CollectInternal(split->source); this->CollectInternal(split->source);
...@@ -55,10 +57,11 @@ class IterMarkSplitCollector { ...@@ -55,10 +57,11 @@ class IterMarkSplitCollector {
} }
} }
void CollectInternal(const IterMark& mark) { void CollectInternal(const IterMark &mark) {
if (visited_.count(mark)) return; if (visited_.count(mark))
return;
visited_.insert(mark); visited_.insert(mark);
if (auto* op = mark->source.as<IterSumExprNode>()) { if (auto *op = mark->source.as<IterSumExprNode>()) {
for (IterSplitExpr split : op->args) { for (IterSplitExpr split : op->args) {
this->CollectInternal(split->source); this->CollectInternal(split->source);
mark2splits_[split->source].push_back(split); mark2splits_[split->source].push_back(split);
...@@ -67,9 +70,9 @@ class IterMarkSplitCollector { ...@@ -67,9 +70,9 @@ class IterMarkSplitCollector {
} }
}; };
Array<IterSplitExpr> get_unused_iters(const IterMark& mark, Array<IterSplitExpr> get_unused_iters(const IterMark &mark,
const std::vector<IterSplitExpr>& splits, const std::vector<IterSplitExpr> &splits,
Analyzer* analyzer) { Analyzer *analyzer) {
PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
std::vector<bool> used(splits.size(), false); std::vector<bool> used(splits.size(), false);
std::vector<IterSplitExpr> results; std::vector<IterSplitExpr> results;
...@@ -78,19 +81,24 @@ Array<IterSplitExpr> get_unused_iters(const IterMark& mark, ...@@ -78,19 +81,24 @@ Array<IterSplitExpr> get_unused_iters(const IterMark& mark,
size_t j = 0; size_t j = 0;
size_t lowest = splits.size(); size_t lowest = splits.size();
for (; j < splits.size(); ++j) { for (; j < splits.size(); ++j) {
if (used[j]) continue; if (used[j])
if (!used[j] && analyzer->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) { continue;
if (!used[j] && analyzer->CanProveEqual(splits[j]->lower_factor,
expected_lower_factor)) {
break; break;
} }
if (lowest == splits.size() || if (lowest == splits.size() ||
CanProveDivisible(splits[lowest]->lower_factor, splits[j]->lower_factor)) { CanProveDivisible(splits[lowest]->lower_factor,
splits[j]->lower_factor)) {
lowest = j; lowest = j;
} }
} }
if (j == splits.size()) { if (j == splits.size()) {
ICHECK(lowest != splits.size()); ICHECK(lowest != splits.size());
ICHECK(CanProveDivisible(splits[lowest]->lower_factor, expected_lower_factor)); ICHECK(CanProveDivisible(splits[lowest]->lower_factor,
results.emplace_back(mark, expected_lower_factor, expected_lower_factor));
results.emplace_back(
mark, expected_lower_factor,
FloorDiv(splits[lowest]->lower_factor, expected_lower_factor), 1); FloorDiv(splits[lowest]->lower_factor, expected_lower_factor), 1);
expected_lower_factor = splits[lowest]->lower_factor; expected_lower_factor = splits[lowest]->lower_factor;
} else { } else {
...@@ -99,36 +107,40 @@ Array<IterSplitExpr> get_unused_iters(const IterMark& mark, ...@@ -99,36 +107,40 @@ Array<IterSplitExpr> get_unused_iters(const IterMark& mark,
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
} }
} }
bool match_full_iter = analyzer->CanProveEqual(expected_lower_factor, mark->extent); bool match_full_iter =
analyzer->CanProveEqual(expected_lower_factor, mark->extent);
if (!match_full_iter) { if (!match_full_iter) {
results.emplace_back(mark, expected_lower_factor, FloorDiv(mark->extent, expected_lower_factor), results.emplace_back(mark, expected_lower_factor,
1); FloorDiv(mark->extent, expected_lower_factor), 1);
} }
return results; return results;
} }
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs, Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
const Array<IterVar> input_iters, Analyzer* analyzer) { const Array<IterVar> input_iters,
auto iter_sum = exprs.Map( Analyzer *analyzer) {
[&](const auto& e) { return NormalizeToIterSum(e, ToVMap(input_iters), analyzer); }); auto iter_sum = exprs.Map([&](const auto &e) {
return NormalizeToIterSum(e, ToVMap(input_iters), analyzer);
});
IterMarkSplitCollector collector; IterMarkSplitCollector collector;
collector.Collect(iter_sum); collector.Collect(iter_sum);
Array<IterSplitExpr> results; Array<IterSplitExpr> results;
for (const IterMark& mark : collector.visited_) { for (const IterMark &mark : collector.visited_) {
ICHECK(mark->source.as<Var>()) << "Not a normalized iterator: " << mark; ICHECK(mark->source.as<Var>()) << "Not a normalized iterator: " << mark;
} }
for (const IterVar& iter : input_iters) { for (const IterVar &iter : input_iters) {
IterMark iv_mark; IterMark iv_mark;
for (const IterMark& mark : collector.visited_) { for (const IterMark &mark : collector.visited_) {
if (mark->source.as<Var>().same_as(iter->var)) { if (mark->source.as<Var>().same_as(iter->var)) {
iv_mark = mark; iv_mark = mark;
break; break;
} }
} }
if (iv_mark.defined()) { if (iv_mark.defined()) {
auto splits = get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer); auto splits =
get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer);
// Put the small axis last // Put the small axis last
results.insert(results.end(), splits.rbegin(), splits.rend()); results.insert(results.end(), splits.rbegin(), splits.rend());
} else if (!is_one(iter->dom->extent)) { } else if (!is_one(iter->dom->extent)) {
...@@ -140,12 +152,12 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs, ...@@ -140,12 +152,12 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs,
return results; return results;
} }
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits) { PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr> &splits) {
Array<arith::IterSplitExpr> lists; Array<arith::IterSplitExpr> lists;
PrimExpr scale = 1; PrimExpr scale = 1;
for (int i = splits.size() - 1; i >= 0; i--) { for (int i = splits.size() - 1; i >= 0; i--) {
auto scaled_split = auto scaled_split = arith::IterSplitExpr(
arith::IterSplitExpr(splits[i]->source, splits[i]->lower_factor, splits[i]->extent, scale); splits[i]->source, splits[i]->lower_factor, splits[i]->extent, scale);
lists.push_back(scaled_split); lists.push_back(scaled_split);
scale *= splits[i]->extent; scale *= splits[i]->extent;
} }
...@@ -153,45 +165,47 @@ PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits) { ...@@ -153,45 +165,47 @@ PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits) {
} }
class IterSumMutator { class IterSumMutator {
public: public:
IterSumMutator(const Map<IterSplitExpr, IterSplitExpr>& replace_map) IterSumMutator(const Map<IterSplitExpr, IterSplitExpr> &replace_map)
: replace_map_(replace_map) {} : replace_map_(replace_map) {}
// override the original mutate function. // override the original mutate function.
IterSumExpr Mutate(const IterSumExpr& iter_sum) { IterSumExpr Mutate(const IterSumExpr &iter_sum) {
Array<IterSplitExpr> args; Array<IterSplitExpr> args;
for (const auto& split : iter_sum->args) { for (const auto &split : iter_sum->args) {
if (replace_map_.count(split)) { if (replace_map_.count(split)) {
args.push_back(replace_map_[split]); args.push_back(replace_map_[split]);
} else { } else {
auto split_ = auto split_ = IterSplitExpr(Mutate(split->source), split->lower_factor,
IterSplitExpr(Mutate(split->source), split->lower_factor, split->extent, split->scale); split->extent, split->scale);
args.push_back(split_); args.push_back(split_);
} }
} }
return IterSumExpr(args, iter_sum->base); return IterSumExpr(args, iter_sum->base);
} }
IterMark Mutate(const IterMark& mark) { IterMark Mutate(const IterMark &mark) {
if (auto* op = mark->source.as<IterSumExprNode>()) { if (auto *op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent); return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent);
} else { } else {
return mark; return mark;
} }
} }
private: private:
Map<IterSplitExpr, IterSplitExpr> replace_map_; Map<IterSplitExpr, IterSplitExpr> replace_map_;
}; };
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr, std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr &expr,
const Array<IterVar> input_iters, const Var& var, const Array<IterVar> input_iters,
arith::Analyzer* analyzer) { const Var &var,
auto iter_sum = arith::NormalizeToIterSum(expr, ToVMap(input_iters), analyzer); arith::Analyzer *analyzer) {
auto iter_sum =
arith::NormalizeToIterSum(expr, ToVMap(input_iters), analyzer);
IterMarkSplitCollector collector; IterMarkSplitCollector collector;
collector.Collect({iter_sum}); collector.Collect({iter_sum});
IterMark mark; IterMark mark;
for (const IterMark& m : collector.visited_) { for (const IterMark &m : collector.visited_) {
ICHECK(m->source.as<Var>()) << "Not a normalized iterator: " << mark; ICHECK(m->source.as<Var>()) << "Not a normalized iterator: " << mark;
if (m->source.as<Var>().value().same_as(var)) { if (m->source.as<Var>().value().same_as(var)) {
mark = m; mark = m;
...@@ -204,7 +218,7 @@ std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr, ...@@ -204,7 +218,7 @@ std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr,
} }
PrimExpr extent = 1; PrimExpr extent = 1;
for (const auto& split : splits) { for (const auto &split : splits) {
extent *= split->extent; extent *= split->extent;
} }
extent = analyzer->Simplify(extent); extent = analyzer->Simplify(extent);
...@@ -214,29 +228,31 @@ std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr, ...@@ -214,29 +228,31 @@ std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr,
auto new_mark = IterMark(new_var, extent); auto new_mark = IterMark(new_var, extent);
PrimExpr scale = 1; PrimExpr scale = 1;
Map<IterSplitExpr, IterSplitExpr> replace_map; Map<IterSplitExpr, IterSplitExpr> replace_map;
for (const auto& split : splits) { for (const auto &split : splits) {
auto rescaled = arith::IterSplitExpr(new_mark, scale, split->extent, split->scale); auto rescaled =
arith::IterSplitExpr(new_mark, scale, split->extent, split->scale);
replace_map.Set(split, rescaled); replace_map.Set(split, rescaled);
scale *= split->extent; scale *= split->extent;
} }
IterSumMutator mutator(replace_map); IterSumMutator mutator(replace_map);
PrimExpr reaplced = analyzer->Simplify(NormalizeIterMapToExpr(mutator.Mutate(iter_sum))); PrimExpr reaplced =
analyzer->Simplify(NormalizeIterMapToExpr(mutator.Mutate(iter_sum)));
return {reaplced, new_iter_var}; return {reaplced, new_iter_var};
} }
Array<IterVar> ToIterVars(const Map<Var, Range>& vmap) { Array<IterVar> ToIterVars(const Map<Var, Range> &vmap) {
Array<IterVar> result; Array<IterVar> result;
for (const auto& [var, range] : vmap) { for (const auto &[var, range] : vmap) {
result.push_back(IterVar(range, var, IterVarType::kDataPar)); result.push_back(IterVar(range, var, IterVarType::kDataPar));
} }
return result; return result;
} }
Map<Var, Range> ToVMap(const Array<IterVar>& ivs) { Map<Var, Range> ToVMap(const Array<IterVar> &ivs) {
Map<Var, Range> result; Map<Var, Range> result;
for (const auto& iv : ivs) { for (const auto &iv : ivs) {
result.Set(iv->var, iv->dom); result.Set(iv->var, iv->dom);
} }
return result; return result;
......
...@@ -23,36 +23,40 @@ using namespace tir; ...@@ -23,36 +23,40 @@ using namespace tir;
* If the expr is (x // 2) and x is in Range(4), * If the expr is (x // 2) and x is in Range(4),
* than the result should be (x % 2) * than the result should be (x % 2)
*/ */
Array<arith::IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs, Array<arith::IterSplitExpr>
DivideUnusedIterators(const Array<PrimExpr> &exprs,
const Array<IterVar> input_iters, const Array<IterVar> input_iters,
arith::Analyzer* analyzer); arith::Analyzer *analyzer);
/*! /*!
* \brief Compress the iterator var, remove the unused part of the var not present in the expr * \brief Compress the iterator var, remove the unused part of the var not
* present in the expr
* *
* Returns the compressed IterVar as well as the Updated iter sum expression. * Returns the compressed IterVar as well as the Updated iter sum expression.
*/ */
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr, std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr &expr,
const Array<IterVar> input_iters, const Var& var, const Array<IterVar> input_iters,
arith::Analyzer* analyzer); const Var &var,
arith::Analyzer *analyzer);
/*! /*!
* \brief Convert the iter splits returned by DivideUnusedIterators into flattened expression * \brief Convert the iter splits returned by DivideUnusedIterators into
* flattened expression
* *
*/ */
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits); PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr> &splits);
/*! /*!
* \brief Convert an Array of IterVar to a Map object * \brief Convert an Array of IterVar to a Map object
* *
*/ */
Map<Var, Range> ToVMap(const Array<IterVar>& ivs); Map<Var, Range> ToVMap(const Array<IterVar> &ivs);
/*! /*!
* \brief Convert a Map object to an Array of IterVar * \brief Convert a Map object to an Array of IterVar
* *
*/ */
Array<IterVar> ToIterVars(const Map<Var, Range>& vmap); Array<IterVar> ToIterVars(const Map<Var, Range> &vmap);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -13,76 +13,90 @@ ...@@ -13,76 +13,90 @@
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../target/cuda.h" #include "../target/cuda.h"
#include "../target/utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
#define TIR_DEFINE_TL_BUILTIN(OpName) \ #define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op& OpName() { \ const Op &OpName() { \
static const Op& op = Op::Get("tl." #OpName); \ static const Op &op = Op::Get("tl." #OpName); \
return op; \ return op; \
} \ } \
TVM_REGISTER_OP("tl." #OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)
TIR_DEFINE_TL_BUILTIN(CreateListofMBarrierOp) TIR_DEFINE_TL_BUILTIN(CreateListofMBarrierOp)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(CreateTMADescriptorOp) TIR_DEFINE_TL_BUILTIN(CreateTMADescriptorOp)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(CreateTMAIm2ColDescriptorOp) TIR_DEFINE_TL_BUILTIN(CreateTMAIm2ColDescriptorOp)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(GetMBarrierOp) TIR_DEFINE_TL_BUILTIN(GetMBarrierOp)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(TMALoadOp).set_num_inputs(-1).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(TMALoadOp).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque)); "TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMALoadIm2ColOp).set_num_inputs(-1).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(TMALoadIm2ColOp)
"TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreOp) TIR_DEFINE_TL_BUILTIN(TMAStoreOp)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(MBarrierWaitParity) TIR_DEFINE_TL_BUILTIN(MBarrierWaitParity)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(MBarrierExpectTX) TIR_DEFINE_TL_BUILTIN(MBarrierExpectTX)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(LDMatrixOp) TIR_DEFINE_TL_BUILTIN(LDMatrixOp)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(STMatrixOp) TIR_DEFINE_TL_BUILTIN(STMatrixOp)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SyncThreadsPartialOp) TIR_DEFINE_TL_BUILTIN(SyncThreadsPartialOp)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp) TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp)
.set_num_inputs(0) .set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SetMaxNReg) TIR_DEFINE_TL_BUILTIN(SetMaxNReg)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(WaitWgmma) TIR_DEFINE_TL_BUILTIN(WaitWgmma).set_num_inputs(1).set_attr<TCallEffectKind>(
.set_num_inputs(1) "TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(PackB16Op).set_num_inputs(2).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(PackB16Op).set_num_inputs(2).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure)); "TCallEffectKind", Integer(CallEffectKind::kPure));
......
...@@ -18,21 +18,23 @@ namespace tl { ...@@ -18,21 +18,23 @@ namespace tl {
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
* CuTensorMap* CreateTMADescriptorOp(data_type, rank, global_addr, global_shape..., * CuTensorMap* CreateTMADescriptorOp(data_type, rank, global_addr,
* global_stride..., smem_box..., smem_stride..., interleave, swizzle, l2_promotion, oob_fill) * global_shape..., global_stride..., smem_box..., smem_stride..., interleave,
* swizzle, l2_promotion, oob_fill)
* *
*/ */
const Op& CreateTMADescriptorOp(); const Op &CreateTMADescriptorOp();
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for image to column load * \brief tvm intrinsics for TMADescriptor creation for image to column load
* *
* CuTensorMap* CreateTMAIm2ColDescriptorOp(data_type, rank, global_addr, global_shape..., * CuTensorMap* CreateTMAIm2ColDescriptorOp(data_type, rank, global_addr,
* global_stride..., elem_stride..., lower_corner..., upper_corner..., smme_box_pixel, smem_box_channel, * global_shape..., global_stride..., elem_stride..., lower_corner...,
* interleave, swizzle, l2_promotion, oob_fill) * upper_corner..., smme_box_pixel, smem_box_channel, interleave, swizzle,
* l2_promotion, oob_fill)
* *
*/ */
const Op& CreateTMAIm2ColDescriptorOp(); const Op &CreateTMAIm2ColDescriptorOp();
/*! /*!
* \brief Create a list of mbarrier with num_threads * \brief Create a list of mbarrier with num_threads
...@@ -40,7 +42,7 @@ const Op& CreateTMAIm2ColDescriptorOp(); ...@@ -40,7 +42,7 @@ const Op& CreateTMAIm2ColDescriptorOp();
* GetMBarrier(num_threads0, num_threads1, ...) * GetMBarrier(num_threads0, num_threads1, ...)
* *
*/ */
const Op& CreateListofMBarrierOp(); const Op &CreateListofMBarrierOp();
/*! /*!
* \brief Get the mbarrier with barrier_id * \brief Get the mbarrier with barrier_id
...@@ -48,31 +50,35 @@ const Op& CreateListofMBarrierOp(); ...@@ -48,31 +50,35 @@ const Op& CreateListofMBarrierOp();
* int64_t* GetMBarrier(barrier_id) * int64_t* GetMBarrier(barrier_id)
* *
*/ */
const Op& GetMBarrierOp(); const Op &GetMBarrierOp();
/*! /*!
* \brief tvm intrinsics for loading data from global tensor descriptor to shared memory * \brief tvm intrinsics for loading data from global tensor descriptor to
* shared memory
* *
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ...) * TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
* *
*/ */
const Op& TMALoadOp(); const Op &TMALoadOp();
/*! /*!
* \brief tvm intrinsics for loading image from global tensor to columns in shared memory * \brief tvm intrinsics for loading image from global tensor to columns in
* shared memory
* *
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ..., image_offset, ...) * TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ...,
* image_offset, ...)
* *
*/ */
const Op& TMALoadIm2ColOp(); const Op &TMALoadIm2ColOp();
/*! /*!
* \brief tvm intrinsics for storing data from shared memory to global tensor descriptor * \brief tvm intrinsics for storing data from shared memory to global tensor
* descriptor
* *
* TMAStoreOp(descriptor, smem_data, coord_0, coord_1, ...) * TMAStoreOp(descriptor, smem_data, coord_0, coord_1, ...)
* *
*/ */
const Op& TMAStoreOp(); const Op &TMAStoreOp();
/*! /*!
* \brief tvm intrinsics for mbarrier wait with parity bit * \brief tvm intrinsics for mbarrier wait with parity bit
...@@ -80,7 +86,7 @@ const Op& TMAStoreOp(); ...@@ -80,7 +86,7 @@ const Op& TMAStoreOp();
* MBarrierWaitParity(mbarrier, parity) * MBarrierWaitParity(mbarrier, parity)
* *
*/ */
const Op& MBarrierWaitParity(); const Op &MBarrierWaitParity();
/*! /*!
* \brief tvm intrinsics for mbarrier expect tx * \brief tvm intrinsics for mbarrier expect tx
...@@ -88,7 +94,7 @@ const Op& MBarrierWaitParity(); ...@@ -88,7 +94,7 @@ const Op& MBarrierWaitParity();
* MBarrierExpectTX(mbarrier, transaction_bytes) * MBarrierExpectTX(mbarrier, transaction_bytes)
* *
*/ */
const Op& MBarrierExpectTX(); const Op &MBarrierExpectTX();
/*! /*!
* \brief tvm intrinsics for ldmatrix * \brief tvm intrinsics for ldmatrix
...@@ -96,7 +102,7 @@ const Op& MBarrierExpectTX(); ...@@ -96,7 +102,7 @@ const Op& MBarrierExpectTX();
* LDMatrixOp(transposed, num, shared_addr, local_addr) * LDMatrixOp(transposed, num, shared_addr, local_addr)
* *
*/ */
const Op& LDMatrixOp(); const Op &LDMatrixOp();
/*! /*!
* \brief tvm intrinsics for stmatrix * \brief tvm intrinsics for stmatrix
...@@ -104,7 +110,7 @@ const Op& LDMatrixOp(); ...@@ -104,7 +110,7 @@ const Op& LDMatrixOp();
* LDMatrixOp(transposed, num, shared_addr, int32_values...) * LDMatrixOp(transposed, num, shared_addr, int32_values...)
* *
*/ */
const Op& STMatrixOp(); const Op &STMatrixOp();
/*! /*!
* \brief Pack two b16 value into a b32 value * \brief Pack two b16 value into a b32 value
...@@ -112,7 +118,7 @@ const Op& STMatrixOp(); ...@@ -112,7 +118,7 @@ const Op& STMatrixOp();
* int32 PackB16Op(b16_value, b16_value) * int32 PackB16Op(b16_value, b16_value)
* *
*/ */
const Op& PackB16Op(); const Op &PackB16Op();
/*! /*!
* \brief Similar to __syncthreads(), but can be used to sync partial threads * \brief Similar to __syncthreads(), but can be used to sync partial threads
...@@ -120,7 +126,7 @@ const Op& PackB16Op(); ...@@ -120,7 +126,7 @@ const Op& PackB16Op();
* SyncThreadsPartialOp(num_partial_threads or mbarrier) * SyncThreadsPartialOp(num_partial_threads or mbarrier)
* *
*/ */
const Op& SyncThreadsPartialOp(); const Op &SyncThreadsPartialOp();
/*! /*!
* \brief Issue a shared memory fence for async operations * \brief Issue a shared memory fence for async operations
...@@ -128,7 +134,7 @@ const Op& SyncThreadsPartialOp(); ...@@ -128,7 +134,7 @@ const Op& SyncThreadsPartialOp();
* FenceProxyAsync() * FenceProxyAsync()
* *
*/ */
const Op& FenceProxyAsyncOp(); const Op &FenceProxyAsyncOp();
/*! /*!
* \brief Set reg hint for warp-specialized branched * \brief Set reg hint for warp-specialized branched
...@@ -136,7 +142,7 @@ const Op& FenceProxyAsyncOp(); ...@@ -136,7 +142,7 @@ const Op& FenceProxyAsyncOp();
* SetMaxNRegInc(num_reg, is_inc) * SetMaxNRegInc(num_reg, is_inc)
* *
*/ */
const Op& SetMaxNReg(); const Op &SetMaxNReg();
/*! /*!
* \brief Wait the previous wgmma to finish * \brief Wait the previous wgmma to finish
...@@ -144,7 +150,7 @@ const Op& SetMaxNReg(); ...@@ -144,7 +150,7 @@ const Op& SetMaxNReg();
* WaitWgmma(num_mma) * WaitWgmma(num_mma)
* *
*/ */
const Op& WaitWgmma(); const Op &WaitWgmma();
/*! /*!
* \brief tvm intrinsic for amd matrix core mfma instructions. * \brief tvm intrinsic for amd matrix core mfma instructions.
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../target/cuda.h" #include "../target/cuda.h"
#include "../target/utils.h"
#include "builtin.h" #include "builtin.h"
namespace tvm { namespace tvm {
...@@ -83,18 +83,20 @@ static int to_CUtensorMapDataType(DataType dtype) { ...@@ -83,18 +83,20 @@ static int to_CUtensorMapDataType(DataType dtype) {
return static_cast<int>(tp); return static_cast<int>(tp);
} }
template <typename T> template <typename T> static Array<T> ReverseArray(Array<T> array) {
static Array<T> ReverseArray(Array<T> array) {
return Array<T>{array.rbegin(), array.rend()}; return Array<T>{array.rbegin(), array.rend()};
} }
Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (!TargetIsHopper(T.target)) return Stmt(); if (!TargetIsHopper(T.target))
return Stmt();
bool is_load; bool is_load;
if (src.scope() == "global" && (dst.scope() == "shared.dyn" || dst.scope() == "shared")) { if (src.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared")) {
// Use the Hopper TMA bulk copy instructions // Use the Hopper TMA bulk copy instructions
is_load = true; is_load = true;
} else if (dst.scope() == "global" && (src.scope() == "shared.dyn" || src.scope() == "shared")) { } else if (dst.scope() == "global" &&
(src.scope() == "shared.dyn" || src.scope() == "shared")) {
is_load = false; is_load = false;
} else { } else {
return Stmt(); return Stmt();
...@@ -107,7 +109,8 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -107,7 +109,8 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
shared_tensor = T.buffer_remap[shared_tensor]; shared_tensor = T.buffer_remap[shared_tensor];
} }
if (T.layout_map.count(global_tensor)) { if (T.layout_map.count(global_tensor)) {
ICHECK(T.layout_map.count(global_tensor) == 0) << "Cannot support global layout."; ICHECK(T.layout_map.count(global_tensor) == 0)
<< "Cannot support global layout.";
} }
TMADesc desc; TMADesc desc;
...@@ -124,7 +127,8 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -124,7 +127,8 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
auto global_range = is_load ? src_range : dst_range; auto global_range = is_load ? src_range : dst_range;
desc.global_addr = global_tensor->data; desc.global_addr = global_tensor->data;
desc.global_shape = ReverseArray(global_tensor->shape); desc.global_shape = ReverseArray(global_tensor->shape);
Array<PrimExpr> global_coords = ReverseArray(global_range.Map([](Range r) { return r->min; })); Array<PrimExpr> global_coords =
ReverseArray(global_range.Map([](Range r) { return r->min; }));
if (!global_tensor->strides.empty()) { if (!global_tensor->strides.empty()) {
desc.global_stride = ReverseArray(global_tensor->strides); desc.global_stride = ReverseArray(global_tensor->strides);
} else { } else {
...@@ -139,11 +143,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -139,11 +143,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
// The first stride element should be 1 // The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes // Make global stride in bytes
desc.global_stride = desc.global_stride = desc.global_stride.Map(
desc.global_stride.Map([&](PrimExpr e) { return e * global_tensor->dtype.bytes(); }); [&](PrimExpr e) { return e * global_tensor->dtype.bytes(); });
// Smem Box // Smem Box
desc.smem_box = ReverseArray(global_range.Map([](Range r) { return r->extent; })); desc.smem_box =
ReverseArray(global_range.Map([](Range r) { return r->extent; }));
desc.smem_stride = Array<PrimExpr>(desc.rank, PrimExpr(1)); desc.smem_stride = Array<PrimExpr>(desc.rank, PrimExpr(1));
// L2 & OOB // L2 & OOB
...@@ -159,12 +164,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -159,12 +164,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
auto stride = as_const_int(shared_layout->InputShape()[0]); auto stride = as_const_int(shared_layout->InputShape()[0]);
auto continuous = as_const_int(shared_layout->InputShape()[1]); auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr); ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(*stride, *continuous, if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(
*stride, *continuous,
shared_tensor->dtype.bits()))) { shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()( } else if (StructuralEqual()(
shared_layout, shared_layout,
makeFullBankSwizzleLayout(*stride, *continuous, shared_tensor->dtype.bits()))) { makeFullBankSwizzleLayout(*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else { } else {
ICHECK(0) << "Cannot detect TMA layout."; ICHECK(0) << "Cannot detect TMA layout.";
...@@ -182,12 +189,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -182,12 +189,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
ICHECK((*inner_box_dim) % instruction_dim == 0); ICHECK((*inner_box_dim) % instruction_dim == 0);
desc.smem_box.Set(0, PrimExpr(instruction_dim)); desc.smem_box.Set(0, PrimExpr(instruction_dim));
Call create_descriptor = Call(DataType::Handle(), CreateTMADescriptorOp(), desc.EncodeCallArgs()); Call create_descriptor =
Call(DataType::Handle(), CreateTMADescriptorOp(), desc.EncodeCallArgs());
Array<PrimExpr> args; Array<PrimExpr> args;
args.reserve(desc.rank + 3); args.reserve(desc.rank + 3);
args.push_back(create_descriptor); args.push_back(create_descriptor);
if (is_load) args.push_back(0); // mbarrier id placeholder if (is_load)
args.push_back(0); // mbarrier id placeholder
auto op = is_load ? TMALoadOp() : TMAStoreOp(); auto op = is_load ? TMALoadOp() : TMAStoreOp();
Stmt tma_copy; Stmt tma_copy;
...@@ -196,18 +205,22 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -196,18 +205,22 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
Var loop_var("i"); Var loop_var("i");
int loop_extent = (*inner_box_dim) / instruction_dim; int loop_extent = (*inner_box_dim) / instruction_dim;
PrimExpr total_elements = 1; PrimExpr total_elements = 1;
for (auto e : desc.smem_box) total_elements *= e; for (auto e : desc.smem_box)
PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1, total_elements *= e;
PrimExpr shared_addr =
shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1,
total_elements * loop_var, total_elements); total_elements * loop_var, total_elements);
args.push_back(shared_addr); args.push_back(shared_addr);
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords) args.push_back(coord); for (auto coord : global_coords)
args.push_back(coord);
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args))); Evaluate(Call(DataType::Handle(), op, args)));
} else { } else {
PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1); PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1);
args.push_back(shared_addr); args.push_back(shared_addr);
for (auto coord : global_coords) args.push_back(coord); for (auto coord : global_coords)
args.push_back(coord);
tma_copy = Evaluate(Call(DataType::Handle(), op, args)); tma_copy = Evaluate(Call(DataType::Handle(), op, args));
} }
tma_copy = IfThenElse(EQ(T.thread_var, 0), tma_copy); tma_copy = IfThenElse(EQ(T.thread_var, 0), tma_copy);
...@@ -222,10 +235,14 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const { ...@@ -222,10 +235,14 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
args.push_back(data_type); args.push_back(data_type);
args.push_back(static_cast<int>(rank)); args.push_back(static_cast<int>(rank));
args.push_back(global_addr); args.push_back(global_addr);
for (auto e : global_shape) args.push_back(e); for (auto e : global_shape)
for (auto e : global_stride) args.push_back(e); args.push_back(e);
for (auto e : smem_box) args.push_back(e); for (auto e : global_stride)
for (auto e : smem_stride) args.push_back(e); args.push_back(e);
for (auto e : smem_box)
args.push_back(e);
for (auto e : smem_stride)
args.push_back(e);
args.push_back(interleave); args.push_back(interleave);
args.push_back(swizzle); args.push_back(swizzle);
args.push_back(l2_promotion); args.push_back(l2_promotion);
...@@ -247,9 +264,11 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -247,9 +264,11 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
padding = args[7].as<IntImm>().value()->value; padding = args[7].as<IntImm>().value()->value;
} }
Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target)); ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" && (dst.scope() == "shared.dyn" || dst.scope() == "shared")); ICHECK(src.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared"));
ICHECK(src->shape.size() == 4); ICHECK(src->shape.size() == 4);
ICHECK(dst->shape.size() == 2); ICHECK(dst->shape.size() == 2);
ICHECK(src->dtype == dst->dtype); ICHECK(src->dtype == dst->dtype);
...@@ -277,7 +296,8 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const ...@@ -277,7 +296,8 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const
// The first stride element should be 1 // The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes // Make global stride in bytes
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { return e * src->dtype.bytes(); }); desc.global_stride = desc.global_stride.Map(
[&](PrimExpr e) { return e * src->dtype.bytes(); });
desc.elem_stride = {1, stride, stride, 1}; desc.elem_stride = {1, stride, stride, 1};
desc.lower_corner = {-padding, -padding}; desc.lower_corner = {-padding, -padding};
desc.upper_corner = {-padding, -padding}; desc.upper_corner = {-padding, -padding};
...@@ -294,9 +314,11 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const ...@@ -294,9 +314,11 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const
auto continuous = as_const_int(shared_layout->InputShape()[1]); auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr); ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout, if (StructuralEqual()(shared_layout,
makeHalfBankSwizzleLayout(*stride, *continuous, dst->dtype.bits()))) { makeHalfBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(*stride, *continuous, } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(
*stride, *continuous,
dst->dtype.bits()))) { dst->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else { } else {
...@@ -304,27 +326,42 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const ...@@ -304,27 +326,42 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const
} }
} }
Call create_desc = Call(DataType::Handle(), CreateTMAIm2ColDescriptorOp(), desc.EncodeCallArgs()); Call create_desc = Call(DataType::Handle(), CreateTMAIm2ColDescriptorOp(),
desc.EncodeCallArgs());
Array<PrimExpr> global_coords; // c, w, h, n Array<PrimExpr> global_coords; // c, w, h, n
Array<PrimExpr> image_offset; // w, h Array<PrimExpr> image_offset; // w, h
global_coords.reserve(desc.rank); global_coords.reserve(desc.rank);
ICHECK(analyzer->CanProveEqual(FloorMod(desc.global_shape[0], desc.smem_box_channel), 0)) ICHECK(analyzer->CanProveEqual(
FloorMod(desc.global_shape[0], desc.smem_box_channel), 0))
<< "Currently can only support divisible channel case"; << "Currently can only support divisible channel case";
global_coords.push_back(FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0])); global_coords.push_back(
FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0]));
image_offset.push_back( image_offset.push_back(
dilation * FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]), kernel)); dilation *
image_offset.push_back(dilation * FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]),
FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0] * kernel)); kernel));
image_offset.push_back(dilation * FloorDiv(c_step * desc.smem_box_channel,
PrimExpr h_dim = FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1, stride) + 1; desc.global_shape[0] * kernel));
PrimExpr w_dim = FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1, stride) + 1;
global_coords.push_back(stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding); PrimExpr h_dim =
FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1,
stride) +
1;
PrimExpr w_dim =
FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1,
stride) +
1;
global_coords.push_back(
stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding);
global_coords.push_back(
stride *
FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) -
padding);
global_coords.push_back( global_coords.push_back(
stride * FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) - padding); FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
global_coords.push_back(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
Array<PrimExpr> args; Array<PrimExpr> args;
args.reserve(desc.rank * 2 + 1); args.reserve(desc.rank * 2 + 1);
...@@ -333,11 +370,14 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const ...@@ -333,11 +370,14 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const
auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst; auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst;
auto shared_addr = dst_buffer.access_ptr(2); auto shared_addr = dst_buffer.access_ptr(2);
args.push_back(shared_addr); args.push_back(shared_addr);
for (auto coord : global_coords) args.push_back(coord); for (auto coord : global_coords)
for (auto offset : image_offset) args.push_back(offset); args.push_back(coord);
for (auto offset : image_offset)
args.push_back(offset);
Stmt tma_copy = Stmt tma_copy =
IfThenElse(EQ(T.thread_var, 0), Evaluate(Call(DataType::Handle(), TMALoadIm2ColOp(), args))); IfThenElse(EQ(T.thread_var, 0),
Evaluate(Call(DataType::Handle(), TMALoadIm2ColOp(), args)));
return tma_copy; return tma_copy;
} }
...@@ -348,11 +388,16 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const { ...@@ -348,11 +388,16 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
args.push_back(data_type); args.push_back(data_type);
args.push_back(static_cast<int>(rank)); args.push_back(static_cast<int>(rank));
args.push_back(global_addr); args.push_back(global_addr);
for (auto e : global_shape) args.push_back(e); for (auto e : global_shape)
for (auto e : global_stride) args.push_back(e); args.push_back(e);
for (auto e : elem_stride) args.push_back(e); for (auto e : global_stride)
for (auto e : lower_corner) args.push_back(e); args.push_back(e);
for (auto e : upper_corner) args.push_back(e); for (auto e : elem_stride)
args.push_back(e);
for (auto e : lower_corner)
args.push_back(e);
for (auto e : upper_corner)
args.push_back(e);
args.push_back(smem_box_pixel); args.push_back(smem_box_pixel);
args.push_back(smem_box_channel); args.push_back(smem_box_channel);
args.push_back(interleave); args.push_back(interleave);
...@@ -365,7 +410,8 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const { ...@@ -365,7 +410,8 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col) TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(8) .set_num_inputs(8)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -49,12 +49,12 @@ struct TMAIm2ColDesc { ...@@ -49,12 +49,12 @@ struct TMAIm2ColDesc {
}; };
class Conv2DIm2ColOp : public Operator { class Conv2DIm2ColOp : public Operator {
public: public:
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap); Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op& Get(); static const Op &Get();
private: private:
Buffer src, dst; Buffer src, dst;
int stride, padding, dilation, kernel; int stride, padding, dilation, kernel;
PrimExpr nhw_step, c_step; PrimExpr nhw_step, c_step;
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../target/utils.h" #include "../target/utils.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h" #include "../transform/loop_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "builtin.h" #include "builtin.h"
namespace tvm { namespace tvm {
...@@ -37,7 +37,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) { ...@@ -37,7 +37,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
} }
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]); std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]); std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3){ if (args.size() >= 3) {
coalesced_width = Downcast<IntImm>(args[2]); coalesced_width = Downcast<IntImm>(args[2]);
} }
} }
...@@ -46,17 +46,20 @@ Array<IterVar> Copy::MakeIterVars() const { ...@@ -46,17 +46,20 @@ Array<IterVar> Copy::MakeIterVars() const {
Array<IterVar> loop_vars; Array<IterVar> loop_vars;
size_t idx = 0; size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) { for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent)) continue; if (is_one(src_range[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}); Var var = Var(std::string{char('i' + idx)});
idx++; idx++;
loop_vars.push_back({Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
} }
return loop_vars; return loop_vars;
} }
// ivs: itervars returned by MakeIterVars() // ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices // src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar>& ivs, int src_dst) const { Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices; Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range; Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0; size_t idx = 0;
...@@ -72,14 +75,16 @@ Array<PrimExpr> Copy::MakeIndices(const Array<IterVar>& ivs, int src_dst) const ...@@ -72,14 +75,16 @@ Array<PrimExpr> Copy::MakeIndices(const Array<IterVar>& ivs, int src_dst) const
return indices; return indices;
} }
PrimExpr Copy::MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& ivs, PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
Array<PrimExpr> extents, int src_dst) const { const Array<IterVar> &ivs, Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range; Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list; Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0; size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) { for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent)) continue; if (is_one(ranges[i]->extent))
continue;
PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i]; PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond); cond_list.push_back(cond);
...@@ -94,14 +99,16 @@ PrimExpr Copy::MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& iv ...@@ -94,14 +99,16 @@ PrimExpr Copy::MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& iv
return {}; return {};
else { else {
PrimExpr cond = cond_list[0]; PrimExpr cond = cond_list[0];
for (size_t i = 1; i < cond_list.size(); i++) cond = And(cond, cond_list[i]); for (size_t i = 1; i < cond_list.size(); i++)
cond = And(cond, cond_list[i]);
return cond; return cond;
} }
} }
For Copy::MakeSIMTLoop(arith::Analyzer* analyzer) const { For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars(); Array<IterVar> loop_vars = MakeIterVars();
for (const auto& iv : loop_vars) analyzer->Bind(iv->var, iv->dom); for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0); Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1); Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
...@@ -110,44 +117,52 @@ For Copy::MakeSIMTLoop(arith::Analyzer* analyzer) const { ...@@ -110,44 +117,52 @@ For Copy::MakeSIMTLoop(arith::Analyzer* analyzer) const {
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
PrimExpr value = BufferLoad(src, src_indices); PrimExpr value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype) value = Cast(dst->dtype, value); if (src->dtype != dst->dtype)
if (src_predicate.defined()) value = if_then_else(src_predicate, value, make_zero(dst->dtype)); value = Cast(dst->dtype, value);
if (src_predicate.defined())
value = if_then_else(src_predicate, value, make_zero(dst->dtype));
Stmt body = BufferStore(dst, value, dst_indices); Stmt body = BufferStore(dst, value, dst_indices);
if (dst_predicate.defined()) body = IfThenElse(dst_predicate, body); if (dst_predicate.defined())
body = IfThenElse(dst_predicate, body);
for (int i = loop_vars.size() - 1; i >= 0; i--) { for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {}; Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()){ if (coalesced_width.defined()) {
annotations.Set("coalesced_width", coalesced_width); annotations.Set("coalesced_width", coalesced_width);
} }
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, ForKind::kParallel, body, NullOpt, annotations); body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, NullOpt, annotations);
} }
return Downcast<For>(body); return Downcast<For>(body);
} }
Stmt Copy::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer); Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
if (ldsm_stmt.defined()) return ldsm_stmt; if (ldsm_stmt.defined())
return ldsm_stmt;
Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer); Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
if (bulk_copy_stmt.defined()) return bulk_copy_stmt; if (bulk_copy_stmt.defined())
return bulk_copy_stmt;
auto simt_loop = MakeSIMTLoop(analyzer); auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = std::make_unique<ParallelOp>(fused_loop); auto par_op = std::make_unique<ParallelOp>(fused_loop);
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap}, InferLevel::kFree); par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
auto thread_loop = InferLevel::kFree);
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop); auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
} }
return vectorized_thread_loop; return vectorized_thread_loop;
} }
Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
// Check buffer scope // Check buffer scope
bool is_ldmatrix; bool is_ldmatrix;
if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" && if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
...@@ -162,21 +177,26 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -162,21 +177,26 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
// Check no predicates // Check no predicates
Array<IterVar> loop_vars = MakeIterVars(); Array<IterVar> loop_vars = MakeIterVars();
if (loop_vars.size() < 2) return Stmt(); if (loop_vars.size() < 2)
for (const auto& iv : loop_vars) analyzer->Bind(iv->var, iv->dom); return Stmt();
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
if (src_predicate.defined() || dst_predicate.defined()) return Stmt(); if (src_predicate.defined() || dst_predicate.defined())
return Stmt();
Buffer shared_tensor = is_ldmatrix ? src : dst; Buffer shared_tensor = is_ldmatrix ? src : dst;
Buffer local_tensor = is_ldmatrix ? dst : src; Buffer local_tensor = is_ldmatrix ? dst : src;
Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0); Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]); Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
Array<PrimExpr> local_indices_transformed = local_layout->Forward(local_indices); Array<PrimExpr> local_indices_transformed =
local_layout->Forward(local_indices);
local_tensor = T.buffer_remap[local_tensor]; local_tensor = T.buffer_remap[local_tensor];
// currently only support 1-d case // currently only support 1-d case
if (local_layout->OutputDim() != 1) return Stmt(); if (local_layout->OutputDim() != 1)
return Stmt();
Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1); Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
Array<PrimExpr> shared_indices_transformed = shared_indices; Array<PrimExpr> shared_indices_transformed = shared_indices;
...@@ -193,27 +213,32 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -193,27 +213,32 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
IterVar row_var = loop_vars[loop_vars.size() - 2]; IterVar row_var = loop_vars[loop_vars.size() - 2];
PrimExpr local_layout_thread_map = PrimExpr local_layout_thread_map =
FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32); FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32);
PrimExpr matrix_8x8_thread_map = PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
makeGemmFragment8x8()->ForwardThread({FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt); {FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
PrimExpr matrix_8x8_thread_map_trans = makeGemmFragment8x8Transposed()->ForwardThread( PrimExpr matrix_8x8_thread_map_trans =
makeGemmFragment8x8Transposed()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt); {FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
PrimExpr local_indices_flattened = local_tensor.OffsetOf(local_indices_transformed).back(); PrimExpr local_indices_flattened =
local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) && if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, col_var->var, col_var->dom->extent, 2, IndiceCanVectorize(local_indices_flattened, col_var->var,
analyzer)) { col_var->dom->extent, 2, analyzer)) {
is_transposed = false; is_transposed = false;
} else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans, local_layout_thread_map) && } else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
IndiceCanVectorize(local_indices_flattened, row_var->var, row_var->dom->extent, 2, local_layout_thread_map) &&
analyzer)) { IndiceCanVectorize(local_indices_flattened, row_var->var,
row_var->dom->extent, 2, analyzer)) {
is_transposed = true; is_transposed = true;
} else { } else {
return Stmt(); return Stmt();
} }
// Check shared_layout is 16 bytes continuous // Check shared_layout is 16 bytes continuous
if (shared_tensor->dtype.bytes() != 2) return Stmt(); if (shared_tensor->dtype.bytes() != 2)
PrimExpr flattened_indice = shared_tensor.OffsetOf(shared_indices_transformed).back(); return Stmt();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var, loop_vars.back()->dom->extent, 8, PrimExpr flattened_indice =
analyzer)) shared_tensor.OffsetOf(shared_indices_transformed).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
loop_vars.back()->dom->extent, 8, analyzer))
return Stmt(); return Stmt();
// Can only support local_range to be a full range // Can only support local_range to be a full range
...@@ -232,7 +257,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -232,7 +257,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
num = 2; num = 2;
Array<PrimExpr> args; Array<PrimExpr> args;
const Op& op = is_ldmatrix ? tl::LDMatrixOp() : tl::STMatrixOp(); const Op &op = is_ldmatrix ? tl::LDMatrixOp() : tl::STMatrixOp();
args.push_back(static_cast<int>(is_transposed)); args.push_back(static_cast<int>(is_transposed));
args.push_back(num); args.push_back(num);
...@@ -240,52 +265,60 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -240,52 +265,60 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
// if not transpose // if not transpose
// coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4)) // coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
// if transpose // if transpose
// coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread % 8 / 2) // coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
// % 8 / 2)
Var local_iter("i"); Var local_iter("i");
Layout inv = local_layout->Inverse(); Layout inv = local_layout->Inverse();
Array<PrimExpr> shared_coords; Array<PrimExpr> shared_coords;
PrimExpr warp = FloorDiv(T.thread_var, 32) * 32; PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
if (!is_transposed) if (!is_transposed)
shared_coords = shared_coords = inv->Forward(
inv->Forward({local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num), {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
warp + FloorMod(T.thread_var, 8) * 4}); warp + FloorMod(T.thread_var, 8) * 4});
else else
shared_coords = shared_coords = inv->Forward(
inv->Forward({local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) + {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2), FloorMod(T.thread_var, 2),
warp + FloorDiv(FloorMod(T.thread_var, 8), 2)}); warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
shared_coords.pop_back(); // remove rep shared_coords.pop_back(); // remove rep
if (shared_layout.defined()) shared_coords = shared_layout->Forward(shared_coords); if (shared_layout.defined())
shared_coords = shared_layout->Forward(shared_coords);
PrimExpr shared_addr = shared_tensor.access_ptr( PrimExpr shared_addr = shared_tensor.access_ptr(
is_ldmatrix ? 1 : 2, DataType::Handle(), 1, shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num)); is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
args.push_back(shared_addr); args.push_back(shared_addr);
if (is_ldmatrix) { if (is_ldmatrix) {
// Can only support same dtype for ldmatrx // Can only support same dtype for ldmatrx
if (local_tensor->dtype != shared_tensor->dtype) return Stmt(); if (local_tensor->dtype != shared_tensor->dtype)
PrimExpr local_addr = return Stmt();
local_tensor.access_ptr(2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num)); PrimExpr local_addr = local_tensor.access_ptr(
2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
args.push_back(local_addr); args.push_back(local_addr);
} else { } else {
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
PrimExpr value0 = BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i}); PrimExpr value0 =
PrimExpr value1 = BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1}); BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
PrimExpr value1 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
if (local_tensor->dtype != shared_tensor->dtype) { if (local_tensor->dtype != shared_tensor->dtype) {
value0 = Cast(shared_tensor->dtype, value0); value0 = Cast(shared_tensor->dtype, value0);
value1 = Cast(shared_tensor->dtype, value1); value1 = Cast(shared_tensor->dtype, value1);
} }
PrimExpr value_packed = Call(DataType::Int(32), PackB16Op(), {value0, value1}); PrimExpr value_packed =
Call(DataType::Int(32), PackB16Op(), {value0, value1});
args.push_back(value_packed); args.push_back(value_packed);
} }
} }
auto body = Evaluate(Call(DataType::Handle(), op, args)); auto body = Evaluate(Call(DataType::Handle(), op, args));
For for_node = For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body); For for_node =
For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
for_node = LoopPragmaUnroll(for_node); for_node = LoopPragmaUnroll(for_node);
return for_node; return for_node;
} }
LayoutMap Copy::InferLayout(const LayoutInferArgs& T, InferLevel level) { LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// Use parallel op to infer the layout // Use parallel op to infer the layout
if (par_op_ == nullptr) { if (par_op_ == nullptr) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
...@@ -303,7 +336,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { ...@@ -303,7 +336,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
} }
} }
For Fill::MakeSIMTLoop(arith::Analyzer* analyzer) const { For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
int ndim = dst->shape.size(); int ndim = dst->shape.size();
Array<IterVar> loop_vars; Array<IterVar> loop_vars;
Array<PrimExpr> dst_indices; Array<PrimExpr> dst_indices;
...@@ -314,22 +347,26 @@ For Fill::MakeSIMTLoop(arith::Analyzer* analyzer) const { ...@@ -314,22 +347,26 @@ For Fill::MakeSIMTLoop(arith::Analyzer* analyzer) const {
} }
Stmt body = BufferStore(dst, value, dst_indices); Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) { for (int i = ndim - 1; i >= 0; i--) {
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, ForKind::kParallel, body); body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body);
} }
return Downcast<For>(body); return Downcast<For>(body);
} }
Stmt Fill::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") { if (dst.scope() == "local.fragment") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer)); auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.block_size, T.layout_map}, InferLevel::kFree); par_op->InferLayout({T.target, T.block_size, T.layout_map},
par_op->InferLayout({T.target, T.block_size, T.layout_map}, InferLevel::kFree); InferLevel::kFree);
auto thread_loop = par_op->InferLayout({T.target, T.block_size, T.layout_map},
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop); auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
} }
return vectorized_thread_loop; return vectorized_thread_loop;
} else if (dst.scope() == "local") { } else if (dst.scope() == "local") {
...@@ -339,16 +376,17 @@ Stmt Fill::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -339,16 +376,17 @@ Stmt Fill::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
} else { } else {
LOG(FATAL) << "Unsupported scope " << dst.scope(); LOG(FATAL) << "Unsupported scope " << dst.scope();
} }
} }
TIR_REGISTER_TL_OP(Copy, copy) TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(3) .set_num_inputs(3)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_REGISTER_TL_OP(Fill, fill) TIR_REGISTER_TL_OP(Fill, fill)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -19,25 +19,25 @@ namespace tl { ...@@ -19,25 +19,25 @@ namespace tl {
using namespace tir; using namespace tir;
class Copy : public Operator { class Copy : public Operator {
public: public:
Copy(Array<PrimExpr> args, BufferMap vmap); Copy(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op& Get(); static const Op &Get();
protected: protected:
Stmt LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const; Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
Stmt LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const; Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
For MakeSIMTLoop(arith::Analyzer* analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
Array<IterVar> MakeIterVars() const; Array<IterVar> MakeIterVars() const;
// ivs: itervars returned by MakeIterVars() // ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices // src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> MakeIndices(const Array<IterVar>& ivs, int src_dst) const; Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
PrimExpr MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& ivs, PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const; Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_; Array<PrimExpr> args_;
...@@ -50,13 +50,13 @@ class Copy : public Operator { ...@@ -50,13 +50,13 @@ class Copy : public Operator {
}; };
class Fill : public Operator { class Fill : public Operator {
public: public:
Fill(Array<PrimExpr> args, BufferMap vmap); Fill(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op& Get(); static const Op &Get();
private: private:
For MakeSIMTLoop(arith::Analyzer* analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
tir::Buffer dst; tir::Buffer dst;
PrimExpr value; PrimExpr value;
}; };
......
...@@ -52,11 +52,13 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { ...@@ -52,11 +52,13 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
} }
} }
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target) const { std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps,
Target target) const {
int m_warp = 1, n_warp = 1; int m_warp = 1, n_warp = 1;
if (TargetIsHopper(target)) { if (TargetIsHopper(target)) {
ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads."; ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
if (this->policy == GemmWarpPolicy::kFullRow || this->policy == GemmWarpPolicy::kSquare) { if (this->policy == GemmWarpPolicy::kFullRow ||
this->policy == GemmWarpPolicy::kSquare) {
m_warp = num_warps; m_warp = num_warps;
ICHECK(this->M % num_warps == 0); ICHECK(this->M % num_warps == 0);
} else if (this->policy == GemmWarpPolicy::kFullCol) { } else if (this->policy == GemmWarpPolicy::kFullCol) {
...@@ -100,14 +102,15 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target) con ...@@ -100,14 +102,15 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target) con
return {m_warp, n_warp}; return {m_warp, n_warp};
} }
Stmt Gemm::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32; int warp_size = 32;
if (TargetIsCDNA(T.target)) { if (TargetIsCDNA(T.target)) {
warp_size = 64; warp_size = 64;
} }
ICHECK(T.block_size % warp_size == 0); ICHECK(T.block_size % warp_size == 0);
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target); auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target);
std::stringstream ss; std::stringstream ss;
std::string op_name = "tl::gemm_ss"; std::string op_name = "tl::gemm_ss";
if (A.scope() == "local.fragment") { if (A.scope() == "local.fragment") {
...@@ -137,19 +140,23 @@ Stmt Gemm::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -137,19 +140,23 @@ Stmt Gemm::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
return Evaluate(new_call); return Evaluate(new_call);
} }
LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) { LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (completed_) return {}; if (completed_)
return {};
LayoutMap results; LayoutMap results;
ICHECK(C.scope() == "local.fragment"); ICHECK(C.scope() == "local.fragment");
if (TargetIsVolta(T.target)) { if (TargetIsVolta(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target); auto [warp_m, warp_n] =
auto fragment = makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]), results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]),
true, trans_A ? 1 : 2)); *as_const_int(A->shape[1]), true,
trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false); ICHECK(trans_A == false);
results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n)); results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
...@@ -158,25 +165,31 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -158,25 +165,31 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
} }
ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn"); ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]), results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]),
false, trans_B ? 2 : 1)); *as_const_int(B->shape[1]), false,
trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) { } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target); auto [warp_m, warp_n] =
auto fragment = makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]), results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]),
A->dtype.bits(), trans_A ? 1 : 2)); A->dtype.bits(), trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false); ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits())); results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits()));
} else { } else {
ICHECK(0); ICHECK(0);
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]), results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]),
B->dtype.bits(), trans_B ? 2 : 1)); B->dtype.bits(), trans_B ? 2 : 1));
} else if (B.scope() == "local.fragment") { } else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false); ICHECK(trans_B == false);
...@@ -186,18 +199,23 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -186,18 +199,23 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
} }
} else if (TargetIsHopper(T.target)) { } else if (TargetIsHopper(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target); auto [warp_m, warp_n] =
auto fragment = makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits()); ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]), results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]),
A->dtype.bits(), trans_A ? 1 : 2)); A->dtype.bits(), trans_A ? 1 : 2));
} else { } else {
ICHECK(trans_A == false); ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits())); results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits()));
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]), results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]),
B->dtype.bits(), trans_B ? 2 : 1)); B->dtype.bits(), trans_B ? 2 : 1));
} else { } else {
ICHECK(0) << "WGMMA only support B in shared."; ICHECK(0) << "WGMMA only support B in shared.";
...@@ -206,9 +224,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -206,9 +224,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
ICHECK(trans_B == true) << "Currently only support Transpose B for CDNA"; ICHECK(trans_B == true) << "Currently only support Transpose B for CDNA";
const int warp_size = 64; const int warp_size = 64;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target); auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
...@@ -216,24 +236,29 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -216,24 +236,29 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
// Make Linear Memory Access Layout // Make Linear Memory Access Layout
// auto shared_layout = // auto shared_layout =
// makeGemmLayoutLinear(*as_const_int(A->shape[0]), *as_const_int(A->shape[1])); // makeGemmLayoutLinear(*as_const_int(A->shape[0]),
// *as_const_int(A->shape[1]));
// Make Swizzle or Pad Layout // Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]), auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]),
A->dtype.bits(), kPack); A->dtype.bits(), kPack);
results.Set(A, shared_layout); results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A)); results.Set(
A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A));
} else { } else {
ICHECK(0); ICHECK(0);
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
// Make Linear Memory Access Layout // Make Linear Memory Access Layout
// auto shared_layout = // auto shared_layout =
// makeGemmLayoutLinear(*as_const_int(B->shape[0]), *as_const_int(B->shape[1])); // makeGemmLayoutLinear(*as_const_int(B->shape[0]),
// *as_const_int(B->shape[1]));
// Make Swizzle or Pad Layout // Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]), auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]),
B->dtype.bits(), kPack); B->dtype.bits(), kPack);
results.Set(B, shared_layout); results.Set(B, shared_layout);
...@@ -251,7 +276,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -251,7 +276,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP(Gemm, gemm) TIR_REGISTER_TL_OP(Gemm, gemm)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -18,18 +18,18 @@ namespace tl { ...@@ -18,18 +18,18 @@ namespace tl {
using namespace tir; using namespace tir;
class Gemm : public Operator { class Gemm : public Operator {
public: public:
Gemm(Array<PrimExpr> args, BufferMap vmap); Gemm(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op& Get(); static const Op &Get();
enum class GemmWarpPolicy { enum class GemmWarpPolicy {
kSquare = 0, kSquare = 0,
kFullRow = 1, kFullRow = 1,
kFullCol = 2, kFullCol = 2,
} policy; } policy;
private: private:
std::pair<int, int> ComputeWarpPartition(int num_warps, Target target) const; std::pair<int, int> ComputeWarpPartition(int num_warps, Target target) const;
Array<PrimExpr> call_args; Array<PrimExpr> call_args;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment