"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "ecf337bab5c23708d80a4c537c6b49dbda6e23b2"
Commit 2c490782 authored by Lukinon's avatar Lukinon Committed by qisan
Browse files

[Feature] Add support for Hygon DCU

parent 7d389a43
from tilelang import tvm as tvm
from tvm import DataType
import tilelang
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
from tilelang import disable_cache
disable_cache()
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
# chunk = 32 if in_dtype == "float16" else 64
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMAC Wrapper to Auto Generate Code for MMAC
mmac_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def gemm_intrinsics(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mmac_emitter.ldmatrix_a(A_local, A_shared, ki)
# Load B into fragment
mmac_emitter.ldmatrix_b(B_local, B_shared, ki)
# Perform Matrix Multiplication
mmac_emitter.mmac(A_local, B_local, C_local)
# Perform STMatrix
mmac_emitter.stmatrix(C_local, C_shared)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
j // micro_size_y,
i // micro_size_x,
i % micro_size_x,
j % micro_size_y,
]
return gemm_intrinsics
def ref_program(A, B):
return A @ B.T
def main():
M, N, K = 16384, 16384, 16384
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
profiler = kernel.get_profiler()
latency = profiler.do_bench(profiler.func, warmup=25)
print(latency)
print(kernel.get_kernel_source())
# Ensure that the latency is not None
assert latency is not None
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main()
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <torch/extension.h>
#include <hip/hip_runtime.h>
__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[block_count++] = idx;
}
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int N_HEADS,
int N_ROWS,
int BLOCK_SIZE_M,
int BLOCK_SIZE_N,
int NNZ_V,
int NNZ_S
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
int v_idx = vertical_indexes[v++];
int s_idx = slash_indexes[s++];
while (s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
v_idx = end_m + BLOCK_SIZE_M;
}
} else {
if (s < NNZ_S) {
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
break;
}
if (s_idx > range_end + BLOCK_SIZE_M) {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int BATCH_SIZE,
int N_HEADS,
int N_ROWS,
int NNZ_V,
int NNZ_S
) {
const int BLOCK_SIZE_M = 64;
const int BLOCK_SIZE_N = 64;
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0,
seqlens, vertical_indexes, slash_indexes,
block_count, block_offset, column_count, column_index,
N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
);
}
std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
assert(block_size_M == 64);
assert(block_size_N == 64);
hipSetDevice(seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);
return { block_count, block_offset, column_count, column_index };
}
...@@ -156,6 +156,23 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n, ...@@ -156,6 +156,23 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
return block_layout; return block_layout;
} }
Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64)
LOG(FATAL) << "Not supported";
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, true);
return block_layout;
}
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size) { const int element_size) {
...@@ -730,6 +747,7 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, ...@@ -730,6 +747,7 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
if (!k_inner && element_size == 8) // int8 KxN if (!k_inner && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 8) == 0) else if (mat_continuous % (vector_size * 8) == 0)
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0) else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
......
...@@ -150,6 +150,9 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n, ...@@ -150,6 +150,9 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size); const int element_size);
Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size); const int element_size);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
*/ */
#include "gemm.h" #include "gemm.h"
#include <fstream>
#include "builtin.h" #include "builtin.h"
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -828,9 +828,16 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ...@@ -828,9 +828,16 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
ICHECK(C.scope() == "local.fragment") ICHECK(C.scope() == "local.fragment")
<< "CDNA gemm (FMMA) only supports C in local.fragment scope, got " << "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<< C.scope(); << C.scope();
if (TargetIsDCU(T.target))
{
auto fragment =
makeGemmFragmentCDCU(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
} else {
auto fragment = auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
}
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
......
...@@ -137,6 +137,7 @@ void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) { ...@@ -137,6 +137,7 @@ void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
std::string CodeGenTileLangHIP::Finish() { std::string CodeGenTileLangHIP::Finish() {
// hip must need a header file. // hip must need a header file.
decl_stream << "#define HIP_ENABLE_WARP_SYNC_BUILTINS\n";
decl_stream << "#include <hip/hip_runtime.h>\n"; decl_stream << "#include <hip/hip_runtime.h>\n";
if (need_mma_h_) { if (need_mma_h_) {
decl_stream << "#include <mma.h>\n"; decl_stream << "#include <mma.h>\n";
...@@ -146,12 +147,12 @@ std::string CodeGenTileLangHIP::Finish() { ...@@ -146,12 +147,12 @@ std::string CodeGenTileLangHIP::Finish() {
decl_stream << "#include <tl_templates/hip/hip_fp8.h>\n"; decl_stream << "#include <tl_templates/hip/hip_fp8.h>\n";
} }
decl_stream << "#include <tl_templates/hip/gemm.h>\n"; decl_stream << "#include <tl_templates/dcu_hip/gemm.h>\n";
decl_stream << "#include <tl_templates/hip/copy.h>\n"; decl_stream << "#include <tl_templates/dcu_hip/copy.h>\n";
decl_stream << "#include <tl_templates/hip/reduce.h>\n"; decl_stream << "#include <tl_templates/dcu_hip/reduce.h>\n";
decl_stream << "#include <tl_templates/hip/ldsm.h>\n"; decl_stream << "#include <tl_templates/dcu_hip/ldsm.h>\n";
decl_stream << "#include <tl_templates/hip/threadblock_swizzle.h>\n"; decl_stream << "#include <tl_templates/dcu_hip/threadblock_swizzle.h>\n";
decl_stream << "#include <tl_templates/hip/debug.h>\n"; decl_stream << "#include <tl_templates/dcu_hip/debug.h>\n";
decl_stream << "\n"; decl_stream << "\n";
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
...@@ -952,6 +953,71 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -952,6 +953,71 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("{c_ref}", c_ref); replacer.register_rule("{c_ref}", c_ref);
replacer.register_rule("{c_bias}", c_bias); replacer.register_rule("{c_bias}", c_bias);
os << replacer.rewrite(call_mfma_code); os << replacer.rewrite(call_mfma_code);
} else if (op->op.same_as(tl::tvm_mmac())) {
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: float16, float32, ...
// arg 4: B precision: float16, float32, ...
// arg 5: C precision: float32, float64, ...
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
ICHECK(op->args.size() == 12U)
<< "Invalid number of arguments for tvm_mfma";
std::string prefix = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_bias = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
ICHECK(A_layout == "row" || B_layout == "row")
<< "Matrix core only support row major";
// map for dtype -> float32x4 -> float4
std::unordered_map<std::string, std::string> dtype_map = {
{"int8", "char"},
{"int32", "int"},
{"int8x4", "int32_t"},
{"int8x8", "int64_t"},
{"int32x4", "int32x4"},
{"float16", "half"},
{"float32", "float"},
{"float64", "double"},
{"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"},
{"float32x4", "float32x4"},
{"float8_e4m3fnuzx4", "fp8_e4_4_t"},
{"float8_e4m3fnuzx8", "long"},
{"float32x16", "float32x16"}};
std::string call_mmac_code = R"({
*((({C_dtype}*){c_ref}) + {c_bias}) = {mmac_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
*((({B_dtype}*){b_ref}) + {b_bias}),
*((({C_dtype}*){c_ref}) + {c_bias}));
})";
std::string mmac_buildin = "__builtin_amdgcn_mmac_" + prefix;
Replacer replacer;
replacer.register_rule("{mmac_buildin}", mmac_buildin);
replacer.register_rule("{A_dtype}", dtype_map[A_dtype]);
replacer.register_rule("{B_dtype}", dtype_map[B_dtype]);
replacer.register_rule("{C_dtype}", dtype_map[C_dtype]);
replacer.register_rule("{a_ref}", a_ref);
replacer.register_rule("{a_bias}", a_bias);
replacer.register_rule("{b_ref}", b_ref);
replacer.register_rule("{b_bias}", b_bias);
replacer.register_rule("{c_ref}", c_ref);
replacer.register_rule("{c_bias}", c_bias);
os << replacer.rewrite(call_mmac_code);
} else if (op->op.same_as(builtin::thread_return())) { } else if (op->op.same_as(builtin::thread_return())) {
os << "return"; os << "return";
} else if (op->op.same_as(tl::tl_gemm())) { } else if (op->op.same_as(tl::tl_gemm())) {
......
...@@ -240,6 +240,16 @@ TVM_REGISTER_OP("tir.fmod") ...@@ -240,6 +240,16 @@ TVM_REGISTER_OP("tir.fmod")
DispatchPureExtern<HIPMath>); DispatchPureExtern<HIPMath>);
// Register low-level builtin ops. // Register low-level builtin ops.
TVM_REGISTER_OP("tir.hip.__shfl")
.set_num_inputs(3)
.add_argument("var", "Expr", "Value to shuffle")
.add_argument("lane", "Expr", "Source lane")
.add_argument("width", "Expr", "Warp width")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_REGISTER_OP("tir.hip.__shfl_sync") TVM_REGISTER_OP("tir.hip.__shfl_sync")
.set_num_inputs(4) .set_num_inputs(4)
.add_argument("mask", "Expr", "The thread mask.") .add_argument("mask", "Expr", "The thread mask.")
...@@ -286,4 +296,4 @@ TVM_REGISTER_OP("tir.hip.__activemask") ...@@ -286,4 +296,4 @@ TVM_REGISTER_OP("tir.hip.__activemask")
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "utils.h" #include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -78,6 +79,16 @@ bool TargetIsCDNA(Target target) { ...@@ -78,6 +79,16 @@ bool TargetIsCDNA(Target target) {
return false; return false;
} }
bool TargetIsDCU(Target target) {
if (!TargetIsRocm(target))
return false;
if (target->attrs.count("mcpu")) {
// if mcpu start with "gfx936", it is DCU
return mcpu.find("gfx936") == 0;
}
return false;
}
bool TargetHasAsyncCopy(Target target) { bool TargetHasAsyncCopy(Target target) {
if (TargetIsCuda(target)) { if (TargetIsCuda(target)) {
int arch = GetArchInt(target); int arch = GetArchInt(target);
......
...@@ -22,6 +22,7 @@ bool TargetIsHopper(Target target); ...@@ -22,6 +22,7 @@ bool TargetIsHopper(Target target);
bool TargetIsSm100(Target target); bool TargetIsSm100(Target target);
bool TargetIsSM120(Target target); bool TargetIsSM120(Target target);
bool TargetIsCDNA(Target target); bool TargetIsCDNA(Target target);
bool TargetIsDCU(Target target);
bool TargetHasAsyncCopy(Target target); bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target); bool TargetHasLdmatrix(Target target);
......
#pragma once
#include "core.hpp"
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
// #include <rocwmma/rocwmma.hpp>
#define HIPRT_INF_F __int_as_float(0x7f800000)
#define HIPRT_NEGINF_F __int_as_float(0xff800000)
#define HIPRT_NAN_F __int_as_float(0x7fffffff)
#define HIPRT_MIN_DENORM_F __int_as_float(0x00000001)
#define HIPRT_MAX_NORMAL_F __int_as_float(0x7f7fffff)
#define HIPRT_NEG_ZERO_F __int_as_float(0x80000000)
#define HIPRT_ZERO_F 0.0f
#define HIPRT_ONE_F 1.0f
/* double precision constants */
#define HIPRT_INF __hiloint2double(0x7ff00000, 0x00000000)
#define HIPRT_NAN __hiloint2double(0xfff80000, 0x00000000)
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE_NOINLINE __noinline__ __device__
#define TILELANG_CHECK(stmt) \
do { \
hipError_t __err = (stmt); \
if (__err != hipSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
__LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \
return -1; \
} \
} while (0)
#define TILELANG_CHECK_LAST_ERROR(kernel_name) \
do { \
hipError_t __err = hipGetLastError(); \
if (__err != hipSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "kernel_name: %s - %s", \
hipGetErrorName(__err), hipGetErrorString(__err)); \
return -1; \
} \
} while (0)
#define half _Float16
#define __float2half_rn(x) half(x)
#define hpow __ocml_pown_f16
#define hsqrt __ocml_sqrt_f16
using float16_t = _Float16;
using float16x2 =
__attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t;
using float16x4 =
__attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;
using float16x8 =
__attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t;
using float16x16 =
__attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t;
using half_t = float16_t;
using bfloat16_t = __hip_bfloat16;
struct bfloat16x2 {
bfloat16_t x, y;
};
struct bfloat16x4 {
bfloat16_t data[4];
};
struct bfloat16x8 {
bfloat16_t data[8];
};
struct bfloat16x16 {
bfloat16_t data[16];
};
typedef
__attribute__((__vector_size__(4 * sizeof(short)))) short bfloat16x4_vec;
using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t;
// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
template <typename T>
struct is_half_type : std::false_type {};
template <>
struct is_half_type<__half> : std::true_type {};
template <>
struct is_half_type<half_t> : std::true_type {};
template <typename T>
inline constexpr bool is_half_v = is_half_type<std::decay_t<T>>::value;
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1* address, T2 val) {
if constexpr (is_half_v<T1>) {
__half* addr = reinterpret_cast<__half*>(address);
__half hval = __float2half(static_cast<float>(val));
atomicAdd(addr, hval);
} else {
atomicAdd(address, static_cast<T1>(val));
}
}
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 &ref, T2 val) {
AtomicAdd(&ref, val);
}
template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) {
return atomicAdd(&ref, static_cast<T1>(val));
}
template <typename T>
TL_DEVICE void AtomicAddx4(T* ref, const T val[4]) {
atomicAdd(&ref[0], val[0]);
atomicAdd(&ref[1], val[1]);
atomicAdd(&ref[2], val[2]);
atomicAdd(&ref[3], val[3]);
}
\ No newline at end of file
#pragma once
#include "common.h"
using f32 = float;
// using f16 = _Float16;
using u8 = std::uint8_t;
using u16 = std::uint16_t;
using u32 = std::uint32_t;
using index_t = u32;
using ck_tile::int32x4_t;
struct __attribute__((packed)) buffer_resource {
const void *ptr;
uint32_t range;
uint32_t config;
};
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
uint32_t size = 0xffffffff) {
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}
__device__ void init_m0(uint32_t m0_value) {
asm volatile("s_mov_b32 m0, %0" : : "s"(m0_value) : "memory");
}
__device__ void inc_m0(uint32_t m0_inc) {
asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory");
}
namespace tl {
// AMDGPU automatically commit memory fence
TL_DEVICE void cp_async_commit() {}
// Global Memory only fence
__device__ void async_gld_fence(index_t cnt) {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// Global Memory and Shared Memory fence
__device__ void async_gld_sld_fence(index_t cnt) {
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); }
template <int N = 0> TL_DEVICE void cp_async_wait() {
async_gld_fence(N);
// or
// async_gld_sld_fence(N);
}
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
index_t voffset) {
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(voffset), "s"(rsrc)
: "memory");
}
template <int N>
TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
if constexpr (N == 16) {
*(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
} else if constexpr (N == 8) {
*(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
} else if constexpr (N == 4) {
async_buffer_load_dword_v(
lds_base_ptr,
make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
threadIdx.x * N /*assume 4 bytes*/);
}
}
template <int N>
TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
void *global_base_ptr, bool cond) {
if constexpr (N == 16) {
*(uint4 *)lds_base_ptr =
cond ? *(uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0);
} else if constexpr (N == 8) {
*(uint2 *)lds_base_ptr =
cond ? *(uint2 *)global_base_ptr : make_uint2(0, 0);
} else {
if (cond) {
async_buffer_load_dword_v(
lds_base_ptr,
make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
threadIdx.x * N /*assume 4 bytes*/);
} else {
*(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0);
}
}
}
} // namespace tl
#ifdef __HIPCC__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
defined(__gfx9__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#else
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
namespace ck_tile{
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
template <typename T>
CK_TILE_HOST_DEVICE constexpr T max(T x)
{
return x;
}
template <typename T>
CK_TILE_HOST constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <typename T>
CK_TILE_DEVICE constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <>
CK_TILE_DEVICE float max(float x, float y)
{
return __builtin_fmaxf(x, y); // can resultin v_max3_f32
}
template <>
CK_TILE_DEVICE double max(double x, double y)
{
return __builtin_fmax(x, y); // maybe still v_max3_f32
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T min(T x)
{
return x;
}
template <typename T>
CK_TILE_HOST constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <typename T>
CK_TILE_DEVICE constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <>
CK_TILE_DEVICE float min(float x, float y)
{
return __builtin_fminf(x, y);
}
template <>
CK_TILE_DEVICE double min(double x, double y)
{
return __builtin_fmin(x, y);
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...));
}
}
#pragma once
#include <hip/hip_runtime.h>
// Base template declaration
template <typename T> __device__ void debug_print_var(const char *msg, T var);
// Specialization for signed char type
template <>
__device__ void debug_print_var<signed char>(const char *msg, signed char var) {
const char *safe_msg = msg;
int value = static_cast<int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed "
"char value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
}
// Specialization for unsigned char type
template <>
__device__ void debug_print_var<unsigned char>(const char *msg,
unsigned char var) {
const char *safe_msg = msg;
unsigned int value = static_cast<unsigned int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned char value=%u\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
}
// Specialization for int type
template <> __device__ void debug_print_var<int>(const char *msg, int var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
}
// Specialization for unsigned int type
template <>
__device__ void debug_print_var<unsigned int>(const char *msg,
unsigned int var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned int value=%u\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
}
// Specialization for float type
template <> __device__ void debug_print_var<float>(const char *msg, float var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
"value=%f\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
}
// Specialization for double type
template <>
__device__ void debug_print_var<double>(const char *msg, double var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double "
"value=%lf\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, var);
}
// Specialization for bool type
template <> __device__ void debug_print_var<bool>(const char *msg, bool var) {
const char *safe_msg = msg;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
"value=%s\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z,
var ? "true" : "false");
}
// Specialization for short type
template <> __device__ void debug_print_var<short>(const char *msg, short var) {
const char *safe_msg = msg;
int value = static_cast<int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=short "
"value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
}
// Specialization for unsigned short type
template <>
__device__ void debug_print_var<unsigned short>(const char *msg,
unsigned short var) {
const char *safe_msg = msg;
unsigned int value = static_cast<unsigned int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned short value=%u\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, value);
}
// Template declaration for device-side debug printing (buffer only)
template <typename T>
__device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
int index, T var);
// Specialization for signed char type
template <>
__device__ void
debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
int index, signed char var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
int value = static_cast<int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=signed char value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, value);
}
// Specialization for unsigned char type
template <>
__device__ void
debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name,
int index, unsigned char var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
unsigned int value = static_cast<unsigned int>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=unsigned char value=%u\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, value);
}
// Specialization for integer type
template <>
__device__ void debug_print_buffer_value<int>(const char *msg,
const char *buf_name, int index,
int var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%d\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, var);
}
// Specialization for float type
template <>
__device__ void debug_print_buffer_value<float>(const char *msg,
const char *buf_name, int index,
float var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=float value=%f\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, var);
}
// Specialization for half_t type
template <>
__device__ void debug_print_buffer_value<half_t>(const char *msg,
const char *buf_name,
int index, half_t var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
float value = static_cast<float>(var);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half_t value=%f\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, value);
}
// Specialization for double type
template <>
__device__ void debug_print_buffer_value<double>(const char *msg,
const char *buf_name,
int index, double var) {
const char *safe_msg = msg;
const char *safe_buf_name = buf_name;
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=double value=%lf\n",
safe_msg, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, var);
}
#pragma once
#include "common.h"
#include <type_traits>
namespace tl {
// Trait to determine the MFMA instruction to use based on data type
template <typename T> struct MfmaTraits;
// Specialization for int8
template <> struct MfmaTraits<int8_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) {
int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
*c = __builtin_amdgcn_mmac_i32_16x16x32i8(*b_packed, *a_packed, *c);
}
};
// Specialization for half/float16
template <> struct MfmaTraits<half> {
template <typename AccType>
static TL_DEVICE void mfma_op(const half *b, const half *a, AccType *c) {
*c = __builtin_amdgcn_mmac_f32_16x16x16f16(*((float16x4 *)b),
*((float16x4 *)a), *c);
}
};
// Specialization for bfloat16_t
template <> struct MfmaTraits<bfloat16_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const bfloat16_t *b, const bfloat16_t *a,
AccType *c) {
bfloat16x4_vec b_vec, a_vec;
// Reinterpret the pointers
short *b_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(b));
short *a_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(a));
// Copy the data
for (int i = 0; i < 4; ++i) {
b_vec[i] = b_short[i];
a_vec[i] = a_short[i];
}
// Call the intrinsic and store the result directly to c
*c = __builtin_amdgcn_mmac_f32_16x16x16bf16(b_vec, a_vec, *c);
}
};
#if defined(HIP_FP8_ENABLED)
// Specialization for fp8_e4_t
template <> struct MfmaTraits<fp8_e4_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
AccType *c) {
int64_t a_val = *reinterpret_cast<const int64_t *>(a);
int64_t b_val = *reinterpret_cast<const int64_t *>(b);
*c = __builtin_amdgcn_mmac_f32_16x16x32_fp8_fp8(b_val, a_val, *c);
}
};
#endif
// ref to bitblas/tl/mfma_macro_generator.py::kPack
template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA,
bool TransposeB, bool clear_accum, int kPack, typename A_type,
typename B_type, typename C_type, typename AccDataType = float>
class GemmTensorOp {
public:
//static_assert(!clear_accum, "clear_accum=true is not supported yet");
static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16;
static constexpr int micro_size_k = 32 / sizeof(A_type);
static constexpr int vec_size = 8 / sizeof(A_type);
// This part comes from the Codegen
static constexpr int M_Tile = N;
static constexpr int N_Tile = M;
static constexpr int K_Tile = K;
static constexpr int block_row_warps = num_warp_m;
static constexpr int block_col_warps = num_warp_n;
static constexpr int inner_k = K_Tile / (micro_size_k * kPack);
static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x);
static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y);
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen
// part.
static constexpr bool kPadA = true;
static constexpr bool kPadB = true;
static constexpr bool kPadC = true;
static constexpr int BANK_SIZE_BYTES = 128;
static constexpr int warp_size = 64;
TL_DEVICE static constexpr auto reverse_index_map(int thread_id,
int local_id) {
return std::make_pair(thread_id % 16,
(thread_id / 16) * (vec_size * kPack) + local_id);
}
TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id,
int local_id) {
return std::make_pair((thread_id / 16) * (vec_size * kPack) + local_id,
thread_id % 16);
}
/*
* Detailed Implementation please
* checkout bitblas/tl/utils.py:get_swizzle_layout
*/
template <int continuous = 32, int element_size = 2>
TL_DEVICE static auto make_mfma_swizzle_layout(const int row, const int col) {
const auto dtype_bits = element_size * 8;
const int numBanks = 32;
const int bankBitWidth = 32;
const int SIMDWidth = 16;
const int vecSize = vec_size * kPack;
const int innerDimLength = continuous;
const int typeWidthInBit = dtype_bits;
const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
const int maxPhase =
std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
const int phase = (row / perPhase) % maxPhase;
const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize);
const int colOffOrdered = col % vecSize;
const int colOff = colOffSwizzled + colOffOrdered;
return std::make_pair(row, colOff);
}
template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_layout_padded(const int row,
const int col) {
return std::make_pair(row, col);
}
template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_swizzle_layout(const int row,
const int col) {
auto [n_row, n_col] =
make_mfma_swizzle_layout<continuous, element_size>(row, col);
return n_row * continuous + n_col;
}
static TL_DEVICE void body(A_type *A_shared, B_type *B_shared,
C_type *C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
auto warp_m = warp_id % block_row_warps;
auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y;
auto lane_id = tid % warp_size;
auto tx = lane_id;
auto alane_id = lane_id;
auto blane_id = (lane_id & 15) / 4 + (lane_id & 3) * 4 + (lane_id / 16) * 16;
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;
constexpr auto last_dim_b = TransposeB ? K_Tile : M_Tile;
constexpr auto last_dim_a = TransposeA ? N_Tile : K_Tile;
B_type B_local[warp_rows * kPack * local_size_b];
A_type A_local[warp_cols * kPack * local_size_a];
for (int ki = 0; ki < inner_k; ki++) {
// Fetch B into register
for (int i = 0; i < warp_rows; i++) {
const auto l = warp_m * warp_row_tiles + i * micro_size_x;
const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
if constexpr (TransposeB) {
auto [row, col] = reverse_index_map(blane_id, local_id);
B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
} else {
auto [row, col] = reverse_index_map_transposed(blane_id, local_id);
B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
r + row, l + col)];
}
}
}
// Fetch A into register
for (int j = 0; j < warp_cols; j++) {
const auto l = warp_n * warp_col_tiles + j * micro_size_y;
const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
if constexpr (TransposeA) {
auto [row, col] = reverse_index_map_transposed(alane_id, local_id);
A_local[j * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
r + row, l + col)];
} else {
auto [row, col] = reverse_index_map(alane_id, local_id);
A_local[j * kPack * local_size_a + local_id] =
A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
l + row, r + col)];
}
}
}
// Compute
for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j);
auto a_ptr = ((A_type *)A_local) + (j * kPack + kp) * vec_size;
auto b_ptr = ((B_type *)B_local) + (i * kPack + kp) * vec_size;
// Use the trait to select the correct MFMA instruction, either fp8,
// fp16 or bf16 currently
MfmaTraits<A_type>::mfma_op(a_ptr, b_ptr, acc_ptr);
}
}
}
}
}
static TL_DEVICE void body_rs(A_type *A_local, B_type *B_shared,
C_type *C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
auto warp_m = warp_id % block_row_warps;
auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y;
auto lane_id = tid % warp_size;
auto tx = lane_id;
auto alane_id = lane_id;
auto blane_id = (lane_id & 15) / 4 + (lane_id & 3) * 4 + (lane_id / 16) * 16;
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;
constexpr auto last_dim_b = TransposeB ? K_Tile : M_Tile;
constexpr auto last_dim_a = TransposeA ? N_Tile : K_Tile;
B_type B_local[warp_rows * kPack * local_size_b];
for (int ki = 0; ki < inner_k; ki++) {
// Fetch B into register
for (int i = 0; i < warp_rows; i++) {
const auto l = warp_m * warp_row_tiles + i * micro_size_x;
const auto r = ki * (kPack * micro_size_k);
for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
if constexpr (TransposeB) {
auto [row, col] = reverse_index_map(blane_id, local_id);
B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
l + row, r + col)];
} else {
auto [row, col] = reverse_index_map_transposed(blane_id, local_id);
B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
r + row, l + col)];
}
}
}
// Compute
for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j);
auto b_ptr = ((B_type *)B_local) + (i * kPack + kp) * vec_size;
auto a_ptr = ((A_type *)A_local) +
(ki * warp_cols * kPack + j * kPack + kp) * vec_size;
// Use the trait to select the correct MFMA instruction, either fp8,
// fp16 or bf16 currently
MfmaTraits<A_type>::mfma_op(a_ptr, b_ptr, acc_ptr);
}
}
}
}
}
};
} // namespace tl
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int kPack, typename A_type,
typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using Compute =
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
clear_accum, kPack, A_type, B_type, C_type>;
Compute::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int kPack, typename A_type,
typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using Compute =
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
clear_accum, kPack, A_type, B_type, C_type>;
Compute::body_rs(pA, pB, accum);
}
} // namespace tl
#include <hip/amd_detail/amd_hip_fp8.h>
#define HIP_FP8_ENABLED 1
using fp8_e4_t = __hip_fp8_e4m3_fnuz;
using fp8_e4_2_t = __hip_fp8x2_e4m3_fnuz;
// Simple wrapper that provides member access for generated code
struct fp8_e4_4_t {
union {
__hip_fp8x4_e4m3_fnuz data;
struct {
fp8_e4_t x, y, z, w;
};
};
// Default constructor
__device__ fp8_e4_4_t() = default;
// Constructor from __hip_fp8x4_e4m3_fnuz
__device__ fp8_e4_4_t(const __hip_fp8x4_e4m3_fnuz &val) : data(val) {}
// Constructor from float4
__device__ fp8_e4_4_t(const float4 &val) : data(val) {}
// Conversion operator to __hip_fp8x4_e4m3_fnuz
__device__ operator __hip_fp8x4_e4m3_fnuz() const { return data; }
// Assignment operator
__device__ fp8_e4_4_t &operator=(const __hip_fp8x4_e4m3_fnuz &val) {
data = val;
return *this;
}
};
struct __align__(8) fp8_e4_8_t {
fp8_e4_4_t x;
fp8_e4_4_t y;
};
struct __align__(16) fp8_e4_16_t {
fp8_e4_8_t x;
fp8_e4_8_t y;
};
__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
fp8_e4_t w) {
// reinterpret the 4 fp8_e4_t values to signed char value and shift
signed char x_char = *reinterpret_cast<signed char *>(&x);
signed char y_char = *reinterpret_cast<signed char *>(&y);
signed char z_char = *reinterpret_cast<signed char *>(&z);
signed char w_char = *reinterpret_cast<signed char *>(&w);
int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
return *reinterpret_cast<fp8_e4_4_t *>(&res);
}
__device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
fp8_e4_t w, fp8_e4_t v, fp8_e4_t u,
fp8_e4_t t, fp8_e4_t s) {
signed char x_char = *reinterpret_cast<signed char *>(&x);
signed char y_char = *reinterpret_cast<signed char *>(&y);
signed char z_char = *reinterpret_cast<signed char *>(&z);
signed char w_char = *reinterpret_cast<signed char *>(&w);
signed char v_char = *reinterpret_cast<signed char *>(&v);
signed char u_char = *reinterpret_cast<signed char *>(&u);
signed char t_char = *reinterpret_cast<signed char *>(&t);
signed char s_char = *reinterpret_cast<signed char *>(&s);
int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char;
fp8_e4_8_t res;
res.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
return res;
}
#pragma once
#include "common.h"
#pragma once
#include "common.h"
namespace tl {
struct SumOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x + y;
}
};
struct MaxOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return ck_tile::max(x, y);
}
};
struct MinOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return ck_tile::min(x, y);
}
};
// Detect half types
template <typename T>
struct is_half_type : std::false_type {};
template <>
struct is_half_type<__half> : std::true_type {};
template <>
struct is_half_type<_Float16> : std::true_type {};
template <typename T>
inline constexpr bool is_half_v = is_half_type<std::decay_t<T>>::value;
template <class Reducer, int threads, int scale, int thread_offset = 0>
struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 ||
threads == 128 || threads == 64 || threads == 32 ||
threads == 16 || threads == 8 || threads == 4 || threads == 2);
static_assert(threads % scale == 0);
template <typename T> static __device__ T run(T x, T *red_buf = nullptr) {
constexpr int offset = threads / 2;
constexpr int warpSize = 64;
if constexpr (offset >= warpSize) {
__syncthreads();
red_buf[threadIdx.x] = x;
__syncthreads();
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
} else {
if constexpr (is_half_v<T>) {
unsigned short x_raw;
if constexpr (std::is_same_v<std::decay_t<T>, __half>) {
x_raw = __half_as_ushort(x);
} else { // _Float16
union { _Float16 f; unsigned short s; } u;
u.f = x;
x_raw = u.s;
}
unsigned short shuffled_raw = __shfl_xor(x_raw, offset);
T shuffled_x;
if constexpr (std::is_same_v<std::decay_t<T>, __half>) {
shuffled_x = __ushort_as_half(shuffled_raw);
} else { // _Float16
union { unsigned short s; _Float16 f; } u;
u.s = shuffled_raw;
shuffled_x = u.f;
}
x = Reducer()(x, shuffled_x);
} else {
x = Reducer()(x, __shfl_xor(x, offset));
}
}
if constexpr (offset == scale) {
return x;
} else {
return AllReduce<Reducer, offset, scale, thread_offset>::run(x, red_buf);
}
}
};
template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32);
template <typename T, int SEG = 32>
static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H,
int W) {
constexpr int TILE_H = threads / SEG;
constexpr uint64_t MASK = 0xffffffffffffffffULL;
const int num_blocks = (H + TILE_H - 1) / TILE_H;
const int tid = threadIdx.x;
const int lane = tid % 64;
const int row = tid / 64;
for (int b = 0; b < num_blocks; ++b) {
const int gRow = b * TILE_H + row;
if (gRow >= H)
return;
T carry = (T)0;
if (reverse) {
// Start from the last segment for reverse mode
for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) {
const int col = seg * SEG + lane;
const int real_row = Axis == 1 ? gRow : col;
const int real_col = Axis == 1 ? col : gRow;
T val = (col < W) ? src[real_row * W + real_col] : (T)0;
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}
val += carry;
if (real_col < W)
dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, (T)0);
if (lane == 0)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, (T)0);
}
} else {
for (int seg = 0; seg * SEG < W; ++seg) {
const int col = seg * SEG + lane;
const int real_row = Axis == 1 ? gRow : col;
const int real_col = Axis == 1 ? col : gRow;
T val = (col < W) ? src[real_row * W + real_col] : (T)0;
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
if (lane >= off)
val += n;
}
val += carry;
if (real_col < W)
dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
}
}
}
}
};
} // namespace tl
#pragma once
#include "common.h"
namespace tl {
template <int panel_width> TL_DEVICE dim3 rasterization2DRow() {
auto ceil_div = [](int a, int b) { return (a + b - 1) / b; };
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.x;
const unsigned int panel_offset = block_idx % panel_size;
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel
? panel_width
: (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx = (panel_idx & 1)
? gridDim.x - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
auto ceil_div = [](int a, int b) { return (a + b - 1) / b; };
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.y;
const unsigned int panel_offset = block_idx % panel_size;
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel
? panel_width
: (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx = (panel_idx & 1)
? gridDim.y - 1 - panel_offset / stride
: panel_offset / stride;
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
} // namespace tl
...@@ -61,7 +61,8 @@ def compile_hip(code, ...@@ -61,7 +61,8 @@ def compile_hip(code,
file_target = path_target if path_target else temp_target file_target = path_target if path_target else temp_target
cmd = ["hipcc"] cmd = ["hipcc"]
cmd += ["-O3", '-c'] cmd += ["-O1", '-c']
cmd += ["-Wno-invalid-constexpr"]
if isinstance(arch, str): if isinstance(arch, str):
cmd += [f"--offload-arch={arch}"] cmd += [f"--offload-arch={arch}"]
if target_format == "hsaco": if target_format == "hsaco":
......
...@@ -227,7 +227,7 @@ def have_matrixcore(compute_version=None): ...@@ -227,7 +227,7 @@ def have_matrixcore(compute_version=None):
@tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True) @tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/rocm"): def get_rocm_arch(rocm_path="/opt/dtk"):
"""Utility function to get the AMD GPU architecture """Utility function to get the AMD GPU architecture
Parameters Parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment