Commit 3852d58b authored by wangziyang's avatar wangziyang
Browse files

update cp_async & init inject_ds_read

parent 19cdf0ca
import tilelang
import tilelang.language as T
from tilelang import disable_cache
disable_cache()
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm
def main():
kernel = matmul(32, 32, 32, 32, 32, 32)
import torch
a = torch.randn(32, 32).cuda().half()
b = torch.randn(32, 32).cuda().half()
c = kernel(a, b)
ref_c = a @ b
print("c:")
print(c)
print("ref_c:")
print(ref_c)
# torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())
# benchmark
profiler = kernel.get_profiler()
latency = profiler.do_bench(backend="cupti")
# latency = profiler.do_bench()
print(f"tilelang Latency: {latency}ms")
if __name__ == "__main__":
main()
...@@ -104,6 +104,20 @@ Fragment makeGemmFragmentC16x16CDNA() { ...@@ -104,6 +104,20 @@ Fragment makeGemmFragmentC16x16CDNA() {
return Fragment({i, j}, {index}, forward_thread, rep); return Fragment({i, j}, {index}, forward_thread, rep);
} }
// Tiled layout for DCU: each thread handles consecutive data in shared memory
// This layout is compatible with ds_read_m32x16_b16 which reads continuous memory
Fragment makeGemmFragmentC16x16CDNATiled() {
IterVar i = make_itervar("i", 16);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
// Tiled layout: thread ID = i*4+(j/4), each thread handles 4 consecutive columns
// forward_thread: threads are assigned by (4 columns)
// index: each thread handles 4 elements at column 0-3, 4-7, 8-11, 12-15
PrimExpr forward_thread = i * 4 + FloorDiv(j->var,4);
PrimExpr index = FloorMod(j->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, Fragment makeGemmFragmentC_F64(const int block_m, const int block_n,
const int warp_m, const int warp_n) { const int warp_m, const int warp_n) {
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
...@@ -165,11 +179,20 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n, ...@@ -165,11 +179,20 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n; ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false); // Use tiled layout for DCU: compatible with ds_read_m32x16_b16
// auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
auto base_layout = makeGemmFragmentC16x16CDNATiled()->Repeat({1, 1}, false);
auto warp_layout = auto warp_layout =
base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false); base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
auto block_layout = auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
LOG(INFO) << "FragmentC warp_m: " << warp_m;
LOG(INFO) << "FragmentC warp_n: " << warp_n;
LOG(INFO) << "FragmentC block_m: " << block_m;
LOG(INFO) << "FragmentC block_n: " << block_n;
LOG(INFO) << "FragmentC base_layout: " << base_layout->DebugOutput();
LOG(INFO) << "FragmentC warp_layout: " << warp_layout->DebugOutput();
LOG(INFO) << "FragmentC block_layout: " << block_layout->DebugOutput();
return block_layout; return block_layout;
} }
...@@ -265,6 +288,13 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n, ...@@ -265,6 +288,13 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
->Repeat({block_n / warp_n, 1}, true, false); ->Repeat({block_n / warp_n, 1}, true, false);
auto block_layout = auto block_layout =
warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false); warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false);
LOG(INFO) << "FragmentB warp_m: " << warp_m;
LOG(INFO) << "FragmentB warp_n: " << warp_n;
LOG(INFO) << "FragmentB block_m: " << block_m;
LOG(INFO) << "FragmentB block_n: " << block_n;
LOG(INFO) << "FragmentB base_layout: " << base_layout->DebugOutput();
LOG(INFO) << "FragmentB warp_layout: " << warp_layout->DebugOutput();
LOG(INFO) << "FragmentB block_layout: " << block_layout->DebugOutput();
return block_layout; return block_layout;
} else { } else {
auto base_layout = auto base_layout =
...@@ -273,8 +303,16 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n, ...@@ -273,8 +303,16 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
->Repeat({1, block_n / warp_n}, true); ->Repeat({1, block_n / warp_n}, true);
auto block_layout = auto block_layout =
warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true); warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
LOG(INFO) << "FragmentB warp_m: " << warp_m;
LOG(INFO) << "FragmentB warp_n: " << warp_n;
LOG(INFO) << "FragmentB block_m: " << block_m;
LOG(INFO) << "FragmentB block_n: " << block_n;
LOG(INFO) << "FragmentB base_layout: " << base_layout->DebugOutput();
LOG(INFO) << "FragmentB warp_layout: " << warp_layout->DebugOutput();
LOG(INFO) << "FragmentB block_layout: " << block_layout->DebugOutput();
return block_layout; return block_layout;
} }
} }
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
...@@ -314,6 +352,58 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, ...@@ -314,6 +352,58 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
} }
} }
Fragment makeGemmFragmentADCU(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
const int k_pack, bool transposed) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0);
const int mfma_k = k_pack * (element_size == 16 ? 16 : 32);
ICHECK(block_k % mfma_k == 0);
ICHECK(element_size == 8 || element_size == 16)
<< "element bitwidth=" << element_size;
if (transposed) {
auto base_layout =
element_size == 16
? makeGemmFragmentAB16x16CDNATransposed(k_pack)->Repeat(
{1, 1}, false, false)
: makeGemmFragmentAB16x32CDNATransposed(k_pack)->Repeat(
{1, 1}, false, false);
auto warp_layout =
base_layout->Repeat({block_k / mfma_k, warp_m / 16}, false, true);
auto block_layout = warp_layout->Repeat({1, block_m / warp_m}, true, true)
->Replicate(block_n / warp_n);
LOG(INFO) << "FragmentA warp_m: " << warp_m;
LOG(INFO) << "FragmentA warp_n: " << warp_n;
LOG(INFO) << "FragmentA block_m: " << block_m;
LOG(INFO) << "FragmentA block_n: " << block_n;
LOG(INFO) << "FragmentA base_layout: " << base_layout->DebugOutput();
LOG(INFO) << "FragmentA warp_layout: " << warp_layout->DebugOutput();
LOG(INFO) << "FragmentA block_layout: " << block_layout->DebugOutput();
return block_layout;
} else {
auto base_layout =
element_size == 16
? makeGemmFragmentAB16x16CDNA(k_pack)->Repeat({1, 1}, false, false)
: makeGemmFragmentAB16x32CDNA(k_pack)->Repeat({1, 1}, false, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, block_k / mfma_k}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
->Replicate(block_n / warp_n);
LOG(INFO) << "FragmentA warp_m: " << warp_m;
LOG(INFO) << "FragmentA warp_n: " << warp_n;
LOG(INFO) << "FragmentA block_m: " << block_m;
LOG(INFO) << "FragmentA block_n: " << block_n;
LOG(INFO) << "FragmentA base_layout: " << base_layout->DebugOutput();
LOG(INFO) << "FragmentA warp_layout: " << warp_layout->DebugOutput();
LOG(INFO) << "FragmentA block_layout: " << block_layout->DebugOutput();
return block_layout;
}
}
Fragment makeGemmFragment32x32(int element_size) { Fragment makeGemmFragment32x32(int element_size) {
IterVar i = make_itervar("i", 32); IterVar i = make_itervar("i", 32);
IterVar j = make_itervar("j", 32); IterVar j = make_itervar("j", 32);
......
...@@ -226,6 +226,11 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, ...@@ -226,6 +226,11 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int warp_n, const int element_size, const int warp_n, const int element_size,
const int k_pack, bool transposed = false); const int k_pack, bool transposed = false);
Fragment makeGemmFragmentADCU(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
const int k_pack, bool transposed = false);
// Default Memory Layout // Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous); Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
......
...@@ -216,6 +216,12 @@ TIR_DEFINE_TL_BUILTIN(tma_store_wait) ...@@ -216,6 +216,12 @@ TIR_DEFINE_TL_BUILTIN(tma_store_wait)
.set_num_inputs(0) .set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ds_read_vector)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(set_max_nreg) TIR_DEFINE_TL_BUILTIN(set_max_nreg)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -335,6 +335,16 @@ TVM_DLL const Op &tma_store_arrive(); ...@@ -335,6 +335,16 @@ TVM_DLL const Op &tma_store_arrive();
*/ */
TVM_DLL const Op &tma_store_wait(); TVM_DLL const Op &tma_store_wait();
/*!
* \brief DS read from shared memory to register
*
* ds_read_vector(dst, lds_base_ptr, m, n, offset)
*
* This is a tilelang intrinsic for DCU ds_read hardware instruction.
* Generated code will call tl::ds_read_vector.
*/
TVM_DLL const Op &ds_read_vector();
/*! /*!
* \brief Set reg hint for warp-specialized branched * \brief Set reg hint for warp-specialized branched
* *
......
...@@ -787,27 +787,51 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ...@@ -787,27 +787,51 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
results.Set(c_, fragment->BindThreadRange(thread_range)); results.Set(c_, fragment->BindThreadRange(thread_range));
} }
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") { if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
LOG(INFO) << "Using CDNA shared memory layout for A";
int dim_A = a_->shape.size(); int dim_A = a_->shape.size();
auto shared_layout = makeGemmABLayoutCDNA( if (TargetIsDCU(T.target)) {
*as_const_int(a_->shape[dim_A - 2]), auto shared_layout = makeGemmLayoutLinear(
*as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_); *as_const_int(a_->shape[dim_A - 2]),
results.Set(a_, shared_layout); *as_const_int(a_->shape[dim_A - 1]));
results.Set(a_, shared_layout);
} else {
auto shared_layout = makeGemmABLayoutCDNA(
*as_const_int(a_->shape[dim_A - 2]),
*as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_);
results.Set(a_, shared_layout);
}
} else if (a_.scope() == "local.fragment") { } else if (a_.scope() == "local.fragment") {
auto fragment = LOG(INFO) << "Using CDNA local fragment layout for A";
makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n, if (TargetIsDCU){
a_->dtype.bits(), kPack_, transA_); auto fragment =
results.Set(a_, fragment->BindThreadRange(thread_range)); makeGemmFragmentADCU(m_, n_, k_, m_ / warp_m, n_ / warp_n,
a_->dtype.bits(), kPack_, transA_);
results.Set(a_, fragment->BindThreadRange(thread_range));
}else{
auto fragment =
makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
a_->dtype.bits(), kPack_, transA_);
results.Set(a_, fragment->BindThreadRange(thread_range));
}
} else { } else {
ICHECK(0); ICHECK(0);
} }
if (b_.scope() == "shared" || b_.scope() == "shared.dyn") { if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
LOG(INFO) << "Using CDNA shared memory layout for B";
int dim_B = b_->shape.size(); int dim_B = b_->shape.size();
auto shared_layout = makeGemmABLayoutCDNA( if (TargetIsDCU(T.target)) {
*as_const_int(b_->shape[dim_B - 2]), auto shared_layout = makeGemmLayoutLinear(
*as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_); *as_const_int(b_->shape[dim_B - 2]),
*as_const_int(b_->shape[dim_B - 1]));
results.Set(b_, shared_layout); results.Set(b_, shared_layout);
} else {
auto shared_layout = makeGemmABLayoutCDNA(
*as_const_int(b_->shape[dim_B - 2]),
*as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_);
results.Set(b_, shared_layout);
}
} else if (b_.scope() == "local.fragment") { } else if (b_.scope() == "local.fragment") {
LOG(INFO) << "Using CDNA local fragment layout for B";
auto fragment = auto fragment =
makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_); makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
results.Set(b_, fragment->BindThreadRange(thread_range)); results.Set(b_, fragment->BindThreadRange(thread_range));
......
...@@ -759,6 +759,7 @@ std::string CodeGenTileLangHIP::GetBufferRef(DataType t, ...@@ -759,6 +759,7 @@ std::string CodeGenTileLangHIP::GetBufferRef(DataType t,
void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) { auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
printf("[DEBUG VisitExpr_] Branch: print_extern_call_stmt -> %s\n", name.c_str());
this->PrintIndent(); this->PrintIndent();
this->stream << name << "("; this->stream << name << "(";
for (size_t i = offset; i < op->args.size(); i++) { for (size_t i = offset; i < op->args.size(); i++) {
...@@ -768,7 +769,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -768,7 +769,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
} }
this->stream << ");\n"; this->stream << ");\n";
}; };
if (op->op.same_as(builtin::ptx_cp_async())) { if (op->op.same_as(builtin::ptx_cp_async())) {
printf("[DEBUG VisitExpr_] Branch: ptx_cp_async\n");
std::string dst = this->PrintExpr(op->args[0]); std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]); std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[2]);
...@@ -788,48 +791,75 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -788,48 +791,75 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
<< ", " << condition << ");\n"; << ", " << condition << ");\n";
} }
} else if (op->op.same_as(builtin::ptx_commit_group())) { } else if (op->op.same_as(builtin::ptx_commit_group())) {
printf("[DEBUG VisitExpr_] Branch: ptx_commit_group\n");
print_extern_call_stmt("tl::cp_async_commit"); print_extern_call_stmt("tl::cp_async_commit");
} else if (op->op.same_as(builtin::ptx_wait_group())) { } else if (op->op.same_as(builtin::ptx_wait_group())) {
printf("[DEBUG VisitExpr_] Branch: ptx_wait_group\n");
int n = Downcast<IntImm>(op->args[0])->value; int n = Downcast<IntImm>(op->args[0])->value;
std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
print_extern_call_stmt(func_name, 1); print_extern_call_stmt(func_name, 1);
} else if (op->op.same_as(builtin::create_barriers())) { } else if (op->op.same_as(builtin::create_barriers())) {
printf("[DEBUG VisitExpr_] Branch: create_barriers\n");
this->PrintIndent(); this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value; int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n"; << barrier_count << "];\n";
} else if (op->op.same_as(tl::get_mbarrier())) { } else if (op->op.same_as(tl::get_mbarrier())) {
printf("[DEBUG VisitExpr_] Branch: get_mbarrier\n");
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]); std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]"; os << barrier_name + "[" + barrier_id + "]";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) { } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
printf("[DEBUG VisitExpr_] Branch: ptx_arrive_barrier\n");
print_extern_call_stmt("tl::mbarrier_arrive"); print_extern_call_stmt("tl::mbarrier_arrive");
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
printf("[DEBUG VisitExpr_] Branch: ptx_init_barrier_thread_count\n");
print_extern_call_stmt("tl::mbarrier_init"); print_extern_call_stmt("tl::mbarrier_init");
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
printf("[DEBUG VisitExpr_] Branch: ptx_arrive_barrier_expect_tx\n");
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx"); print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
printf("[DEBUG VisitExpr_] Branch: ptx_cp_async_barrier\n");
print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) { } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
printf("[DEBUG VisitExpr_] Branch: mbarrier_expect_tx\n");
print_extern_call_stmt("tl::mbarrier_expect_tx"); print_extern_call_stmt("tl::mbarrier_expect_tx");
} else if (op->op.same_as(tl::mbarrier_wait_parity())) { } else if (op->op.same_as(tl::mbarrier_wait_parity())) {
printf("[DEBUG VisitExpr_] Branch: mbarrier_wait_parity\n");
print_extern_call_stmt("tl::mbarrier_wait"); print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::ptx_stmatrix())) { } else if (op->op.same_as(tl::ptx_stmatrix())) {
printf("[DEBUG VisitExpr_] Branch: ptx_stmatrix\n");
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) if (trans == 1)
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::wait_wgmma())) { }else if(op->op.same_as(tl::ds_read_vector())){
//ds_read_b64 %1, %2 offset:%3
// ds_read_m32x16_b16 %0, %1 offset:%2
printf("[DEBUG VisitExpr_] Branch: ds_read_vector\n");
std::string dst = this->PrintExpr(op->args[0]);
std::string lds_base_ptr = this->PrintExpr(op->args[1]);
std::string m = this->PrintExpr(op->args[2]);
std::string n = this->PrintExpr(op->args[3]);
std::string offset = this->PrintExpr(op->args[4]);
this->PrintIndent();
this->stream << "tl::ds_read_vector<" << m << ", " << n <<", " << offset << ">"
<< "(*reinterpret_cast<float4_*>(" << dst << "), "
<< "reinterpret_cast<uintptr_t>(" << lds_base_ptr << "));\n";
}else if (op->op.same_as(tl::wait_wgmma())) {
printf("[DEBUG VisitExpr_] Branch: wait_wgmma\n");
this->PrintIndent(); this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value; int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::pack_b16())) { } else if (op->op.same_as(tl::pack_b16())) {
printf("[DEBUG VisitExpr_] Branch: pack_b16\n");
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< this->PrintExpr(op->args[1]) << ")"; << this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(tl::__ldg())) { } else if (op->op.same_as(tl::__ldg())) {
printf("[DEBUG VisitExpr_] Branch: __ldg\n");
// HIP fallback: regular load // HIP fallback: regular load
const BufferLoadNode *bl = op->args[0].as<BufferLoadNode>(); const BufferLoadNode *bl = op->args[0].as<BufferLoadNode>();
ICHECK(bl) << "T.__ldg expects a BufferLoad as the first argument."; ICHECK(bl) << "T.__ldg expects a BufferLoad as the first argument.";
...@@ -840,6 +870,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -840,6 +870,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base); auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base);
os << buffer_ref; os << buffer_ref;
} else if (op->op.same_as(builtin::tvm_fill_fragment())) { } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
printf("[DEBUG VisitExpr_] Branch: tvm_fill_fragment\n");
need_mma_h_ = true; need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U); ICHECK_EQ(op->args.size(), 6U);
os << "nvcuda::wmma::fill_fragment("; os << "nvcuda::wmma::fill_fragment(";
...@@ -850,6 +881,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -850,6 +881,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
this->PrintExpr(op->args[5], os); this->PrintExpr(op->args[5], os);
os << ")"; os << ")";
} else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
printf("[DEBUG VisitExpr_] Branch: tvm_load_matrix_sync\n");
need_mma_h_ = true; need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U); ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::load_matrix_sync("; os << "nvcuda::wmma::load_matrix_sync(";
...@@ -862,6 +894,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -862,6 +894,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
this->PrintExpr(op->args[6], os); this->PrintExpr(op->args[6], os);
os << ")"; os << ")";
} else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
printf("[DEBUG VisitExpr_] Branch: tvm_store_matrix_sync\n");
need_mma_h_ = true; need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U); ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::store_matrix_sync("; os << "nvcuda::wmma::store_matrix_sync(";
...@@ -879,6 +912,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -879,6 +912,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
} }
os << ")"; os << ")";
} else if (op->op.same_as(builtin::tvm_mma_sync())) { } else if (op->op.same_as(builtin::tvm_mma_sync())) {
printf("[DEBUG VisitExpr_] Branch: tvm_mma_sync\n");
need_mma_h_ = true; need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U); ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::mma_sync("; os << "nvcuda::wmma::mma_sync(";
...@@ -889,6 +923,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -889,6 +923,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
os << "]" << ((i < 3) ? ", " : ")"); os << "]" << ((i < 3) ? ", " : ")");
} }
} else if (op->op.same_as(builtin::tvm_bmma_sync())) { } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
printf("[DEBUG VisitExpr_] Branch: tvm_bmma_sync\n");
need_mma_h_ = true; need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U); ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::bmma_sync("; os << "nvcuda::wmma::bmma_sync(";
...@@ -899,6 +934,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -899,6 +934,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
os << "]" << ((i < 3) ? ", " : ")"); os << "]" << ((i < 3) ? ", " : ")");
} }
} else if (op->op.same_as(tl::tvm_mfma())) { } else if (op->op.same_as(tl::tvm_mfma())) {
printf("[DEBUG VisitExpr_] Branch: tvm_mfma\n");
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype} // arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col // arg 1: A layout: row/col
// arg 2: B layout: row/col // arg 2: B layout: row/col
...@@ -964,6 +1000,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -964,6 +1000,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("{c_bias}", c_bias); replacer.register_rule("{c_bias}", c_bias);
os << replacer.rewrite(call_mfma_code); os << replacer.rewrite(call_mfma_code);
} else if (op->op.same_as(tl::tvm_mmac())) { } else if (op->op.same_as(tl::tvm_mmac())) {
printf("[DEBUG VisitExpr_] Branch: tvm_mmac\n");
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype} // arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col // arg 1: A layout: row/col
// arg 2: B layout: row/col // arg 2: B layout: row/col
...@@ -1029,8 +1066,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1029,8 +1066,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("{c_bias}", c_bias); replacer.register_rule("{c_bias}", c_bias);
os << replacer.rewrite(call_mmac_code); os << replacer.rewrite(call_mmac_code);
} else if (op->op.same_as(builtin::thread_return())) { } else if (op->op.same_as(builtin::thread_return())) {
printf("[DEBUG VisitExpr_] Branch: thread_return\n");
os << "return"; os << "return";
} else if (op->op.same_as(tl::tl_gemm())) { } else if (op->op.same_as(tl::tl_gemm())) {
printf("[DEBUG VisitExpr_] Branch: tl_gemm\n");
ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, " ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
"A_ptr, B_ptr, C_ptr>, but got " "A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size(); << op->args.size();
...@@ -1038,15 +1077,19 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1038,15 +1077,19 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)), this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os); op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) { } else if (op->op.same_as(tl::tl_gemm_sp())) {
printf("[DEBUG VisitExpr_] Branch: tl_gemm_sp\n");
LOG(FATAL) << "tl_gemm_sp is not supported on HIP"; LOG(FATAL) << "tl_gemm_sp is not supported on HIP";
} else if (op->op.same_as(tl::loop_break())) { } else if (op->op.same_as(tl::loop_break())) {
printf("[DEBUG VisitExpr_] Branch: loop_break\n");
this->PrintIndent(); this->PrintIndent();
this->stream << "break;\n"; this->stream << "break;\n";
} else if (op->op.same_as(tl::no_set_max_nreg())) { } else if (op->op.same_as(tl::no_set_max_nreg())) {
printf("[DEBUG VisitExpr_] Branch: no_set_max_nreg\n");
// HIP doesn't need explicit register management like CUDA // HIP doesn't need explicit register management like CUDA
// This is a no-op for HIP // This is a no-op for HIP
return; return;
} else { } else {
printf("[DEBUG VisitExpr_] Branch: CodeGenC::VisitExpr_ (fallback)\n");
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
} }
......
...@@ -96,6 +96,13 @@ bool TargetHasAsyncCopy(Target target) { ...@@ -96,6 +96,13 @@ bool TargetHasAsyncCopy(Target target) {
if (TargetIsCuda(target)) { if (TargetIsCuda(target)) {
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 80; return arch >= 80;
}else if(TargetIsDCU(target)) {
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
return mcpu.find("gfx936") == 0;
} else {
return false;
}
} else if (TargetIsCDNA(target)) { } else if (TargetIsCDNA(target)) {
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu")); std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
......
...@@ -137,4 +137,25 @@ template <typename T> TL_DEVICE void AtomicAddx4(T *ref, const T val[4]) { ...@@ -137,4 +137,25 @@ template <typename T> TL_DEVICE void AtomicAddx4(T *ref, const T val[4]) {
atomicAdd(&ref[1], val[1]); atomicAdd(&ref[1], val[1]);
atomicAdd(&ref[2], val[2]); atomicAdd(&ref[2], val[2]);
atomicAdd(&ref[3], val[3]); atomicAdd(&ref[3], val[3]);
} }
\ No newline at end of file
typedef float float4_ __attribute__((ext_vector_type(4)));
typedef float float2_ __attribute__((ext_vector_type(2)));
struct half4
{
__half x;
__half y;
__half z;
__half w;
};
union RegisterUnion
{
float4_ vector4;
struct
{
float2_ vector_front;
float2_ vector_rear;
};
};
\ No newline at end of file
...@@ -86,6 +86,36 @@ TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) { ...@@ -86,6 +86,36 @@ TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
} }
} }
template <int M, int N, int offset>
TL_DEVICE void ds_read_vector(float4_& dst, uint32_t lds_base_ptr)
{
if constexpr (M == 16 && N == 32)
{
const int offset_in_bytes = offset * sizeof(half_t);
asm volatile("ds_read_m32x16_b16 %0, %1 offset:%2\n\t"
: "+v"(dst)
: "v"(lds_base_ptr),
"n"(offset_in_bytes)
: "memory");
}
else if constexpr (M == 32 && N == 16)
{
const int offset_in_bytes0 = offset * sizeof(half_t);
const int offset_in_bytes1 = offset_in_bytes0 + 4096;
float2_& front = *reinterpret_cast<float2_*>(&dst);
float2_& rear = *(reinterpret_cast<float2_*>(&dst) + 1);
asm volatile(
"ds_read_b64 %1, %2 offset:%3\n\t"
"ds_read_b64 %0, %2 offset:%4\n\t"
: "+v"(rear), "+v"(front)
: "v"(lds_base_ptr), "n"(offset_in_bytes0), "n"(offset_in_bytes1)
: "memory"
);
}
}
template <int N> template <int N>
TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
void *global_base_ptr, bool cond) { void *global_base_ptr, bool cond) {
...@@ -107,4 +137,6 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, ...@@ -107,4 +137,6 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
} }
} }
} // namespace tl } // namespace tl
...@@ -66,7 +66,8 @@ template <> struct MfmaTraits<fp8_e4_t> { ...@@ -66,7 +66,8 @@ template <> struct MfmaTraits<fp8_e4_t> {
// ref to bitblas/tl/mfma_macro_generator.py::kPack // ref to bitblas/tl/mfma_macro_generator.py::kPack
template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA, template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA,
bool TransposeB, bool clear_accum, int kPack, typename A_type, bool TransposeB, bool clear_accum, int kPack, typename A_type,
typename B_type, typename C_type, typename AccDataType = float> typename B_type, typename C_type, typename AccDataType = float,
bool use_swizzle = true>
class GemmTensorOp { class GemmTensorOp {
public: public:
// static_assert(!clear_accum, "clear_accum=true is not supported yet"); // static_assert(!clear_accum, "clear_accum=true is not supported yet");
...@@ -147,13 +148,23 @@ public: ...@@ -147,13 +148,23 @@ public:
template <int continuous = 32, int element_size = 2> template <int continuous = 32, int element_size = 2>
TL_DEVICE static constexpr auto make_swizzle_layout(const int row, TL_DEVICE static constexpr auto make_swizzle_layout(const int row,
const int col) { const int col) {
auto [n_row, n_col] = // auto [n_row, n_col] =
make_mfma_swizzle_layout<continuous, element_size>(row, col); // make_mfma_swizzle_layout<continuous, element_size>(row, col);
return n_row * continuous + n_col; // return n_row * continuous + n_col;
if constexpr (use_swizzle) {
auto [n_row, n_col] =
make_mfma_swizzle_layout<continuous, element_size>(row, col);
return n_row * continuous + n_col;
} else {
// 不使用 swizzle,直接 linear layout
return make_layout_padded<continuous, element_size>(row, col).second +
make_layout_padded<continuous, element_size>(row, col).first * continuous;
}
} }
static TL_DEVICE void body(A_type *A_shared, B_type *B_shared, static TL_DEVICE void body(A_type *A_shared, B_type *B_shared,
C_type *C_local) { C_type *C_local) {
printf("Executing GemmTensorOp dcu_hip body\n");
auto tid = threadIdx.x; auto tid = threadIdx.x;
auto warp_id = tid / warp_size; auto warp_id = tid / warp_size;
auto warp_m = warp_id / block_col_warps; auto warp_m = warp_id / block_col_warps;
...@@ -178,6 +189,10 @@ public: ...@@ -178,6 +189,10 @@ public:
B_type B_local[warp_rows * kPack * local_size_b]; B_type B_local[warp_rows * kPack * local_size_b];
A_type A_local[warp_cols * kPack * local_size_a]; A_type A_local[warp_cols * kPack * local_size_a];
// Get base pointers as byte pointers for ds_read
const char* B_shared_bytes = reinterpret_cast<const char*>(B_shared);
const char* A_shared_bytes = reinterpret_cast<const char*>(A_shared);
for (int ki = 0; ki < inner_k; ki++) { for (int ki = 0; ki < inner_k; ki++) {
// Fetch B into register // Fetch B into register
for (int i = 0; i < warp_rows; i++) { for (int i = 0; i < warp_rows; i++) {
...@@ -257,6 +272,7 @@ public: ...@@ -257,6 +272,7 @@ public:
B_type B_local[warp_rows * kPack * local_size_b]; B_type B_local[warp_rows * kPack * local_size_b];
for (int ki = 0; ki < inner_k; ki++) { for (int ki = 0; ki < inner_k; ki++) {
// Fetch B into register // Fetch B into register
for (int i = 0; i < warp_rows; i++) { for (int i = 0; i < warp_rows; i++) {
...@@ -302,21 +318,21 @@ namespace tl { ...@@ -302,21 +318,21 @@ namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int kPack, typename A_type, bool trans_B, bool clear_accum, int kPack, typename A_type,
typename B_type, typename C_type> typename B_type, typename C_type, bool use_swizzle = false>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using Compute = using Compute =
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
clear_accum, kPack, A_type, B_type, C_type>; clear_accum, kPack, A_type, B_type, C_type, float, use_swizzle>;
Compute::body(pA, pB, accum); Compute::body(pA, pB, accum);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int kPack, typename A_type, bool trans_B, bool clear_accum, int kPack, typename A_type,
typename B_type, typename C_type> typename B_type, typename C_type, bool use_swizzle = false>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using Compute = using Compute =
GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
clear_accum, kPack, A_type, B_type, C_type>; clear_accum, kPack, A_type, B_type, C_type, float, use_swizzle>;
Compute::body_rs(pA, pB, accum); Compute::body_rs(pA, pB, accum);
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Replace shared memory BufferLoad with ds_read hardware instructions
* \file inject_ds_read.cc
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Check if the target is AMD DCU (gfx936, gfx942, etc.)
*/
bool IsDCUTarget(const IRModule& module) {
for (auto& p : module->functions) {
if (auto* prim_func = p.second.as<PrimFuncNode>()) {
if (auto opt_target = prim_func->GetAttr<Target>("target")) {
Target target = opt_target.value();
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx936", it is DCU
return mcpu.find("gfx936") == 0;
}
}
}
}
return false;
}
class DSReadInjector : public StmtMutator {
public:
Stmt VisitStmt_(const BufferStoreNode* store) final {
// Check if the store is to a local register (not shared memory)
bool is_local = store->buffer.scope() == "local" ||
store->buffer.scope() == "local.fragment";
if (!is_local) {
return StmtMutator::VisitStmt_(store);
}
// Check if the value is a BufferLoad from shared memory
if (auto* load = store->value.as<BufferLoadNode>()) {
bool is_shared_load = load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn";
if (!is_shared_load) {
return StmtMutator::VisitStmt_(store);
}
// Skip if indices are vectorized (contain Ramp expressions)
// ds_read is a scalar instruction, cannot handle vectorized indices
if (HasVectorizedIndices(store->indices) || HasVectorizedIndices(load->indices)) {
return StmtMutator::VisitStmt_(store);
}
// Check if the buffer is large enough for ds_read_vector
// ds_read_vector<32, 16> with half_t reads 16 bytes (8 elements)
// For small buffers (less than 16 bytes), skip this transformation
if (store->buffer.defined()) {
const auto& buffer_shape = store->buffer->shape;
if (buffer_shape.size() == 1) {
if (auto* int_shape = buffer_shape[0].as<IntImmNode>()) {
int extent = int_shape->value;
int dtype_bytes = load->dtype.bytes();
// ds_read_vector<32,16> with half_t reads 16 bytes minimum
// For buffers smaller than what ds_read_vector needs, skip
if (extent * dtype_bytes < 16) {
return StmtMutator::VisitStmt_(store);
}
}
}
}
// Analyze the load pattern to determine which ds_read to use
return InjectDSRead(store, load);
}
return StmtMutator::VisitStmt_(store);
}
private:
// PrimExpr VisitExpr_(const CallNode *op) {
// Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
// if (call->op.same_as(builtin::tvm_access_ptr())) {
// return RewriteBufferAccess(call, {1});
// }
// return call;
// }
/*!
* \brief Check if any index expression contains a Ramp (vectorized) expression
*/
bool HasVectorizedIndices(const Array<PrimExpr>& indices) {
for (const auto& idx : indices) {
if (idx.as<RampNode>()) {
return true;
}
}
return false;
}
Stmt InjectDSRead(const BufferStoreNode* store, const BufferLoadNode* load) {
const Buffer& shared_buf = load->buffer;
const Buffer& local_buf = store->buffer;
// Analyze indices to determine the byte offset
// PrimExpr offset = load->indices.size() > 0 ? load->indices[0] : make_zero(DataType::UInt(0));
// Calculate buffer size in bytes
int buffer_bytes = 0;
if (local_buf.defined() && local_buf->shape.size() == 1) {
if (auto* int_shape = local_buf->shape[0].as<IntImmNode>()) {
int num_elements = int_shape->value;
int dtype_bytes = local_buf->dtype.bytes();
buffer_bytes = num_elements * dtype_bytes;
}
}
// Determine which ds_read to use based on buffer size
// ds_read_b64 loads 8 bytes (64 bits) = 1 element for half_t, 2 for float32
// ds_read_m32x16_b16 loads 32 bytes (256 bits)
int dtype_bits = local_buf->dtype.bits();
int m = 16;
// For buffer < 16 bytes, use single ds_read_b64 (M=32, N=1)
// For buffer >= 16 bytes, use double ds_read_b64 (M=32, N=16)
// ds_read_b64 reads 8 bytes per call
int n = (buffer_bytes >= 32) ? 32 : 16;
int offset = 0;
return EmitDSRead(local_buf, shared_buf, m, n, offset);
}
Stmt EmitDSRead(const Buffer& local_buf,
const Buffer& shared_buf, int m, int n, int offset) {
// ds_read_vector takes: (dst, shared_ptr, m, n, offset)
Array<PrimExpr> args = {
local_buf->data, // dst: local buffer data pointer
shared_buf.access_ptr(0, DataType::Handle(), 1, 0), // src: shared buffer data pointer
make_const(DataType::Int(32), m),
make_const(DataType::Int(32), n),
make_const(DataType::Int(32), offset) // byte_offset: offset into shared memory
};
Stmt ds_read_stmt = Evaluate(
Call(DataType::Handle(), ds_read_vector(), args));
return ds_read_stmt;
}
};
using namespace tir::transform;
tvm::transform::Pass InjectDSRead() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
// Only apply to DCU targets
if (!IsDCUTarget(m)) {
return f;
}
auto *n = f.CopyOnWrite();
n->body = DSReadInjector()(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectDSRead", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectDSRead", InjectDSRead);
}
} // namespace tl
} // namespace tvm
...@@ -1057,6 +1057,9 @@ private: ...@@ -1057,6 +1057,9 @@ private:
int stage = static_cast<int>(pipeline_stages[i]->value); int stage = static_cast<int>(pipeline_stages[i]->value);
bool is_async = bool is_async =
pipeline_async_stages.find(stage) != pipeline_async_stages.end(); pipeline_async_stages.find(stage) != pipeline_async_stages.end();
printf("Block %s assigned to stage %d with order %d%s\n", original_order[i]->name_hint.c_str(),
stage, static_cast<int>(pipeline_orders[i]->value),
is_async ? " (async)" : " sync");
PipelineAnnotation stage_order{ PipelineAnnotation stage_order{
stage, stage,
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async, /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async,
......
...@@ -262,6 +262,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -262,6 +262,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject PTX async copy must behind the thread sync pass # Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load # as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod) mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
# Inject ds_read for shared to register memory copy on DCU
mod = tilelang.transform.InjectDSRead()(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
......
...@@ -237,7 +237,8 @@ class Environment: ...@@ -237,7 +237,8 @@ class Environment:
# Kernel selection options # Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1 # Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "1") # TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "1")
TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "0")
# Auto-tuning settings # Auto-tuning settings
TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0")
......
...@@ -769,18 +769,14 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -769,18 +769,14 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = ( l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
rk * (chunk // micro_size_k) + ki, A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
warp_m * warp_rows + i,
)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col]
else: else:
print(self.a_preshuffle)
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return ( return (
_warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk) _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk)
...@@ -845,19 +841,19 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -845,19 +841,19 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = ( l, r = (
warp_n * warp_cols + j, warp_n * warp_col_tiles + j * micro_size_y,
rk * (chunk // micro_size_k) + ki, rk * chunk + ki * (k_pack * micro_size_k),
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = ( l, r = (
rk * (chunk // micro_size_k) + ki, rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_cols + j, warp_n * warp_col_tiles + j * micro_size_y,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col] B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
return ( return (
_warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk) _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk)
......
...@@ -89,6 +89,44 @@ def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = N ...@@ -89,6 +89,44 @@ def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = N
raise TypeError("T.__ldg expects a BufferLoad or a Buffer.") raise TypeError("T.__ldg expects a BufferLoad or a Buffer.")
def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: int) -> Call:
"""
Load from shared memory using ds_read_b64 instruction.
This is a vectorized load instruction on AMD DCU that loads 64 bits (8 bytes)
from shared memory at the specified byte offset.
It writes 8 bytes to dst from shared memory at byte_offset.
This is a vectorized load instruction on AMD DCU that loads a 32x16 matrix
of half (16-bit) values with hardware-managed bank conflict avoidance.
Load from shared memory using ds_read_m32x16_b16 instruction.
The ds_read_vector intrinsic has signature:
ds_read_vector<M,N,offset>(float4 & dst, int lds_base_ptr)
Args:
dst: Destination pointer (register/local buffer).
lds_base_ptr: Source pointer (shared memory buffer data).
M: Number of columns in the matrix to load (for ds_read_m32x16_b16 / ds_read_b64).
N: Number of rows in the matrix to load (for ds_read_m32x16_b16 / ds_read_b64).
offset: address offset into shared memory.
Returns:
Call: A TIR call intrinsic for the ds_read_b64 instruction.
"""
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.ds_read_vector"),
dst,
shared_ptr,
m,
n,
offset
)
def get_mbarrier(*args): def get_mbarrier(*args):
"""Retrieve a memory barrier operation. """Retrieve a memory barrier operation.
......
...@@ -139,6 +139,7 @@ def gemm_v1( ...@@ -139,6 +139,7 @@ def gemm_v1(
mbar: tir.Buffer | None = None, mbar: tir.Buffer | None = None,
): ):
"""GEMM v1: use op tl.gemm.""" """GEMM v1: use op tl.gemm."""
# print("Using GEMM v1")
return _gemm_impl( return _gemm_impl(
"tl.tileop.gemm", "tl.tileop.gemm",
A, A,
...@@ -168,6 +169,7 @@ def gemm_v2( ...@@ -168,6 +169,7 @@ def gemm_v2(
mbar: tir.Buffer | None = None, mbar: tir.Buffer | None = None,
): ):
"""GEMM v2: use op tl.gemm_py.""" """GEMM v2: use op tl.gemm_py."""
print("Using GEMM v2")
return _gemm_impl( return _gemm_impl(
"tl.tileop.gemm_py", "tl.tileop.gemm_py",
A, A,
......
...@@ -15,15 +15,17 @@ from .gemm_mmac import GemmMMAC ...@@ -15,15 +15,17 @@ from .gemm_mmac import GemmMMAC
from tilelang import _ffi_api from tilelang import _ffi_api
from tilelang.utils.target import target_is_volta from tilelang.utils.target import target_is_volta
print("tileop gemm init...")
@tvm_ffi.register_global_func("tl.gemm_py.infer_layout") @tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range): def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range):
print("tileop gemm infer_layout")
thread_nums = thread_bounds.extent thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums) return gemm_py.infer_layout(target, thread_nums)
@tvm_ffi.register_global_func("tl.gemm_py.lower") @tvm_ffi.register_global_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, thread_var: tir.Var): def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, thread_var: tir.Var):
print("tileop gemm lower")
thread_nums = thread_bounds.extent thread_nums = thread_bounds.extent
stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var)
return stmt return stmt
...@@ -140,12 +142,14 @@ class GemmPy(Node, Scriptable): ...@@ -140,12 +142,14 @@ class GemmPy(Node, Scriptable):
def infer_layout(self, target: Target, thread_nums: int): def infer_layout(self, target: Target, thread_nums: int):
"""Infer the layout for the GEMM operation based on target architecture.""" """Infer the layout for the GEMM operation based on target architecture."""
print(f"GemmPy infer_layout Target: {target}, thread_nums: {thread_nums}")
gemm_inst = self._select_gemm_instruction(thread_nums, target) gemm_inst = self._select_gemm_instruction(thread_nums, target)
impl_class = self._get_implementation_class(gemm_inst, target) impl_class = self._get_implementation_class(gemm_inst, target)
return impl_class(self).infer_layout(target, thread_nums) return impl_class(self).infer_layout(target, thread_nums)
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
"""Lower the GEMM operation to TIR statements based on target architecture.""" """Lower the GEMM operation to TIR statements based on target architecture."""
print(f"GemmPy lower Target: {target}, thread_nums: {thread_nums}")
gemm_inst = self._select_gemm_instruction(thread_nums, target) gemm_inst = self._select_gemm_instruction(thread_nums, target)
impl_class = self._get_implementation_class(gemm_inst, target) impl_class = self._get_implementation_class(gemm_inst, target)
return impl_class(self).lower(layout_map, target, thread_nums, thread_var) return impl_class(self).lower(layout_map, target, thread_nums, thread_var)
...@@ -181,6 +185,7 @@ class GemmPy(Node, Scriptable): ...@@ -181,6 +185,7 @@ class GemmPy(Node, Scriptable):
NotImplementedError: If the instruction type is not supported NotImplementedError: If the instruction type is not supported
ValueError: If the instruction type is unknown ValueError: If the instruction type is unknown
""" """
print(f"_get_implementation_class Target: {target}")
if gemm_inst.is_mma(): if gemm_inst.is_mma():
if target_is_volta(target): if target_is_volta(target):
return GemmMMASm70 return GemmMMASm70
......
...@@ -31,24 +31,28 @@ class GemmMFMA(GemmBase): ...@@ -31,24 +31,28 @@ class GemmMFMA(GemmBase):
) )
if self.is_gemm_ss(): if self.is_gemm_ss():
print("gemm_ss")
return { return {
self.A: make_swizzled_layout(self.A), self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B), self.B: make_swizzled_layout(self.B),
self.C: mfma_emitter.make_mfma_store_layout(self.C), self.C: mfma_emitter.make_mfma_store_layout(self.C),
} }
elif self.is_gemm_sr(): elif self.is_gemm_sr():
print("gemm_sr")
return { return {
self.A: make_swizzled_layout(self.A), self.A: make_swizzled_layout(self.A),
self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"), self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"),
self.C: mfma_emitter.make_mfma_store_layout(self.C), self.C: mfma_emitter.make_mfma_store_layout(self.C),
} }
elif self.is_gemm_rs(): elif self.is_gemm_rs():
print("gemm_rs")
return { return {
self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"), self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B), self.B: make_swizzled_layout(self.B),
self.C: mfma_emitter.make_mfma_store_layout(self.C), self.C: mfma_emitter.make_mfma_store_layout(self.C),
} }
elif self.is_gemm_rr(): elif self.is_gemm_rr():
print("gemm_rr")
return { return {
self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"), self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"),
self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"), self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"),
...@@ -101,6 +105,7 @@ class GemmMFMA(GemmBase): ...@@ -101,6 +105,7 @@ class GemmMFMA(GemmBase):
assert is_full_region(C_region), "Fragment output C must be a full region" assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss(): if self.is_gemm_ss():
print("lower is_gemm_ss")
@T.prim_func @T.prim_func
def _gemm_ssr() -> None: def _gemm_ssr() -> None:
...@@ -136,6 +141,7 @@ class GemmMFMA(GemmBase): ...@@ -136,6 +141,7 @@ class GemmMFMA(GemmBase):
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr(): elif self.is_gemm_sr():
assert is_full_region(B_region), "Fragment input B must be a full region" assert is_full_region(B_region), "Fragment input B must be a full region"
print("lower is_gemm_sr")
@T.prim_func @T.prim_func
def _gemm_srr() -> None: def _gemm_srr() -> None:
...@@ -167,6 +173,7 @@ class GemmMFMA(GemmBase): ...@@ -167,6 +173,7 @@ class GemmMFMA(GemmBase):
return _Simplify(_gemm_srr, inline_let=True) return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
assert is_full_region(A_region), "Fragment input A must be a full region" assert is_full_region(A_region), "Fragment input A must be a full region"
print("lower is_gemm_rs")
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -195,6 +202,7 @@ class GemmMFMA(GemmBase): ...@@ -195,6 +202,7 @@ class GemmMFMA(GemmBase):
elif self.is_gemm_rr(): elif self.is_gemm_rr():
assert is_full_region(A_region), "Fragment input A must be a full region" assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(B_region), "Fragment input B must be a full region" assert is_full_region(B_region), "Fragment input B must be a full region"
print("lower is_gemm_rr")
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment