Commit 9a640856 authored by qisan's avatar qisan
Browse files

[Feature] Add some testing files for Hygon DCU

parent 2c490782
......@@ -1383,7 +1383,6 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
CodeGenC::PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f, stream);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
......
......@@ -165,7 +165,7 @@ public:
auto tx = lane_id;
auto alane_id = lane_id;
auto blane_id = (lane_id & 15) / 4 + (lane_id & 3) * 4 + (lane_id / 16) * 16;
auto blane_id = ((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4);
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
......@@ -246,7 +246,7 @@ public:
auto tx = lane_id;
auto alane_id = lane_id;
auto blane_id = (lane_id & 15) / 4 + (lane_id & 3) * 4 + (lane_id / 16) * 16;
auto blane_id = ((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4);
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
......
import torch
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang.language as T
# from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
tilelang.disable_cache()
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
a_transposed=False,
b_transposed=True,
k_pack=1,
):
assert in_dtype in [
"float16",
"bfloat16",
"int8",
], "Currently only float16, bfloat16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if in_dtype in {"float8_e4m3fnuz", "int8"}:
micro_size_k = 32
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 * k_pack
shared_scope = "shared"
cache_write_shared = False
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size
local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mmac_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
k_pack=k_pack,
)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=0):
# Load A into shared memory
if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, ko * block_K], A_shared)
# Load B into shared memory
if b_transposed:
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
# Load A into fragment
mmac_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mmac_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mmac_emitter.mmac(A_local, B_local, C_local)
# Perform STMatrix
mmac_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
j // micro_size_y,
i // micro_size_x,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype="float32",
a_transposed=False,
b_transposed=True,
k_pack=1):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack)
print(matmul)
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
if a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32),
B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
# assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
# assert_tl_matmul_correctness(
# 128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -118,12 +118,10 @@ class MatrixCoreIntrinEmitter(object):
in_dtype_abbrv = {
"float16": "f16",
"float32": "f32",
"int8": "i8"
"int8": "i8",
"bfloat16" : "bf16"
}[in_dtype]
if in_dtype_abbrv == "i8":
self.mmac_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8"
else:
self.mmac_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
......@@ -581,7 +579,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(((tx & 15) / 4 + (tx & 3) * 4 + (tx / 16) * 16), local_id))
row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id))
l, r = (
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
......@@ -591,7 +589,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(((tx & 15) / 4 + (tx & 3) * 4 + (tx / 16) * 16), local_id))
row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment