"docs/vscode:/vscode.git/clone" did not exist on "ea6938aea589b034c2320964bf066ba6dd33b12c"
Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
......@@ -201,7 +201,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including:
- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilzing magic layout transformation and intrins to accelerate dequantize gemm.
- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilizing magic layout transformation and intrins to accelerate dequantize gemm.
- [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning.
- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations.
- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col.
......
......@@ -7,7 +7,7 @@
- **Python Version**: >= 3.8
- **CUDA Version**: >= 11.0
The easiest way to install TileLang is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal.
The easiest way to install TileLang is directly from the PyPi using pip. To install the latest version, run the following command in your terminal.
**Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/microsoft/TileLang/blob/main/docs/Installation.md#building-from-source).**
......
......@@ -12,6 +12,7 @@ import torch
import argparse
from functools import partial
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
......@@ -24,12 +25,15 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype:
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
val_f16 = tir.reinterpret(
"float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
......@@ -46,7 +50,7 @@ def torch_convert(tensor):
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
......@@ -55,6 +59,7 @@ def torch_convert(tensor):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
......@@ -64,18 +69,15 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
@T.prim_func
def main(
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((N, K), in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(
T.ceildiv(K, block_K),
num_stages=1
):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
......@@ -89,12 +91,13 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
return main
def test_fp4_fp16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
program = test_convert(
N,
K,
K,
block_N,
block_K,
"float16",
......@@ -109,6 +112,7 @@ def test_fp4_fp16_convert_close():
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
def get_configs():
block_M = [128]
block_N = [128, 256]
......@@ -118,13 +122,19 @@ def get_configs():
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [
{'block_M': c[0], 'block_N': c[1], 'block_K': c[2], 'num_stages': c[3], 'threads': c[4], 'split': c[5]}
for c in _configs
]
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4],
'split': c[5]
} for c in _configs]
return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
......@@ -142,10 +152,13 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
):
SplitC = T.alloc_buffer(
[split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype
)
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz):
SplitC = T.alloc_buffer([
split, (N + block_N - 1) // block_N * block_N,
(M + block_M - 1) // block_M * block_M
], out_dtype)
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
......@@ -154,12 +167,10 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
......@@ -175,7 +186,8 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc)
......@@ -183,14 +195,15 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
for i, j in T.Parallel(block_N, block_M):
acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
T.copy(acc, Ct[bx * block_N, by * block_M])
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
......@@ -199,12 +212,10 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
......@@ -221,31 +232,43 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
if split == 1:
return main
else:
return main_split
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10,
rep=10
)
@jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None, profiler="auto")
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None):
rep=10)
@jit(
out_idx=[2],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel
def ref_program(A, qB):
dtypeC = "float16"
B = torch_convert(qB)
......@@ -253,6 +276,7 @@ def ref_program(A, qB):
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
......@@ -264,7 +288,9 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K
if (not args.tune):
program = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
program = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
......@@ -276,7 +302,8 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, ref_latency = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
best_latency, best_config, ref_latency = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -6,18 +6,15 @@ from tilelang import Profiler
import tilelang.language as T
def matmul(
M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"
):
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128
) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......
......@@ -9,8 +9,7 @@ import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
......@@ -180,6 +179,7 @@ def tl_matmul(
return main
M, N, K = 128, 128, 128
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float16"
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
......
......@@ -61,7 +61,6 @@ Fragment makeGemmFragmentC16x16CDNA() {
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragment8x8Transposed() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8);
......@@ -80,57 +79,70 @@ Fragment makeGemmFragment8x16() {
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, const int warp_m,
const int warp_n) {
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n,
const int warp_m, const int warp_n) {
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0);
ICHECK(warp_n % 16 == 0);
auto base_layout = makeGemmFragment8x8();
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout = warp_layout->Repeat({warp_m / 8, warp_n / 8}, false, false);
auto warp_layout =
base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout =
warp_layout->Repeat({warp_m / 8, warp_n / 8}, false, false);
return block_layout;
}
Fragment makeGemmFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n,
Fragment makeGemmFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64) return makeGemmFragmentC_F64(block_m, block_n, warp_m, warp_n);
if (element_size == 64)
return makeGemmFragmentC_F64(block_m, block_n, warp_m, warp_n);
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 8}, false, false);
auto warp_layout =
base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout =
warp_layout->Repeat({warp_m / 16, warp_n / 8}, false, false);
return block_layout;
}
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64) LOG(FATAL) << "Not supported";
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64)
LOG(FATAL) << "Not supported";
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
auto warp_layout = base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true);
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
return block_layout;
}
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m,
const int warp_n, const int element_size) {
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
ICHECK(block_m % warp_m == 0);
// ICHECK(block_n == warp_n);
ICHECK(warp_m % 16 == 0);
auto warp_layout =
makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false, false); // 16 x N (1 warp)
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); // 16*Y x N (Y warp)
auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false,
false); // 16 x N (1 warp)
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n},
true, false); // 16*Y x N (Y warp)
return block_layout->Repeat({warp_m / 16, 1}, false, false);
}
Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n, const int element_size) {
Fragment makeGemmFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
......@@ -140,13 +152,17 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block
ICHECK(element_size == 8 || element_size == 16);
if (element_size == 8) {
auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)->Replicate(block_n / warp_n);
auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false);
return block_layout;
} else if (element_size == 16) {
auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)->Replicate(block_n / warp_n);
auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
return block_layout;
} else {
ICHECK(0);
......@@ -154,36 +170,45 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block
}
}
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n, bool transposed) {
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, bool transposed) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0);
if (transposed) {
auto base_layout = makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
auto warp_layout = base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)->Replicate(block_n / warp_n);
auto base_layout =
makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
->Replicate(block_n / warp_n);
return block_layout;
} else {
auto base_layout = makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false);
auto warp_layout = base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, 1}, true, true)->Replicate(block_n / warp_n);
auto base_layout =
makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
->Replicate(block_n / warp_n);
return block_layout;
}
}
Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n) {
Fragment makeGemmFragmentB(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n) {
// transposed
ICHECK(warp_n % 8 == 0);
ICHECK(block_k % 16 == 0);
auto base_layout = makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({1, block_n / warp_n}, true);
auto block_layout = warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
auto base_layout =
makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
auto warp_layout = base_layout->Replicate(block_m / warp_m)
->Repeat({1, block_n / warp_n}, true);
auto block_layout =
warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
return block_layout;
}
......@@ -195,33 +220,39 @@ Fragment makeGemmFragment32x32(int element_size) {
if (element_size == 16) {
PrimExpr thd = FloorMod(i, 4) + FloorDiv(FloorMod(i, 16), 8) * 4 +
FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
PrimExpr idx = FloorMod(j, 4) + FloorDiv(j, 16) * 4 + FloorDiv(FloorMod(i, 8), 4) * 8 +
PrimExpr idx = FloorMod(j, 4) + FloorDiv(j, 16) * 4 +
FloorDiv(FloorMod(i, 8), 4) * 8 +
FloorDiv(FloorMod(j, 8), 4) * 16;
return Fragment({i, j}, {idx}, thd, rep);
} else {
PrimExpr thd = FloorMod(i, 2) + 2 * FloorDiv(FloorMod(j, 4), 2) +
FloorDiv(FloorMod(i, 16), 8) * 4 + FloorDiv(FloorMod(j, 16), 8) * 8 +
FloorDiv(i, 16) * 16;
PrimExpr idx = FloorMod(j, 2) + 2 * FloorDiv(FloorMod(i, 4), 2) + FloorDiv(j, 16) * 4 +
FloorDiv(FloorMod(i, 8), 4) * 8 + FloorDiv(FloorMod(j, 8), 4) * 16;
FloorDiv(FloorMod(i, 16), 8) * 4 +
FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
PrimExpr idx = FloorMod(j, 2) + 2 * FloorDiv(FloorMod(i, 4), 2) +
FloorDiv(j, 16) * 4 + FloorDiv(FloorMod(i, 8), 4) * 8 +
FloorDiv(FloorMod(j, 8), 4) * 16;
return Fragment({i, j}, {idx}, thd, rep);
}
}
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m,
const int warp_n, int element_size) {
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
int element_size) {
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 32 == 0);
ICHECK(warp_n % 32 == 0);
auto base_layout = makeGemmFragment32x32(element_size);
auto warp_layout = base_layout->Repeat({warp_m / 32, warp_n / 32}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true);
auto warp_layout =
base_layout->Repeat({warp_m / 32, warp_n / 32}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true);
return block_layout;
}
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n) {
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
......@@ -231,17 +262,22 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int
IterVar i = make_itervar("i", 32);
IterVar j = make_itervar("j", 4);
IterVar rep = make_itervar("rep", 2);
PrimExpr thd = FloorDiv(FloorMod(i, 16), 8) * 4 + 16 * FloorDiv(i, 16) + FloorMod(i, 4) + 8 * rep;
PrimExpr thd = FloorDiv(FloorMod(i, 16), 8) * 4 + 16 * FloorDiv(i, 16) +
FloorMod(i, 4) + 8 * rep;
PrimExpr idx = j + FloorDiv(FloorMod(i, 8), 4) * 4;
Fragment base_layout = Fragment({i, j}, {idx}, thd, rep);
auto warp_layout = base_layout->Repeat({warp_m / 32, block_k / 4}, false, false);
auto block_layout = warp_layout->Replicate(block_n / warp_n)->Repeat({block_m / warp_m, 1}, true);
auto warp_layout =
base_layout->Repeat({warp_m / 32, block_k / 4}, false, false);
auto block_layout = warp_layout->Replicate(block_n / warp_n)
->Repeat({block_m / warp_m, 1}, true);
return block_layout;
}
PrimExpr xor2x2(const PrimExpr& i, const PrimExpr& j) { return FloorMod(i + j, 2); }
PrimExpr xor2x2(const PrimExpr &i, const PrimExpr &j) {
return FloorMod(i + j, 2);
}
PrimExpr xor4x4(const PrimExpr& i, const PrimExpr& j) {
PrimExpr xor4x4(const PrimExpr &i, const PrimExpr &j) {
PrimExpr i0 = FloorMod(i, 2);
PrimExpr j0 = FloorMod(j, 2);
PrimExpr i1 = FloorDiv(i, 2);
......@@ -249,7 +285,7 @@ PrimExpr xor4x4(const PrimExpr& i, const PrimExpr& j) {
return 2 * xor2x2(i1, j1) + xor2x2(i0, j0);
}
PrimExpr xor8x8(const PrimExpr& i, const PrimExpr j) {
PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) {
PrimExpr i0 = FloorMod(i, 2);
PrimExpr j0 = FloorMod(j, 2);
PrimExpr i1 = FloorDiv(i, 2);
......@@ -291,8 +327,10 @@ Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) {
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}
// Detail implementation please ref to bitblas::tl::mfma_layout::make_mfma_swizzle_layout
Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size, int kPack=1) {
// Detail implementation please ref to
// bitblas::tl::mfma_layout::make_mfma_swizzle_layout
Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size,
int kPack = 1) {
const int numBanks = 32;
const int bankBitWidth = 32;
const int SIMDWidth = 16;
......@@ -352,7 +390,8 @@ Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) {
IterVar j = make_itervar("j", continuous);
int padded = continuous;
// Add 128 bits padding when the last dim is a multiple of 256 bits
if ((element_size * continuous) % 256 == 0) padded += 128 / element_size;
if ((element_size * continuous) % 256 == 0)
padded += 128 / element_size;
return Layout(Array{i, j}, {i * padded + j});
}
......@@ -363,14 +402,17 @@ Layout MakeGemmVoltaABLayoutCrosswise(int stride, int continuous) {
PrimExpr vec_contiguous_idx = FloorDiv(j, 4);
PrimExpr vec_strided_within_tile = FloorMod(vec_contiguous_idx, 8);
PrimExpr bit2 = FloorMod(FloorDiv(FloorMod(i, 32), 16) + FloorDiv(FloorMod(i, 16), 8) +
FloorDiv(vec_strided_within_tile, 4),
2);
PrimExpr bit1 =
xor2x2(FloorDiv(FloorMod(i, 8), 4), FloorDiv(FloorMod(vec_strided_within_tile, 4), 2));
PrimExpr permuted_vec_contiguous = FloorDiv(i, 16) * 16 + FloorMod(i, 4) * 4 + bit2 * 2 + bit1;
PrimExpr offset = FloorMod(j, 4) + permuted_vec_contiguous * 4 + vec_contiguous_idx * stride * 4;
PrimExpr bit2 =
FloorMod(FloorDiv(FloorMod(i, 32), 16) + FloorDiv(FloorMod(i, 16), 8) +
FloorDiv(vec_strided_within_tile, 4),
2);
PrimExpr bit1 = xor2x2(FloorDiv(FloorMod(i, 8), 4),
FloorDiv(FloorMod(vec_strided_within_tile, 4), 2));
PrimExpr permuted_vec_contiguous =
FloorDiv(i, 16) * 16 + FloorMod(i, 4) * 4 + bit2 * 2 + bit1;
PrimExpr offset = FloorMod(j, 4) + permuted_vec_contiguous * 4 +
vec_contiguous_idx * stride * 4;
return Layout(Array{i, j}, {offset});
}
......@@ -390,9 +432,11 @@ Layout MakeGemmVoltaALayoutCongruous(int stride, int continuous) {
FloorMod(tile_contiguous_residual, 2) * 4 +
xor4x4(tile_strided_residual, permuted_strided_within_tile);
PrimExpr element_strided = permuted_strided_within_tile + tile_strided_idx * 4;
PrimExpr element_strided =
permuted_strided_within_tile + tile_strided_idx * 4;
PrimExpr element_contiguous =
FloorMod(j, 8) + (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
FloorMod(j, 8) +
(permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
PrimExpr offset = element_strided * continuous + element_contiguous;
return Layout(Array{i, j}, {offset});
}
......@@ -413,30 +457,37 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
FloorDiv(tile_contiguous_residual, 4) * 4 +
xor4x4(tile_strided_residual, permuted_strided_within_tile);
PrimExpr element_strided = permuted_strided_within_tile + tile_strided_idx * 4;
PrimExpr element_strided =
permuted_strided_within_tile + tile_strided_idx * 4;
PrimExpr element_contiguous =
FloorMod(j, 8) + (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
FloorMod(j, 8) +
(permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
PrimExpr offset = element_strided * continuous + element_contiguous;
return Layout(Array{i, j}, {offset});
}
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, int kfactor) {
if (kfactor == 2) return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0) return MakeGemmVoltaALayoutCongruous(stride, continuous);
if (!is_a && continuous % 64 == 0) return MakeGemmVoltaBLayoutCongruous(stride, continuous);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor) {
if (kfactor == 2)
return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous);
if (!is_a && continuous % 64 == 0)
return MakeGemmVoltaBLayoutCongruous(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, 16);
}
Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfactor) {
Layout makeGemmABLayout(int stride, int continuous, int element_size,
int kfactor) {
if (element_size == 64) {
if (kfactor == 1 && continuous % 16 == 0) // float64 KxN
if (kfactor == 1 && continuous % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(stride, continuous);
if (kfactor == 2 && continuous % 16 == 0) // float64 NxK
if (kfactor == 2 && continuous % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
if (kfactor == 1 && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(stride, continuous, element_size);
else if (continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(stride, continuous, element_size);
......@@ -447,7 +498,8 @@ Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfacto
}
}
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kPack) {
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kPack) {
int vector_size = 128 / element_size;
if (continuous % (vector_size * 4) == 0)
return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
......@@ -455,5 +507,5 @@ Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kP
return makeGemmABLayoutPadded(stride, continuous, element_size);
}
}
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -20,7 +20,7 @@ namespace tl {
using namespace tir;
static Var getPlaceholder(const std::string& s) {
static Var getPlaceholder(const std::string &s) {
static std::unordered_map<std::string, Var> map;
if (map.find(s) == map.end()) {
map[s] = Var(s);
......@@ -29,7 +29,9 @@ static Var getPlaceholder(const std::string& s) {
}
Var ReplicationPlaceholder() { return getPlaceholder("_rep"); }
Var InputPlaceholder(size_t idx) { return getPlaceholder(std::string{'_', char('i' + idx)}); }
Var InputPlaceholder(size_t idx) {
return getPlaceholder(std::string{'_', char('i' + idx)});
}
Map<Var, Range> LayoutNode::getVarMap() const {
Map<Var, Range> map;
......@@ -45,11 +47,13 @@ Map<Var, Range> FragmentNode::getVarMap() const {
return map;
}
LayoutNode::LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
LayoutNode::LayoutNode(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index) {
input_size_ = input_size;
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); });
forward_index_ = forward_index.Map(
[&](const PrimExpr &e) { return analyzer.Simplify(e); });
}
Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
......@@ -60,7 +64,8 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
CHECK(is_zero(forward_var[i]->dom->min));
input_size.push_back(forward_var[i]->dom->extent);
}
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
......@@ -71,13 +76,13 @@ Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
data_ = std::move(n);
}
void LayoutNode::VisitAttrs(AttrVisitor* v) {
void LayoutNode::VisitAttrs(AttrVisitor *v) {
v->Visit("input_size", &input_size_);
v->Visit("forward_index", &forward_index_);
}
void LayoutNode::UpdateAnalyzer(arith::Analyzer* analyzer) const {
for (const auto& [var, dom] : getVarMap()) {
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
for (const auto &[var, dom] : getVarMap()) {
analyzer->Bind(var, dom);
}
}
......@@ -99,66 +104,74 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
return ret;
}
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr>& vars) const {
if (vars.empty()) return forward_index_;
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
if (vars.empty())
return forward_index_;
ICHECK_EQ(vars.size(), InputDim());
Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < InputDim(); i++) {
vmap.Set(InputPlaceholder(i), vars[i]);
}
return forward_index_.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
return forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); });
}
Fragment FragmentNode::Repeat(const Array<PrimExpr>& repeats, bool repeat_on_thread,
Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
bool repeat_on_thread,
bool lower_dim_first) const {
ICHECK_EQ(repeats.size(), InputDim());
Array<PrimExpr> new_input_size;
Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < InputDim(); i++) {
new_input_size.push_back(input_size_[i] * repeats[i]);
vmap.Set(InputPlaceholder(i), FloorMod(InputPlaceholder(i), InputShape()[i]));
vmap.Set(InputPlaceholder(i),
FloorMod(InputPlaceholder(i), InputShape()[i]));
}
PrimExpr repeats_index = 0, repeat_stride = 1;
if (lower_dim_first) {
for (int i = InputDim() - 1; i >= 0; i--) {
repeats_index += repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
repeats_index +=
repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
repeat_stride *= repeats[i];
}
} else {
for (size_t i = 0; i < InputDim(); i++) {
repeats_index += repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
repeats_index +=
repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
repeat_stride *= repeats[i];
}
}
if (repeat_on_thread) {
PrimExpr thread_size = ThreadExtent();
auto new_forward_index =
forward_index_.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
auto new_forward_thread = Substitute(forward_thread_, vmap) + thread_size * repeats_index;
return Fragment(new_input_size, new_forward_index, new_forward_thread, replicate_size_,
NullOpt);
auto new_forward_index = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); });
auto new_forward_thread =
Substitute(forward_thread_, vmap) + thread_size * repeats_index;
return Fragment(new_input_size, new_forward_index, new_forward_thread,
replicate_size_, NullOpt);
} else {
ICHECK(OutputDim() == 1);
PrimExpr frag_len = OutputShape()[0];
Array<PrimExpr> new_forward_index = {Substitute(forward_index_[0], vmap) +
frag_len * repeats_index};
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
return Fragment(new_input_size, new_forward_index, new_forward_thread, replicate_size_,
NullOpt);
return Fragment(new_input_size, new_forward_index, new_forward_thread,
replicate_size_, NullOpt);
}
}
Fragment FragmentNode::Replicate(int repeats) const {
ICHECK(repeats >= 1);
Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), FloorMod(ReplicationPlaceholder(), ReplicateExtent()));
vmap.Set(ReplicationPlaceholder(),
FloorMod(ReplicationPlaceholder(), ReplicateExtent()));
PrimExpr new_forward_thread =
Substitute(forward_thread_, vmap) +
ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
return Fragment(input_size_, forward_index_, new_forward_thread, ReplicateExtent() * repeats,
NullOpt);
return Fragment(input_size_, forward_index_, new_forward_thread,
ReplicateExtent() * repeats, NullOpt);
}
Fragment FragmentNode::DeReplicate() const {
......@@ -171,21 +184,23 @@ Fragment FragmentNode::DeReplicate() const {
if (rep_size && idx_size) {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
}
if (factor == 1) return GetRef<Fragment>(this);
if (factor == 1)
return GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(),
ReplicationPlaceholder() * factor + FloorMod(forward_index_[0], factor));
vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
FloorMod(forward_index_[0], factor));
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
return Fragment(input_size_, new_forward_index, new_forward_thread, int(*rep_size) / factor,
NullOpt);
return Fragment(input_size_, new_forward_index, new_forward_thread,
int(*rep_size) / factor, NullOpt);
}
Layout LayoutNode::Inverse() const {
arith::Analyzer analyzer;
arith::IterMapResult res = arith::DetectIterMap(forward_index_, getVarMap(), 1,
arith::IterMapLevel::Bijective, &analyzer);
arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1,
arith::IterMapLevel::Bijective, &analyzer);
ICHECK(res->errors.empty()) << res->errors;
auto outputs_shape = OutputShape();
......@@ -208,21 +223,25 @@ Layout LayoutNode::Inverse() const {
return Layout(outputs_shape, backward_index);
}
PrimExpr infer_fragment_index(const Map<Var, Range>& input_iters, const PrimExpr& forward_thread,
arith::Analyzer* analyzer) {
Array<arith::IterSplitExpr> splits =
DivideUnusedIterators({forward_thread}, ToIterVars(input_iters), analyzer);
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
const PrimExpr &forward_thread,
arith::Analyzer *analyzer) {
Array<arith::IterSplitExpr> splits = DivideUnusedIterators(
{forward_thread}, ToIterVars(input_iters), analyzer);
Array<arith::IterSplitExpr> split_without_rep;
for (const auto& split : splits) {
for (const auto &split : splits) {
CHECK(split->source->source.as<Var>());
if (split->source->source.as<Var>().value().same_as(ReplicationPlaceholder())) continue;
if (split->source->source.as<Var>().value().same_as(
ReplicationPlaceholder()))
continue;
split_without_rep.push_back(split);
}
return MakeFlattenedExpression(split_without_rep);
}
FragmentNode::FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
FragmentNode::FragmentNode(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size) {
input_size_ = input_size;
replicate_size_ = replicate_size;
......@@ -230,9 +249,11 @@ FragmentNode::FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_i
UpdateAnalyzer(&analyzer);
forward_thread_ = analyzer.Simplify(forward_thread);
if (forward_index.empty()) {
forward_index = {infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
forward_index = {
infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
}
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); });
forward_index_ = forward_index.Map(
[&](const PrimExpr &e) { return analyzer.Simplify(e); });
}
Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
......@@ -250,24 +271,28 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
replicate_size = thread_replicate->dom->extent;
vmap.Set(thread_replicate->var, ReplicationPlaceholder());
}
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, replicate_size);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
data_ = std::move(n);
}
Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size, Optional<Var> replicate_var) {
PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var) {
if (replicate_var.defined()) {
forward_thread =
Substitute(forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
forward_thread = Substitute(
forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
}
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, replicate_size);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
data_ = std::move(n);
}
void FragmentNode::VisitAttrs(tvm::AttrVisitor* v) {
void FragmentNode::VisitAttrs(tvm::AttrVisitor *v) {
LayoutNode::VisitAttrs(v);
v->Visit("forward_thread", &forward_thread_);
v->Visit("replicate_size", &replicate_size_);
......@@ -282,14 +307,15 @@ PrimExpr FragmentNode::ThreadExtent() const {
return ist.max();
}
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr>& vars,
const Optional<PrimExpr>& rep_var) const {
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
const Optional<PrimExpr> &rep_var) const {
Map<Var, PrimExpr> vmap;
ICHECK_EQ(vars.size(), InputDim());
for (size_t i = 0; i < InputDim(); i++) {
vmap.Set(InputPlaceholder(i), vars[i]);
}
if (rep_var.defined()) vmap.Set(ReplicationPlaceholder(), rep_var.value());
if (rep_var.defined())
vmap.Set(ReplicationPlaceholder(), rep_var.value());
return Substitute(forward_thread_, vmap);
}
......@@ -299,7 +325,8 @@ Layout FragmentNode::Inverse() const {
input_size_copy.push_back(ReplicateExtent());
auto forward_index_copy = forward_index_;
forward_index_copy.push_back(
Substitute(forward_thread_, {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
Substitute(forward_thread_,
{{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
auto fwd = Layout(input_size_copy, forward_index_copy);
auto bwd = fwd->Inverse();
return bwd;
......@@ -311,8 +338,9 @@ Fragment FragmentNode::CondenseReplicateVar() const {
input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
PrimExpr new_forward_thread;
IterVar new_thread_replicate;
std::tie(new_forward_thread, new_thread_replicate) = CompressIterator(
forward_thread_, ToIterVars(input_iters), ReplicationPlaceholder(), &analyzer);
std::tie(new_forward_thread, new_thread_replicate) =
CompressIterator(forward_thread_, ToIterVars(input_iters),
ReplicationPlaceholder(), &analyzer);
return Fragment(input_size_, forward_index_, new_forward_thread,
new_thread_replicate->dom->extent, new_thread_replicate->var);
}
......@@ -330,12 +358,14 @@ void FragmentNode::DebugOutput() const {
LOG_DEBUG << "Fragment ThreadIndex: " << forward_thread_;
}
bool LayoutNode::SEqualReduce(const LayoutNode* other, SEqualReducer equal) const {
bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}
bool FragmentNode::SEqualReduce(const FragmentNode* other, SEqualReducer equal) const {
bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
......@@ -346,7 +376,7 @@ bool FragmentNode::SEqualReduce(const FragmentNode* other, SEqualReducer equal)
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1]));
});
......@@ -366,36 +396,37 @@ TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) {
return layout->GetForwardIndex();
});
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Fragment(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size").set_body_typed([](Fragment fragment) {
return fragment->ThreadExtent();
});
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size")
.set_body_typed([](Fragment fragment) { return fragment->ThreadExtent(); });
TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) {
return fragment->GetForwardThread();
});
TVM_REGISTER_GLOBAL("tl.Fragment_repeat")
.set_body_typed([](Fragment fragment, Array<PrimExpr> repeats, bool repeat_on_thread,
bool lower_dim_first) {
.set_body_typed([](Fragment fragment, Array<PrimExpr> repeats,
bool repeat_on_thread, bool lower_dim_first) {
return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first);
});
TVM_REGISTER_GLOBAL("tl.Fragment_replicate").set_body_typed([](Fragment fragment, int repeats) {
return fragment->Replicate(repeats);
});
TVM_REGISTER_GLOBAL("tl.Fragment_replicate")
.set_body_typed([](Fragment fragment, int repeats) {
return fragment->Replicate(repeats);
});
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var").set_body_typed([](Fragment fragment) {
return fragment->CondenseReplicateVar();
});
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var")
.set_body_typed([](Fragment fragment) {
return fragment->CondenseReplicateVar();
});
TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
.set_body_typed([](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, element_size, 0);
});
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -20,7 +20,7 @@ class Layout;
class Fragment;
class LayoutNode : public Object {
public:
public:
LayoutNode() = default;
LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
......@@ -34,21 +34,21 @@ class LayoutNode : public Object {
Array<PrimExpr> GetForwardIndex() const { return forward_index_; }
virtual Array<PrimExpr> Forward(const Array<PrimExpr>& vars) const;
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
virtual Layout Inverse() const;
virtual void DebugOutput() const;
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode* other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor* v);
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor *v);
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
protected:
protected:
virtual Map<Var, Range> getVarMap() const;
void UpdateAnalyzer(arith::Analyzer* analyzer) const;
void UpdateAnalyzer(arith::Analyzer *analyzer) const;
Array<PrimExpr> forward_index_;
Array<PrimExpr> input_size_;
};
......@@ -57,7 +57,7 @@ class LayoutNode : public Object {
* \brief Layout reference class.
*/
class Layout : public ObjectRef {
public:
public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
......@@ -65,10 +65,10 @@ class Layout : public ObjectRef {
};
class FragmentNode : public LayoutNode {
public:
public:
FragmentNode() = default;
FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, PrimExpr forward_thread,
PrimExpr replicate_size);
FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size);
PrimExpr GetForwardThread() const { return forward_thread_; }
......@@ -78,9 +78,10 @@ class FragmentNode : public LayoutNode {
PrimExpr ReplicateExtent() const { return replicate_size_; };
PrimExpr ForwardThread(const Array<PrimExpr>& vars, const Optional<PrimExpr>& rep_var) const;
PrimExpr ForwardThread(const Array<PrimExpr> &vars,
const Optional<PrimExpr> &rep_var) const;
Fragment Repeat(const Array<PrimExpr>& repeats, bool repeat_on_thread,
Fragment Repeat(const Array<PrimExpr> &repeats, bool repeat_on_thread,
bool lower_dim_first = true) const;
Fragment Replicate(int repeats) const;
......@@ -91,12 +92,12 @@ class FragmentNode : public LayoutNode {
void DebugOutput() const final;
void VisitAttrs(tvm::AttrVisitor* v);
bool SEqualReduce(const FragmentNode* other, SEqualReducer equal) const;
static constexpr const char* _type_key = "tl.Fragment";
void VisitAttrs(tvm::AttrVisitor *v);
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char *_type_key = "tl.Fragment";
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);
protected:
protected:
Map<Var, Range> getVarMap() const final;
PrimExpr forward_thread_;
PrimExpr replicate_size_;
......@@ -106,12 +107,13 @@ class FragmentNode : public LayoutNode {
* \brief Fragment reference class.
*/
class Fragment : public Layout {
public:
public:
TVM_DLL Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
PrimExpr forward_thread, IterVar thread_replicate);
TVM_DLL Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size, Optional<Var> replicate_var);
PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
};
......@@ -121,41 +123,52 @@ Var ReplicationPlaceholder();
Fragment makeGemmFragment8x8();
Fragment makeGemmFragment8x8Transposed();
Fragment makeGemmFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n,
Fragment makeGemmFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m,
const int warp_n, const int element_size);
Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n, const int element_size);
Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n);
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n, bool transposed = false);
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size);
Fragment makeGemmFragmentB(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n);
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, bool transposed = false);
// Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfactor);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kfactor);
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m,
const int warp_n, const int element_size);
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, int kfactor);
Layout makeGemmABLayout(int stride, int continuous, int element_size,
int kfactor);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor);
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor);
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);
namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
constexpr const char* kLayoutMap = "layout_map";
} // namespace attr
constexpr const char *kLayoutMap = "layout_map";
} // namespace attr
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LAYOUT_LAYOUT_H_
#endif // TVM_TL_LAYOUT_LAYOUT_H_
......@@ -34,20 +34,23 @@ PrimExpr SwizzlePattern::swizzle(PrimExpr expr) const {
return low + high * base;
}
bool SwizzlePattern::operator==(const SwizzlePattern& other) const {
return std::tie(base_, bits_, shift_) == std::tie(other.base_, other.bits_, other.shift_);
bool SwizzlePattern::operator==(const SwizzlePattern &other) const {
return std::tie(base_, bits_, shift_) ==
std::tie(other.base_, other.bits_, other.shift_);
}
SwizzledLayoutNode::SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzledLayoutNode::SwizzledLayoutNode(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
SwizzlePattern pattern)
: pattern_(pattern) {
input_size_ = input_size;
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); });
forward_index_ = forward_index.Map(
[&](const PrimExpr &e) { return analyzer.Simplify(e); });
}
Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr>& vars) const {
Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr> &vars) const {
auto expr_list = LayoutNode::Forward(vars);
auto expr = expr_list.back();
expr_list.pop_back();
......@@ -57,8 +60,8 @@ Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr>& vars) const {
void SwizzledLayoutNode::DebugOutput() const {
LayoutNode::DebugOutput();
std::cout << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() << " "
<< pattern_.Shift();
std::cout << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits()
<< " " << pattern_.Shift();
}
Layout SwizzledLayoutNode::Inverse() const {
......@@ -66,7 +69,8 @@ Layout SwizzledLayoutNode::Inverse() const {
return {};
}
SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
Array<PrimExpr> forward_index,
SwizzlePattern pattern) {
Map<Var, PrimExpr> vmap;
Array<PrimExpr> input_size;
......@@ -75,26 +79,32 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forwa
CHECK(is_zero(forward_var[i]->dom->min));
input_size.push_back(forward_var[i]->dom->extent);
}
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
data_ = std::move(n);
}
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
data_ = std::move(n);
}
void SwizzledLayoutNode::VisitAttrs(tvm::AttrVisitor* v) { LayoutNode::VisitAttrs(v); }
void SwizzledLayoutNode::VisitAttrs(tvm::AttrVisitor *v) {
LayoutNode::VisitAttrs(v);
}
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode* other, SEqualReducer equal) const {
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) && pattern_ == other->pattern_;
equal(this->forward_index_, other->forward_index_) &&
pattern_ == other->pattern_;
}
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -19,16 +19,16 @@ namespace tl {
* \brief Swizzle pattern
*/
class SwizzlePattern {
public:
public:
SwizzlePattern() = default;
SwizzlePattern(int bits, int base, int shift);
PrimExpr swizzle(PrimExpr expr) const;
int Bits() const { return bits_; }
int Base() const { return base_; }
int Shift() const { return shift_; }
bool operator==(const SwizzlePattern& other) const;
bool operator==(const SwizzlePattern &other) const;
private:
private:
int bits_;
int base_;
int shift_;
......@@ -38,21 +38,21 @@ class SwizzlePattern {
* \brief Layout with swizzle
*/
class SwizzledLayoutNode : public LayoutNode {
public:
public:
SwizzledLayoutNode() = default;
SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzlePattern pattern);
Array<PrimExpr> Forward(const Array<PrimExpr>& vars) const final;
Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const final;
Layout Inverse() const final;
void DebugOutput() const final;
static constexpr const char* _type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode* other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor* v);
static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor *v);
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);
private:
private:
SwizzlePattern pattern_;
};
......@@ -60,16 +60,16 @@ class SwizzledLayoutNode : public LayoutNode {
* \brief SwizzledLayout reference class.
*/
class SwizzledLayout : public Layout {
public:
TVM_DLL SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzlePattern pattern);
public:
TVM_DLL SwizzledLayout(Array<IterVar> forward_var,
Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode);
};
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LAYOUT_SWIZZLE_H_
\ No newline at end of file
#endif // TVM_TL_LAYOUT_SWIZZLE_H_
\ No newline at end of file
......@@ -18,9 +18,9 @@ namespace tl {
using namespace tir;
using namespace arith;
bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
const auto* clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>();
bool CanProveDivisible(const PrimExpr &lhs, const PrimExpr &rhs) {
const auto *clhs = lhs.as<IntImmNode>();
const auto *crhs = rhs.as<IntImmNode>();
if (crhs && crhs->value == 0) {
return false;
} else if (clhs && crhs) {
......@@ -33,20 +33,22 @@ bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
/*!
* \brief Collector that collects the outgoing split reference of each IterMark.
*
* These out-going splits can then be used to check if the iterators are independent.
* These out-going splits can then be used to check if the iterators are
* independent.
*/
class IterMarkSplitCollector {
public:
public:
// mark all IterMarks that are visited.
std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
// each iter mark to its outgoing splits that are referenced.
std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash,
ObjectPtrEqual>
mark2splits_;
/*!
* \brief Collect all mark2splits recursively from indices.
* \param indices The iterator of interest.
*/
void Collect(const Array<IterSumExpr>& indices) {
void Collect(const Array<IterSumExpr> &indices) {
for (IterSumExpr sum_expr : indices) {
for (IterSplitExpr split : sum_expr->args) {
this->CollectInternal(split->source);
......@@ -55,10 +57,11 @@ class IterMarkSplitCollector {
}
}
void CollectInternal(const IterMark& mark) {
if (visited_.count(mark)) return;
void CollectInternal(const IterMark &mark) {
if (visited_.count(mark))
return;
visited_.insert(mark);
if (auto* op = mark->source.as<IterSumExprNode>()) {
if (auto *op = mark->source.as<IterSumExprNode>()) {
for (IterSplitExpr split : op->args) {
this->CollectInternal(split->source);
mark2splits_[split->source].push_back(split);
......@@ -67,9 +70,9 @@ class IterMarkSplitCollector {
}
};
Array<IterSplitExpr> get_unused_iters(const IterMark& mark,
const std::vector<IterSplitExpr>& splits,
Analyzer* analyzer) {
Array<IterSplitExpr> get_unused_iters(const IterMark &mark,
const std::vector<IterSplitExpr> &splits,
Analyzer *analyzer) {
PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
std::vector<bool> used(splits.size(), false);
std::vector<IterSplitExpr> results;
......@@ -78,20 +81,25 @@ Array<IterSplitExpr> get_unused_iters(const IterMark& mark,
size_t j = 0;
size_t lowest = splits.size();
for (; j < splits.size(); ++j) {
if (used[j]) continue;
if (!used[j] && analyzer->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) {
if (used[j])
continue;
if (!used[j] && analyzer->CanProveEqual(splits[j]->lower_factor,
expected_lower_factor)) {
break;
}
if (lowest == splits.size() ||
CanProveDivisible(splits[lowest]->lower_factor, splits[j]->lower_factor)) {
CanProveDivisible(splits[lowest]->lower_factor,
splits[j]->lower_factor)) {
lowest = j;
}
}
if (j == splits.size()) {
ICHECK(lowest != splits.size());
ICHECK(CanProveDivisible(splits[lowest]->lower_factor, expected_lower_factor));
results.emplace_back(mark, expected_lower_factor,
FloorDiv(splits[lowest]->lower_factor, expected_lower_factor), 1);
ICHECK(CanProveDivisible(splits[lowest]->lower_factor,
expected_lower_factor));
results.emplace_back(
mark, expected_lower_factor,
FloorDiv(splits[lowest]->lower_factor, expected_lower_factor), 1);
expected_lower_factor = splits[lowest]->lower_factor;
} else {
used[j] = true;
......@@ -99,36 +107,40 @@ Array<IterSplitExpr> get_unused_iters(const IterMark& mark,
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
}
}
bool match_full_iter = analyzer->CanProveEqual(expected_lower_factor, mark->extent);
bool match_full_iter =
analyzer->CanProveEqual(expected_lower_factor, mark->extent);
if (!match_full_iter) {
results.emplace_back(mark, expected_lower_factor, FloorDiv(mark->extent, expected_lower_factor),
1);
results.emplace_back(mark, expected_lower_factor,
FloorDiv(mark->extent, expected_lower_factor), 1);
}
return results;
}
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs,
const Array<IterVar> input_iters, Analyzer* analyzer) {
auto iter_sum = exprs.Map(
[&](const auto& e) { return NormalizeToIterSum(e, ToVMap(input_iters), analyzer); });
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
const Array<IterVar> input_iters,
Analyzer *analyzer) {
auto iter_sum = exprs.Map([&](const auto &e) {
return NormalizeToIterSum(e, ToVMap(input_iters), analyzer);
});
IterMarkSplitCollector collector;
collector.Collect(iter_sum);
Array<IterSplitExpr> results;
for (const IterMark& mark : collector.visited_) {
for (const IterMark &mark : collector.visited_) {
ICHECK(mark->source.as<Var>()) << "Not a normalized iterator: " << mark;
}
for (const IterVar& iter : input_iters) {
for (const IterVar &iter : input_iters) {
IterMark iv_mark;
for (const IterMark& mark : collector.visited_) {
for (const IterMark &mark : collector.visited_) {
if (mark->source.as<Var>().same_as(iter->var)) {
iv_mark = mark;
break;
}
}
if (iv_mark.defined()) {
auto splits = get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer);
auto splits =
get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer);
// Put the small axis last
results.insert(results.end(), splits.rbegin(), splits.rend());
} else if (!is_one(iter->dom->extent)) {
......@@ -140,12 +152,12 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs,
return results;
}
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits) {
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr> &splits) {
Array<arith::IterSplitExpr> lists;
PrimExpr scale = 1;
for (int i = splits.size() - 1; i >= 0; i--) {
auto scaled_split =
arith::IterSplitExpr(splits[i]->source, splits[i]->lower_factor, splits[i]->extent, scale);
auto scaled_split = arith::IterSplitExpr(
splits[i]->source, splits[i]->lower_factor, splits[i]->extent, scale);
lists.push_back(scaled_split);
scale *= splits[i]->extent;
}
......@@ -153,45 +165,47 @@ PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits) {
}
class IterSumMutator {
public:
IterSumMutator(const Map<IterSplitExpr, IterSplitExpr>& replace_map)
public:
IterSumMutator(const Map<IterSplitExpr, IterSplitExpr> &replace_map)
: replace_map_(replace_map) {}
// override the original mutate function.
IterSumExpr Mutate(const IterSumExpr& iter_sum) {
IterSumExpr Mutate(const IterSumExpr &iter_sum) {
Array<IterSplitExpr> args;
for (const auto& split : iter_sum->args) {
for (const auto &split : iter_sum->args) {
if (replace_map_.count(split)) {
args.push_back(replace_map_[split]);
} else {
auto split_ =
IterSplitExpr(Mutate(split->source), split->lower_factor, split->extent, split->scale);
auto split_ = IterSplitExpr(Mutate(split->source), split->lower_factor,
split->extent, split->scale);
args.push_back(split_);
}
}
return IterSumExpr(args, iter_sum->base);
}
IterMark Mutate(const IterMark& mark) {
if (auto* op = mark->source.as<IterSumExprNode>()) {
IterMark Mutate(const IterMark &mark) {
if (auto *op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent);
} else {
return mark;
}
}
private:
private:
Map<IterSplitExpr, IterSplitExpr> replace_map_;
};
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr,
const Array<IterVar> input_iters, const Var& var,
arith::Analyzer* analyzer) {
auto iter_sum = arith::NormalizeToIterSum(expr, ToVMap(input_iters), analyzer);
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr &expr,
const Array<IterVar> input_iters,
const Var &var,
arith::Analyzer *analyzer) {
auto iter_sum =
arith::NormalizeToIterSum(expr, ToVMap(input_iters), analyzer);
IterMarkSplitCollector collector;
collector.Collect({iter_sum});
IterMark mark;
for (const IterMark& m : collector.visited_) {
for (const IterMark &m : collector.visited_) {
ICHECK(m->source.as<Var>()) << "Not a normalized iterator: " << mark;
if (m->source.as<Var>().value().same_as(var)) {
mark = m;
......@@ -204,7 +218,7 @@ std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr,
}
PrimExpr extent = 1;
for (const auto& split : splits) {
for (const auto &split : splits) {
extent *= split->extent;
}
extent = analyzer->Simplify(extent);
......@@ -214,33 +228,35 @@ std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr,
auto new_mark = IterMark(new_var, extent);
PrimExpr scale = 1;
Map<IterSplitExpr, IterSplitExpr> replace_map;
for (const auto& split : splits) {
auto rescaled = arith::IterSplitExpr(new_mark, scale, split->extent, split->scale);
for (const auto &split : splits) {
auto rescaled =
arith::IterSplitExpr(new_mark, scale, split->extent, split->scale);
replace_map.Set(split, rescaled);
scale *= split->extent;
}
IterSumMutator mutator(replace_map);
PrimExpr reaplced = analyzer->Simplify(NormalizeIterMapToExpr(mutator.Mutate(iter_sum)));
PrimExpr reaplced =
analyzer->Simplify(NormalizeIterMapToExpr(mutator.Mutate(iter_sum)));
return {reaplced, new_iter_var};
}
Array<IterVar> ToIterVars(const Map<Var, Range>& vmap) {
Array<IterVar> ToIterVars(const Map<Var, Range> &vmap) {
Array<IterVar> result;
for (const auto& [var, range] : vmap) {
for (const auto &[var, range] : vmap) {
result.push_back(IterVar(range, var, IterVarType::kDataPar));
}
return result;
}
Map<Var, Range> ToVMap(const Array<IterVar>& ivs) {
Map<Var, Range> ToVMap(const Array<IterVar> &ivs) {
Map<Var, Range> result;
for (const auto& iv : ivs) {
for (const auto &iv : ivs) {
result.Set(iv->var, iv->dom);
}
return result;
}
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -23,38 +23,42 @@ using namespace tir;
* If the expr is (x // 2) and x is in Range(4),
* than the result should be (x % 2)
*/
Array<arith::IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs,
const Array<IterVar> input_iters,
arith::Analyzer* analyzer);
Array<arith::IterSplitExpr>
DivideUnusedIterators(const Array<PrimExpr> &exprs,
const Array<IterVar> input_iters,
arith::Analyzer *analyzer);
/*!
* \brief Compress the iterator var, remove the unused part of the var not present in the expr
* \brief Compress the iterator var, remove the unused part of the var not
* present in the expr
*
* Returns the compressed IterVar as well as the Updated iter sum expression.
*/
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr,
const Array<IterVar> input_iters, const Var& var,
arith::Analyzer* analyzer);
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr &expr,
const Array<IterVar> input_iters,
const Var &var,
arith::Analyzer *analyzer);
/*!
* \brief Convert the iter splits returned by DivideUnusedIterators into flattened expression
* \brief Convert the iter splits returned by DivideUnusedIterators into
* flattened expression
*
*/
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits);
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr> &splits);
/*!
* \brief Convert an Array of IterVar to a Map object
*
*/
Map<Var, Range> ToVMap(const Array<IterVar>& ivs);
Map<Var, Range> ToVMap(const Array<IterVar> &ivs);
/*!
* \brief Convert a Map object to an Array of IterVar
*
*/
Array<IterVar> ToIterVars(const Map<Var, Range>& vmap);
Array<IterVar> ToIterVars(const Map<Var, Range> &vmap);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LAYOUT_UTILS_H_
#endif // TVM_TL_LAYOUT_UTILS_H_
......@@ -13,78 +13,92 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../target/cuda.h"
#include "../target/utils.h"
namespace tvm {
namespace tl {
#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op& OpName() { \
static const Op& op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)
#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)
TIR_DEFINE_TL_BUILTIN(CreateListofMBarrierOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(CreateTMADescriptorOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(CreateTMAIm2ColDescriptorOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(GetMBarrierOp)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(TMALoadOp).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMALoadIm2ColOp).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMALoadIm2ColOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(MBarrierWaitParity)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(MBarrierExpectTX)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(LDMatrixOp)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(STMatrixOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SyncThreadsPartialOp)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SetMaxNReg)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(WaitWgmma)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(WaitWgmma).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(PackB16Op).set_num_inputs(2).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -18,21 +18,23 @@ namespace tl {
/*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load
*
* CuTensorMap* CreateTMADescriptorOp(data_type, rank, global_addr, global_shape...,
* global_stride..., smem_box..., smem_stride..., interleave, swizzle, l2_promotion, oob_fill)
* CuTensorMap* CreateTMADescriptorOp(data_type, rank, global_addr,
* global_shape..., global_stride..., smem_box..., smem_stride..., interleave,
* swizzle, l2_promotion, oob_fill)
*
*/
const Op& CreateTMADescriptorOp();
const Op &CreateTMADescriptorOp();
/*!
* \brief tvm intrinsics for TMADescriptor creation for image to column load
*
* CuTensorMap* CreateTMAIm2ColDescriptorOp(data_type, rank, global_addr, global_shape...,
* global_stride..., elem_stride..., lower_corner..., upper_corner..., smme_box_pixel, smem_box_channel,
* interleave, swizzle, l2_promotion, oob_fill)
* CuTensorMap* CreateTMAIm2ColDescriptorOp(data_type, rank, global_addr,
* global_shape..., global_stride..., elem_stride..., lower_corner...,
* upper_corner..., smme_box_pixel, smem_box_channel, interleave, swizzle,
* l2_promotion, oob_fill)
*
*/
const Op& CreateTMAIm2ColDescriptorOp();
const Op &CreateTMAIm2ColDescriptorOp();
/*!
* \brief Create a list of mbarrier with num_threads
......@@ -40,7 +42,7 @@ const Op& CreateTMAIm2ColDescriptorOp();
* GetMBarrier(num_threads0, num_threads1, ...)
*
*/
const Op& CreateListofMBarrierOp();
const Op &CreateListofMBarrierOp();
/*!
* \brief Get the mbarrier with barrier_id
......@@ -48,31 +50,35 @@ const Op& CreateListofMBarrierOp();
* int64_t* GetMBarrier(barrier_id)
*
*/
const Op& GetMBarrierOp();
const Op &GetMBarrierOp();
/*!
* \brief tvm intrinsics for loading data from global tensor descriptor to shared memory
* \brief tvm intrinsics for loading data from global tensor descriptor to
* shared memory
*
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
*
*/
const Op& TMALoadOp();
const Op &TMALoadOp();
/*!
* \brief tvm intrinsics for loading image from global tensor to columns in shared memory
* \brief tvm intrinsics for loading image from global tensor to columns in
* shared memory
*
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ..., image_offset, ...)
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ...,
* image_offset, ...)
*
*/
const Op& TMALoadIm2ColOp();
const Op &TMALoadIm2ColOp();
/*!
* \brief tvm intrinsics for storing data from shared memory to global tensor descriptor
* \brief tvm intrinsics for storing data from shared memory to global tensor
* descriptor
*
* TMAStoreOp(descriptor, smem_data, coord_0, coord_1, ...)
*
*/
const Op& TMAStoreOp();
const Op &TMAStoreOp();
/*!
* \brief tvm intrinsics for mbarrier wait with parity bit
......@@ -80,7 +86,7 @@ const Op& TMAStoreOp();
* MBarrierWaitParity(mbarrier, parity)
*
*/
const Op& MBarrierWaitParity();
const Op &MBarrierWaitParity();
/*!
* \brief tvm intrinsics for mbarrier expect tx
......@@ -88,7 +94,7 @@ const Op& MBarrierWaitParity();
* MBarrierExpectTX(mbarrier, transaction_bytes)
*
*/
const Op& MBarrierExpectTX();
const Op &MBarrierExpectTX();
/*!
* \brief tvm intrinsics for ldmatrix
......@@ -96,7 +102,7 @@ const Op& MBarrierExpectTX();
* LDMatrixOp(transposed, num, shared_addr, local_addr)
*
*/
const Op& LDMatrixOp();
const Op &LDMatrixOp();
/*!
* \brief tvm intrinsics for stmatrix
......@@ -104,7 +110,7 @@ const Op& LDMatrixOp();
* LDMatrixOp(transposed, num, shared_addr, int32_values...)
*
*/
const Op& STMatrixOp();
const Op &STMatrixOp();
/*!
* \brief Pack two b16 value into a b32 value
......@@ -112,7 +118,7 @@ const Op& STMatrixOp();
* int32 PackB16Op(b16_value, b16_value)
*
*/
const Op& PackB16Op();
const Op &PackB16Op();
/*!
* \brief Similar to __syncthreads(), but can be used to sync partial threads
......@@ -120,7 +126,7 @@ const Op& PackB16Op();
* SyncThreadsPartialOp(num_partial_threads or mbarrier)
*
*/
const Op& SyncThreadsPartialOp();
const Op &SyncThreadsPartialOp();
/*!
* \brief Issue a shared memory fence for async operations
......@@ -128,7 +134,7 @@ const Op& SyncThreadsPartialOp();
* FenceProxyAsync()
*
*/
const Op& FenceProxyAsyncOp();
const Op &FenceProxyAsyncOp();
/*!
* \brief Set reg hint for warp-specialized branched
......@@ -136,7 +142,7 @@ const Op& FenceProxyAsyncOp();
* SetMaxNRegInc(num_reg, is_inc)
*
*/
const Op& SetMaxNReg();
const Op &SetMaxNReg();
/*!
* \brief Wait the previous wgmma to finish
......@@ -144,7 +150,7 @@ const Op& SetMaxNReg();
* WaitWgmma(num_mma)
*
*/
const Op& WaitWgmma();
const Op &WaitWgmma();
/*!
* \brief tvm intrinsic for amd matrix core mfma instructions.
......@@ -194,7 +200,7 @@ TVM_DLL const Op &tvm_rdna_wmma();
*/
TVM_DLL const Op &tvm_rdna_wmma_store();
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_BUILTIN_H_
\ No newline at end of file
#endif // TVM_TL_OP_BUILTIN_H_
\ No newline at end of file
......@@ -13,8 +13,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../target/cuda.h"
#include "../target/utils.h"
#include "builtin.h"
namespace tvm {
......@@ -26,56 +26,56 @@ static int to_CUtensorMapDataType(DataType dtype) {
CUtensorMapDataType tp;
if (dtype.is_float()) {
switch (dtype.bits()) {
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
}
} else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (dtype.is_int()) {
switch (dtype.bits()) {
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_INT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_INT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_INT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_INT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
}
} else if (dtype.is_uint()) {
switch (dtype.bits()) {
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
}
} else {
ICHECK(0) << dtype;
......@@ -83,18 +83,20 @@ static int to_CUtensorMapDataType(DataType dtype) {
return static_cast<int>(tp);
}
template <typename T>
static Array<T> ReverseArray(Array<T> array) {
template <typename T> static Array<T> ReverseArray(Array<T> array) {
return Array<T>{array.rbegin(), array.rend()};
}
Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
if (!TargetIsHopper(T.target)) return Stmt();
Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (!TargetIsHopper(T.target))
return Stmt();
bool is_load;
if (src.scope() == "global" && (dst.scope() == "shared.dyn" || dst.scope() == "shared")) {
if (src.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared")) {
// Use the Hopper TMA bulk copy instructions
is_load = true;
} else if (dst.scope() == "global" && (src.scope() == "shared.dyn" || src.scope() == "shared")) {
} else if (dst.scope() == "global" &&
(src.scope() == "shared.dyn" || src.scope() == "shared")) {
is_load = false;
} else {
return Stmt();
......@@ -107,7 +109,8 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
shared_tensor = T.buffer_remap[shared_tensor];
}
if (T.layout_map.count(global_tensor)) {
ICHECK(T.layout_map.count(global_tensor) == 0) << "Cannot support global layout.";
ICHECK(T.layout_map.count(global_tensor) == 0)
<< "Cannot support global layout.";
}
TMADesc desc;
......@@ -124,7 +127,8 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
auto global_range = is_load ? src_range : dst_range;
desc.global_addr = global_tensor->data;
desc.global_shape = ReverseArray(global_tensor->shape);
Array<PrimExpr> global_coords = ReverseArray(global_range.Map([](Range r) { return r->min; }));
Array<PrimExpr> global_coords =
ReverseArray(global_range.Map([](Range r) { return r->min; }));
if (!global_tensor->strides.empty()) {
desc.global_stride = ReverseArray(global_tensor->strides);
} else {
......@@ -139,11 +143,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
// The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes
desc.global_stride =
desc.global_stride.Map([&](PrimExpr e) { return e * global_tensor->dtype.bytes(); });
desc.global_stride = desc.global_stride.Map(
[&](PrimExpr e) { return e * global_tensor->dtype.bytes(); });
// Smem Box
desc.smem_box = ReverseArray(global_range.Map([](Range r) { return r->extent; }));
desc.smem_box =
ReverseArray(global_range.Map([](Range r) { return r->extent; }));
desc.smem_stride = Array<PrimExpr>(desc.rank, PrimExpr(1));
// L2 & OOB
......@@ -159,12 +164,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
auto stride = as_const_int(shared_layout->InputShape()[0]);
auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(*stride, *continuous,
shared_tensor->dtype.bits()))) {
if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(
*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(
shared_layout,
makeFullBankSwizzleLayout(*stride, *continuous, shared_tensor->dtype.bits()))) {
makeFullBankSwizzleLayout(*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else {
ICHECK(0) << "Cannot detect TMA layout.";
......@@ -182,12 +189,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
ICHECK((*inner_box_dim) % instruction_dim == 0);
desc.smem_box.Set(0, PrimExpr(instruction_dim));
Call create_descriptor = Call(DataType::Handle(), CreateTMADescriptorOp(), desc.EncodeCallArgs());
Call create_descriptor =
Call(DataType::Handle(), CreateTMADescriptorOp(), desc.EncodeCallArgs());
Array<PrimExpr> args;
args.reserve(desc.rank + 3);
args.push_back(create_descriptor);
if (is_load) args.push_back(0); // mbarrier id placeholder
if (is_load)
args.push_back(0); // mbarrier id placeholder
auto op = is_load ? TMALoadOp() : TMAStoreOp();
Stmt tma_copy;
......@@ -196,18 +205,22 @@ Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
Var loop_var("i");
int loop_extent = (*inner_box_dim) / instruction_dim;
PrimExpr total_elements = 1;
for (auto e : desc.smem_box) total_elements *= e;
PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1,
total_elements * loop_var, total_elements);
for (auto e : desc.smem_box)
total_elements *= e;
PrimExpr shared_addr =
shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1,
total_elements * loop_var, total_elements);
args.push_back(shared_addr);
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords) args.push_back(coord);
for (auto coord : global_coords)
args.push_back(coord);
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args)));
} else {
PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1);
args.push_back(shared_addr);
for (auto coord : global_coords) args.push_back(coord);
for (auto coord : global_coords)
args.push_back(coord);
tma_copy = Evaluate(Call(DataType::Handle(), op, args));
}
tma_copy = IfThenElse(EQ(T.thread_var, 0), tma_copy);
......@@ -222,10 +235,14 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
args.push_back(data_type);
args.push_back(static_cast<int>(rank));
args.push_back(global_addr);
for (auto e : global_shape) args.push_back(e);
for (auto e : global_stride) args.push_back(e);
for (auto e : smem_box) args.push_back(e);
for (auto e : smem_stride) args.push_back(e);
for (auto e : global_shape)
args.push_back(e);
for (auto e : global_stride)
args.push_back(e);
for (auto e : smem_box)
args.push_back(e);
for (auto e : smem_stride)
args.push_back(e);
args.push_back(interleave);
args.push_back(swizzle);
args.push_back(l2_promotion);
......@@ -247,9 +264,11 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
padding = args[7].as<IntImm>().value()->value;
}
Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" && (dst.scope() == "shared.dyn" || dst.scope() == "shared"));
ICHECK(src.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared"));
ICHECK(src->shape.size() == 4);
ICHECK(dst->shape.size() == 2);
ICHECK(src->dtype == dst->dtype);
......@@ -277,7 +296,8 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const
// The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { return e * src->dtype.bytes(); });
desc.global_stride = desc.global_stride.Map(
[&](PrimExpr e) { return e * src->dtype.bytes(); });
desc.elem_stride = {1, stride, stride, 1};
desc.lower_corner = {-padding, -padding};
desc.upper_corner = {-padding, -padding};
......@@ -294,50 +314,70 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const
auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout,
makeHalfBankSwizzleLayout(*stride, *continuous, dst->dtype.bits()))) {
makeHalfBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) {
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(
*stride, *continuous,
dst->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else {
ICHECK(0) << "Cannot detect TMA layout.";
}
}
Call create_desc = Call(DataType::Handle(), CreateTMAIm2ColDescriptorOp(), desc.EncodeCallArgs());
Call create_desc = Call(DataType::Handle(), CreateTMAIm2ColDescriptorOp(),
desc.EncodeCallArgs());
Array<PrimExpr> global_coords; // c, w, h, n
Array<PrimExpr> image_offset; // w, h
Array<PrimExpr> global_coords; // c, w, h, n
Array<PrimExpr> image_offset; // w, h
global_coords.reserve(desc.rank);
ICHECK(analyzer->CanProveEqual(FloorMod(desc.global_shape[0], desc.smem_box_channel), 0))
ICHECK(analyzer->CanProveEqual(
FloorMod(desc.global_shape[0], desc.smem_box_channel), 0))
<< "Currently can only support divisible channel case";
global_coords.push_back(FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0]));
global_coords.push_back(
FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0]));
image_offset.push_back(
dilation * FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]), kernel));
image_offset.push_back(dilation *
FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0] * kernel));
PrimExpr h_dim = FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1, stride) + 1;
PrimExpr w_dim = FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1, stride) + 1;
global_coords.push_back(stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding);
dilation *
FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]),
kernel));
image_offset.push_back(dilation * FloorDiv(c_step * desc.smem_box_channel,
desc.global_shape[0] * kernel));
PrimExpr h_dim =
FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1,
stride) +
1;
PrimExpr w_dim =
FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1,
stride) +
1;
global_coords.push_back(
stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding);
global_coords.push_back(
stride *
FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) -
padding);
global_coords.push_back(
stride * FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) - padding);
global_coords.push_back(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
Array<PrimExpr> args;
args.reserve(desc.rank * 2 + 1);
args.push_back(create_desc);
args.push_back(0); // mbar placeholder
args.push_back(0); // mbar placeholder
auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst;
auto shared_addr = dst_buffer.access_ptr(2);
args.push_back(shared_addr);
for (auto coord : global_coords) args.push_back(coord);
for (auto offset : image_offset) args.push_back(offset);
for (auto coord : global_coords)
args.push_back(coord);
for (auto offset : image_offset)
args.push_back(offset);
Stmt tma_copy =
IfThenElse(EQ(T.thread_var, 0), Evaluate(Call(DataType::Handle(), TMALoadIm2ColOp(), args)));
IfThenElse(EQ(T.thread_var, 0),
Evaluate(Call(DataType::Handle(), TMALoadIm2ColOp(), args)));
return tma_copy;
}
......@@ -348,11 +388,16 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
args.push_back(data_type);
args.push_back(static_cast<int>(rank));
args.push_back(global_addr);
for (auto e : global_shape) args.push_back(e);
for (auto e : global_stride) args.push_back(e);
for (auto e : elem_stride) args.push_back(e);
for (auto e : lower_corner) args.push_back(e);
for (auto e : upper_corner) args.push_back(e);
for (auto e : global_shape)
args.push_back(e);
for (auto e : global_stride)
args.push_back(e);
for (auto e : elem_stride)
args.push_back(e);
for (auto e : lower_corner)
args.push_back(e);
for (auto e : upper_corner)
args.push_back(e);
args.push_back(smem_box_pixel);
args.push_back(smem_box_channel);
args.push_back(interleave);
......@@ -365,7 +410,8 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(8)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -37,7 +37,7 @@ struct TMAIm2ColDesc {
size_t rank;
int data_type;
Array<PrimExpr> global_shape, global_stride, elem_stride; // rank
Array<PrimExpr> lower_corner, upper_corner; // rank - 2
Array<PrimExpr> lower_corner, upper_corner; // rank - 2
PrimExpr global_addr;
int smem_box_pixel, smem_box_channel;
int swizzle;
......@@ -49,18 +49,18 @@ struct TMAIm2ColDesc {
};
class Conv2DIm2ColOp : public Operator {
public:
public:
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
static const Op& Get();
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op &Get();
private:
private:
Buffer src, dst;
int stride, padding, dilation, kernel;
PrimExpr nhw_step, c_step;
};
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_BULK_COPY_H_
\ No newline at end of file
#endif // TVM_TL_OP_BULK_COPY_H_
\ No newline at end of file
......@@ -14,9 +14,9 @@
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "builtin.h"
namespace tvm {
......@@ -37,7 +37,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3){
if (args.size() >= 3) {
coalesced_width = Downcast<IntImm>(args[2]);
}
}
......@@ -46,17 +46,20 @@ Array<IterVar> Copy::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent)) continue;
if (is_one(src_range[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)});
idx++;
loop_vars.push_back({Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar>& ivs, int src_dst) const {
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0;
......@@ -72,14 +75,16 @@ Array<PrimExpr> Copy::MakeIndices(const Array<IterVar>& ivs, int src_dst) const
return indices;
}
PrimExpr Copy::MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& ivs,
Array<PrimExpr> extents, int src_dst) const {
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent)) continue;
if (is_one(ranges[i]->extent))
continue;
PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
......@@ -94,14 +99,16 @@ PrimExpr Copy::MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& iv
return {};
else {
PrimExpr cond = cond_list[0];
for (size_t i = 1; i < cond_list.size(); i++) cond = And(cond, cond_list[i]);
for (size_t i = 1; i < cond_list.size(); i++)
cond = And(cond, cond_list[i]);
return cond;
}
}
For Copy::MakeSIMTLoop(arith::Analyzer* analyzer) const {
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
for (const auto& iv : loop_vars) analyzer->Bind(iv->var, iv->dom);
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
......@@ -110,44 +117,52 @@ For Copy::MakeSIMTLoop(arith::Analyzer* analyzer) const {
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
PrimExpr value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype) value = Cast(dst->dtype, value);
if (src_predicate.defined()) value = if_then_else(src_predicate, value, make_zero(dst->dtype));
if (src->dtype != dst->dtype)
value = Cast(dst->dtype, value);
if (src_predicate.defined())
value = if_then_else(src_predicate, value, make_zero(dst->dtype));
Stmt body = BufferStore(dst, value, dst_indices);
if (dst_predicate.defined()) body = IfThenElse(dst_predicate, body);
if (dst_predicate.defined())
body = IfThenElse(dst_predicate, body);
for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()){
if (coalesced_width.defined()) {
annotations.Set("coalesced_width", coalesced_width);
}
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, ForKind::kParallel, body, NullOpt, annotations);
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, NullOpt, annotations);
}
return Downcast<For>(body);
}
Stmt Copy::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
if (ldsm_stmt.defined()) return ldsm_stmt;
if (ldsm_stmt.defined())
return ldsm_stmt;
Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
if (bulk_copy_stmt.defined()) return bulk_copy_stmt;
if (bulk_copy_stmt.defined())
return bulk_copy_stmt;
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = std::make_unique<ParallelOp>(fused_loop);
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap}, InferLevel::kFree);
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout());
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop);
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
}
Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
// Check buffer scope
bool is_ldmatrix;
if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
......@@ -162,21 +177,26 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
// Check no predicates
Array<IterVar> loop_vars = MakeIterVars();
if (loop_vars.size() < 2) return Stmt();
for (const auto& iv : loop_vars) analyzer->Bind(iv->var, iv->dom);
if (loop_vars.size() < 2)
return Stmt();
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
if (src_predicate.defined() || dst_predicate.defined()) return Stmt();
if (src_predicate.defined() || dst_predicate.defined())
return Stmt();
Buffer shared_tensor = is_ldmatrix ? src : dst;
Buffer local_tensor = is_ldmatrix ? dst : src;
Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
Array<PrimExpr> local_indices_transformed = local_layout->Forward(local_indices);
Array<PrimExpr> local_indices_transformed =
local_layout->Forward(local_indices);
local_tensor = T.buffer_remap[local_tensor];
// currently only support 1-d case
if (local_layout->OutputDim() != 1) return Stmt();
if (local_layout->OutputDim() != 1)
return Stmt();
Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
Array<PrimExpr> shared_indices_transformed = shared_indices;
......@@ -193,27 +213,32 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
IterVar row_var = loop_vars[loop_vars.size() - 2];
PrimExpr local_layout_thread_map =
FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32);
PrimExpr matrix_8x8_thread_map =
makeGemmFragment8x8()->ForwardThread({FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
PrimExpr matrix_8x8_thread_map_trans = makeGemmFragment8x8Transposed()->ForwardThread(
PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
PrimExpr local_indices_flattened = local_tensor.OffsetOf(local_indices_transformed).back();
PrimExpr matrix_8x8_thread_map_trans =
makeGemmFragment8x8Transposed()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
PrimExpr local_indices_flattened =
local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, col_var->var, col_var->dom->extent, 2,
analyzer)) {
IndiceCanVectorize(local_indices_flattened, col_var->var,
col_var->dom->extent, 2, analyzer)) {
is_transposed = false;
} else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, row_var->var, row_var->dom->extent, 2,
analyzer)) {
} else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, row_var->var,
row_var->dom->extent, 2, analyzer)) {
is_transposed = true;
} else {
return Stmt();
}
// Check shared_layout is 16 bytes continuous
if (shared_tensor->dtype.bytes() != 2) return Stmt();
PrimExpr flattened_indice = shared_tensor.OffsetOf(shared_indices_transformed).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var, loop_vars.back()->dom->extent, 8,
analyzer))
if (shared_tensor->dtype.bytes() != 2)
return Stmt();
PrimExpr flattened_indice =
shared_tensor.OffsetOf(shared_indices_transformed).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
loop_vars.back()->dom->extent, 8, analyzer))
return Stmt();
// Can only support local_range to be a full range
......@@ -232,7 +257,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
num = 2;
Array<PrimExpr> args;
const Op& op = is_ldmatrix ? tl::LDMatrixOp() : tl::STMatrixOp();
const Op &op = is_ldmatrix ? tl::LDMatrixOp() : tl::STMatrixOp();
args.push_back(static_cast<int>(is_transposed));
args.push_back(num);
......@@ -240,52 +265,60 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
// if not transpose
// coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
// if transpose
// coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread % 8 / 2)
// coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
// % 8 / 2)
Var local_iter("i");
Layout inv = local_layout->Inverse();
Array<PrimExpr> shared_coords;
PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
if (!is_transposed)
shared_coords =
inv->Forward({local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
warp + FloorMod(T.thread_var, 8) * 4});
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
warp + FloorMod(T.thread_var, 8) * 4});
else
shared_coords =
inv->Forward({local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2),
warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
shared_coords.pop_back(); // remove rep
if (shared_layout.defined()) shared_coords = shared_layout->Forward(shared_coords);
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2),
warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
shared_coords.pop_back(); // remove rep
if (shared_layout.defined())
shared_coords = shared_layout->Forward(shared_coords);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_ldmatrix ? 1 : 2, DataType::Handle(), 1, shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
args.push_back(shared_addr);
if (is_ldmatrix) {
// Can only support same dtype for ldmatrx
if (local_tensor->dtype != shared_tensor->dtype) return Stmt();
PrimExpr local_addr =
local_tensor.access_ptr(2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
if (local_tensor->dtype != shared_tensor->dtype)
return Stmt();
PrimExpr local_addr = local_tensor.access_ptr(
2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
args.push_back(local_addr);
} else {
for (int i = 0; i < num; i++) {
PrimExpr value0 = BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
PrimExpr value1 = BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
PrimExpr value0 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
PrimExpr value1 =
BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
if (local_tensor->dtype != shared_tensor->dtype) {
value0 = Cast(shared_tensor->dtype, value0);
value1 = Cast(shared_tensor->dtype, value1);
}
PrimExpr value_packed = Call(DataType::Int(32), PackB16Op(), {value0, value1});
PrimExpr value_packed =
Call(DataType::Int(32), PackB16Op(), {value0, value1});
args.push_back(value_packed);
}
}
auto body = Evaluate(Call(DataType::Handle(), op, args));
For for_node = For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
For for_node =
For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
for_node = LoopPragmaUnroll(for_node);
return for_node;
}
LayoutMap Copy::InferLayout(const LayoutInferArgs& T, InferLevel level) {
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// Use parallel op to infer the layout
if (par_op_ == nullptr) {
arith::Analyzer analyzer;
......@@ -303,7 +336,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
}
}
For Fill::MakeSIMTLoop(arith::Analyzer* analyzer) const {
For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
int ndim = dst->shape.size();
Array<IterVar> loop_vars;
Array<PrimExpr> dst_indices;
......@@ -314,22 +347,26 @@ For Fill::MakeSIMTLoop(arith::Analyzer* analyzer) const {
}
Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) {
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, ForKind::kParallel, body);
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body);
}
return Downcast<For>(body);
}
Stmt Fill::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.block_size, T.layout_map}, InferLevel::kFree);
par_op->InferLayout({T.target, T.block_size, T.layout_map}, InferLevel::kFree);
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout());
par_op->InferLayout({T.target, T.block_size, T.layout_map},
InferLevel::kFree);
par_op->InferLayout({T.target, T.block_size, T.layout_map},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop);
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
} else if (dst.scope() == "local") {
......@@ -339,16 +376,17 @@ Stmt Fill::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
}
}
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(3)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_REGISTER_TL_OP(Fill, fill)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -19,25 +19,25 @@ namespace tl {
using namespace tir;
class Copy : public Operator {
public:
public:
Copy(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op& Get();
static const Op &Get();
protected:
Stmt LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const;
Stmt LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const;
protected:
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
For MakeSIMTLoop(arith::Analyzer* analyzer) const;
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
Array<IterVar> MakeIterVars() const;
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> MakeIndices(const Array<IterVar>& ivs, int src_dst) const;
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
PrimExpr MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& ivs,
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_;
......@@ -50,18 +50,18 @@ class Copy : public Operator {
};
class Fill : public Operator {
public:
public:
Fill(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
static const Op& Get();
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
static const Op &Get();
private:
For MakeSIMTLoop(arith::Analyzer* analyzer) const;
private:
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
tir::Buffer dst;
PrimExpr value;
};
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_ELEM_H_
\ No newline at end of file
#endif // TVM_TL_OP_ELEM_H_
\ No newline at end of file
......@@ -42,7 +42,7 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
trans_B = args[4].as<Bool>().value();
M = args[5].as<IntImm>().value()->value;
N = args[6].as<IntImm>().value()->value;
K = args[7].as<IntImm>().value()->value;
K = args[7].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
if (args.size() > 9) {
kPack = args[9].as<IntImm>().value()->value;
......@@ -52,11 +52,13 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
}
}
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target) const {
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps,
Target target) const {
int m_warp = 1, n_warp = 1;
if (TargetIsHopper(target)) {
ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
if (this->policy == GemmWarpPolicy::kFullRow || this->policy == GemmWarpPolicy::kSquare) {
if (this->policy == GemmWarpPolicy::kFullRow ||
this->policy == GemmWarpPolicy::kSquare) {
m_warp = num_warps;
ICHECK(this->M % num_warps == 0);
} else if (this->policy == GemmWarpPolicy::kFullCol) {
......@@ -100,14 +102,15 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target) con
return {m_warp, n_warp};
}
Stmt Gemm::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
if (TargetIsCDNA(T.target)) {
warp_size = 64;
}
ICHECK(T.block_size % warp_size == 0);
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target);
std::stringstream ss;
std::string op_name = "tl::gemm_ss";
if (A.scope() == "local.fragment") {
......@@ -137,19 +140,23 @@ Stmt Gemm::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
return Evaluate(new_call);
}
LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (completed_) return {};
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (completed_)
return {};
LayoutMap results;
ICHECK(C.scope() == "local.fragment");
if (TargetIsVolta(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
true, trans_A ? 1 : 2));
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]), true,
trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
......@@ -158,25 +165,31 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
}
ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
false, trans_B ? 2 : 1));
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]), false,
trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]),
A->dtype.bits(), trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits()));
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits()));
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]),
B->dtype.bits(), trans_B ? 2 : 1));
} else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false);
......@@ -186,18 +199,23 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
}
} else if (TargetIsHopper(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits());
auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]),
A->dtype.bits(), trans_A ? 1 : 2));
} else {
ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits()));
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits()));
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]),
B->dtype.bits(), trans_B ? 2 : 1));
} else {
ICHECK(0) << "WGMMA only support B in shared.";
......@@ -206,35 +224,42 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
ICHECK(trans_B == true) << "Currently only support Transpose B for CDNA";
const int warp_size = 64;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
// Make Linear Memory Access Layout
// auto shared_layout =
// makeGemmLayoutLinear(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]));
// makeGemmLayoutLinear(*as_const_int(A->shape[0]),
// *as_const_int(A->shape[1]));
// Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
A->dtype.bits(), kPack);
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]),
A->dtype.bits(), kPack);
results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") {
results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A));
results.Set(
A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A));
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
// Make Linear Memory Access Layout
// auto shared_layout =
// makeGemmLayoutLinear(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]));
// makeGemmLayoutLinear(*as_const_int(B->shape[0]),
// *as_const_int(B->shape[1]));
// Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
B->dtype.bits(), kPack);
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]),
B->dtype.bits(), kPack);
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
......@@ -251,7 +276,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP(Gemm, gemm)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -18,18 +18,18 @@ namespace tl {
using namespace tir;
class Gemm : public Operator {
public:
public:
Gemm(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
static const Op& Get();
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
enum class GemmWarpPolicy {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
} policy;
private:
private:
std::pair<int, int> ComputeWarpPartition(int num_warps, Target target) const;
Array<PrimExpr> call_args;
......@@ -38,11 +38,11 @@ class Gemm : public Operator {
int M, N, K;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int kPack = 1;
bool completed_ = false;
};
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file
#endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment