Unverified Commit 5ccac4fa authored by Zhiwen Mo's avatar Zhiwen Mo Committed by GitHub
Browse files

[Bugfix] Fix tensor memory copy layout (#933)

* Implements tcgen05.ld instruction support for copying from shared.tmem
  to local.fragment on SM100/Blackwell architecture. Adds layout inference
  and lowering logic for tensor memory operations with proper physical
  coordinate range analysis and warpgroup alignment checks.

  Changes:
  - Add kTMemLoad and kTMemStore to CopyInst enumeration
  - Implement CheckTMemLoad() and CheckTMemStore() validation functions
  - Add LowerTmemCopy() to generate tcgen05.ld/st/cp PTX intrinsics
  - Add tmem layout inference in InferLayout() using expandTcgen05Layout
  - Support multiple instruction variants (32dp32b/64b/128b/256b)
  - Add physical layout bounds analysis for tmem coordinates
  - Change clear_accum from bool to PrimExpr in GEMM operations
  - Fix std::optional access checks in layout_inference.cc
  - Add tmem_allocate/deallocate PTX intrinsic support
  - Fix cooperative_groups grid.sync() code generation

* fix

* pipeline fix

* bug fix

* bool fix
parent fc4bd452
......@@ -35,11 +35,11 @@ def matmul(
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1) # 这里的 1 是 expect-arrive-count
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
......@@ -53,9 +53,8 @@ def matmul(
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
if T.get_thread_binding() < 128:
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
......@@ -66,7 +65,7 @@ M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages = 0
num_stages = 2
threads = 256
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
......
......@@ -10,6 +10,7 @@
*/
#include "copy.h"
#include "../layout/tcgen05_layout.h"
#include "../target/utils.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/common/loop_parallel_transform_utils.h"
......@@ -404,6 +405,71 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, T.analyzer, T.buffer_oob);
// Handle tensor memory (tmem) layout inference
if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) {
// Tensor memory copy
// TODO (mzw) Add support for tcgen05.st/cp (in conj. with LowerTmemCopy)
ICHECK(copy_inst == CopyInst::kTMemLoad)
<< "Only support tensor memory copy from shared.tmem to local.fragment "
"currently";
LayoutMap results;
if (!T.layout_map.count(dst) && T.layout_map.count(src)) {
// Use the default layout (32dp32b) if not specified
// NOTE (mzw) We will check the layout in LowerTmemCopy(), so don't
// worry for tmem-incompatible layout
Layout src_layout = T.layout_map[src];
Array<IterVar> logical_coords = MakeIterVars();
Array<PrimExpr> logical_coords_var = {logical_coords[0]->var,
logical_coords[1]->var};
Array<PrimExpr> phy_indices = src_layout->Forward(logical_coords_var);
// Tmem physical coord range analysis
auto analyzer = std::make_shared<arith::Analyzer>();
for (const auto &iv : logical_coords)
analyzer->Bind(iv->var, iv->dom);
arith::ConstIntBound phy_row_bounds =
analyzer->const_int_bound(phy_indices[0]);
arith::ConstIntBound phy_col_bounds =
analyzer->const_int_bound(phy_indices[1]);
Range row_dom = Range((int)(phy_row_bounds->min_value),
(int)(phy_row_bounds->max_value + 1));
Range col_dom = Range((int)(phy_col_bounds->min_value),
(int)(phy_col_bounds->max_value + 1));
constexpr int WARP_SIZE = 32; // Set to 32 since only sm100 is supported
constexpr int WARPGROUP_SIZE = 4 * WARP_SIZE;
ICHECK(is_const_int(T.thread_bounds->extent))
<< "Tensor memory copy requires thread_bounds->extent (num_threads) "
"to be constant integers";
int num_threads = *as_const_int(T.thread_bounds->extent);
ICHECK(num_threads % WARPGROUP_SIZE == 0)
<< "Tensor memory copy requires thread bounds to be aligned to "
"warpgroups, but found "
<< "thread range = " << T.thread_bounds;
for (int num_useful_wgs = num_threads / WARPGROUP_SIZE;
num_useful_wgs >= 1; --num_useful_wgs) {
int num_useful_threads = num_useful_wgs * WARPGROUP_SIZE;
Tcgen05Meta meta = getTcgen05Meta_32dp32b();
auto [is_success, tmem_coord2frag, num_chunks_each_wg] =
expandTcgen05Layout(
meta, phy_col_bounds->max_value - phy_col_bounds->min_value + 1,
num_useful_threads, row_dom, col_dom);
if (!is_success) {
continue;
}
Fragment logical_coord2frag =
Fragment(logical_coords, tmem_coord2frag->Forward(phy_indices),
tmem_coord2frag->ForwardThread(phy_indices, std::nullopt),
make_itervar("rep", 1));
results.Set(dst, logical_coord2frag->BindThreadRange(T.thread_bounds));
break;
}
}
return results;
}
if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) {
// if can apply swizzling, we skip layout inference
// for bulk load/store, we can directly apply the layout of normal copy
......@@ -631,15 +697,46 @@ bool CopyNode::CheckSTSMCopy(Target target) const {
(dst.scope() == "shared.dyn" || dst.scope() == "shared");
}
/**
* @brief Determine whether this copy can use tensor memory load (tcgen05.ld).
*
* Returns true when the target supports tensor memory and the source buffer is
* in `shared.tmem` scope while the destination buffer is in `local.fragment`.
*
* @param target The compilation target to query for tensor memory support.
* @return true if the copy may be lowered to a tcgen05.ld instruction; false
* otherwise.
*/
bool CopyNode::CheckTMemLoad(Target target) const {
return TargetHasTmem(target) && src.scope() == "shared.tmem" &&
dst.scope() == "local.fragment";
}
/**
* @brief Determine whether this copy can use tensor memory store (tcgen05.st).
*
* Returns true when the target supports tensor memory and the source buffer is
* in `local.fragment` scope while the destination buffer is in `shared.tmem`.
*
* @param target The compilation target to query for tensor memory support.
* @return true if the copy may be lowered to a tcgen05.st instruction; false
* otherwise.
*/
bool CopyNode::CheckTMemStore(Target target) const {
return TargetHasTmem(target) && src.scope() == "local.fragment" &&
dst.scope() == "shared.tmem";
}
/**
* @brief Selects the most specific copy instruction supported for the given
* target and buffers.
*
* Determines which specialized copy lowering to use (TMA bulk load/store, LDSM,
* STSM) based on target capabilities and the memory scopes of the
* source/destination buffers. If TMA lowering is disabled via the flag,
* BulkLoad/BulkStore are not selected. The selection priority is: BulkLoad,
* BulkStore, LDSM, STSM, then Normal (fallback).
* STSM, TMem load/store) based on target capabilities and the memory scopes of
* the source/destination buffers. If TMA lowering is disabled via the flag,
* BulkLoad/BulkStore are not selected. The selection priority is: TMemLoad,
* TMemStore, BulkLoad1D, BulkStore1D, BulkLoad, BulkStore, LDSM, STSM, then
* Normal (fallback).
*
* @param target The compilation target used to query hardware capabilities.
* @param disable_tma_lower If true, prevents selecting TMA-based bulk
......@@ -654,6 +751,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
// when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True,
// we will not use tma for bulk load/store
// Check tensor memory operations first (highest priority for SM100/Blackwell)
// 1d tma access can not support out of bound access
if (!disable_tma_lower && !buffer_oob &&
CheckBulkLoad1D(target, layout_map, analyzer)) {
......@@ -669,6 +767,10 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
return CopyInst::kLDSM;
} else if (CheckSTSMCopy(target)) {
return CopyInst::kSTSM;
} else if (CheckTMemLoad(target)) {
return CopyInst::kTMemLoad;
} else if (CheckTMemStore(target)) {
return CopyInst::kTMemStore;
} else {
return CopyInst::kNormal;
}
......@@ -688,14 +790,19 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
*/
Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, analyzer);
if (copy_inst == CopyInst::kBulkLoad1D ||
copy_inst == CopyInst::kBulkStore1D) {
if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) {
auto tmem_copy = LowerTmemCopy(T, analyzer);
ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy";
return tmem_copy;
} else if (copy_inst == CopyInst::kBulkLoad1D ||
copy_inst == CopyInst::kBulkStore1D) {
auto bulk_copy = LowerBulkCopy1D(T, analyzer, copy_inst);
ICHECK(bulk_copy.defined()) << "Failed to lower bulk load 1d";
return bulk_copy;
......@@ -975,6 +1082,206 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
return for_node;
}
/**
* @brief Lower tensor memory copy operations (tcgen05.ld/st/cp).
*
* Handles copy operations involving shared.tmem buffers (tensor memory on
* SM100/Blackwell). Supports three types of tensor memory copies:
* - tcgen05.ld: tensor memory -> register (local.fragment)
* - tcgen05.st: register (local.fragment) -> tensor memory
* - tcgen05.cp: shared memory -> tensor memory
*
* The function validates buffer scopes, extracts 2D loop structure, performs
* layout compatibility checks, selects an appropriate TCGEN05 instruction
* variant based on data width and thread count, and emits the corresponding PTX
* intrinsic call.
*
* Currently only tcgen05.ld is fully supported; st/cp will trigger an ICHECK
* failure.
*
* @param T Lowering context (target, thread bounds, layout maps, buffer
* remaps).
* @param analyzer Arithmetic analyzer for proving bounds and simplifying
* expressions.
* @return Stmt The lowered tensor memory copy statement, or an empty Stmt if
* this copy does not involve tensor memory.
*/
Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
if (src.scope() != "shared.tmem" && dst.scope() != "shared.tmem") {
return Stmt();
}
ICHECK(TargetHasTmem(T.target)) << "Target " << T.target->ToDebugString()
<< " does not support tensor memory copy";
// Decide copy type
bool is_ld = false; // tcgen05.ld (tensor memory -> register)
bool is_st = false; // tcgen05.st (register -> tensor memory)
bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory)
if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") {
is_ld = true;
} else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") {
is_st = true;
} else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") {
is_cp = true;
} else {
ICHECK(0) << "Unsupported tensor memory copy: "
<< "src scope = " << src.scope()
<< ", dst scope = " << dst.scope();
}
// Currently tcgen05.cp is not supported
// TODO (mzw) Support tcgen05.cp
ICHECK(!is_cp)
<< "Copy from shared memory to tensor memory is not supported yet";
// Currently tcgen05.st is not supported
// TODO (mzw) Support tcgen05.st
ICHECK(!is_st) << "Copy from register to tensor memory is not supported yet";
// Extract loop variables and ranges
Array<IterVar> loop_vars = MakeIterVars();
ICHECK(loop_vars.size() == 2) << "Only support 2D tensor memory copy, got "
<< loop_vars.size() << " dimensions";
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
ICHECK(!src_predicate.defined() && !dst_predicate.defined())
<< "Tensor memory copy does not support predicates, got " << src_predicate
<< " and " << dst_predicate;
ICHECK(is_const_int(loop_vars[0]->dom->min) &&
is_const_int(loop_vars[0]->dom->extent) &&
is_const_int(loop_vars[1]->dom->min) &&
is_const_int(loop_vars[1]->dom->extent))
<< "Tensor memory copy requires loop bounds to be constant integers";
int64_t logical_row_min = *as_const_int(loop_vars[0]->dom->min);
int64_t logical_row_extent = *as_const_int(loop_vars[0]->dom->extent);
int64_t logical_col_min = *as_const_int(loop_vars[1]->dom->min);
int64_t logical_col_extent = *as_const_int(loop_vars[1]->dom->extent);
// Extract thread bounds
constexpr int WARP_SIZE = 32; // Set to 32 since only sm100 is supported
constexpr int WARPGROUP_SIZE = 4 * WARP_SIZE;
ICHECK(is_const_int(T.thread_bounds->extent))
<< "Tensor memory copy requires thread_bounds->extent (num_threads) to "
"be constant integers";
int num_threads = *as_const_int(T.thread_bounds->extent);
ICHECK(analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, WARPGROUP_SIZE),
0) &&
num_threads % WARPGROUP_SIZE == 0)
<< "Tensor memory copy requires thread bounds to be aligned to "
"warpgroups, but found "
<< "thread range = " << T.thread_bounds;
// TODO (mzw) Buffer remap for shared.dyn when is_cp is true?
// Retrieve layout
ICHECK(T.layout_map.count(src))
<< "Source buffer " << src->name << " does not have a layout specified";
ICHECK(T.layout_map.count(dst)) << "Destination buffer " << dst->name
<< " does not have a layout specified";
Layout src_layout = T.layout_map[src];
Fragment dst_layout = Downcast<Fragment>(T.layout_map[dst]);
// Check layout
Array<PrimExpr> logical_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> phy_indices =
src_layout->Forward(logical_indices); // "phy" for "physical"
// Analyse the range of tmem_phy_row and tmem_phy_col
arith::ConstIntBound phy_row_bounds =
analyzer->const_int_bound(phy_indices[0]);
arith::ConstIntBound phy_col_bounds =
analyzer->const_int_bound(phy_indices[1]);
int tmem_phy_row_min = phy_row_bounds->min_value;
int tmem_phy_row_max = phy_row_bounds->max_value;
int tmem_phy_col_min = phy_col_bounds->min_value;
int tmem_phy_col_max = phy_col_bounds->max_value;
int tmem_phy_row_extent = tmem_phy_row_max - tmem_phy_row_min + 1;
int tmem_phy_col_extent = tmem_phy_col_max - tmem_phy_col_min + 1;
Range row_dom = Range(tmem_phy_row_min, tmem_phy_row_max + 1);
Range col_dom = Range(tmem_phy_col_min, tmem_phy_col_max + 1);
bool have_succeeded = false;
Stmt body;
auto try_tcgen05_instruction = [&](Tcgen05Meta meta) {
if (have_succeeded) {
return;
}
if (tmem_phy_row_min != 0 || tmem_phy_row_max != 127) {
return;
}
if (tmem_phy_col_min % meta.width != 0 ||
(tmem_phy_col_max + 1) % meta.width != 0) {
return;
}
for (int num_useful_wgs = num_threads / WARPGROUP_SIZE; num_useful_wgs >= 1;
num_useful_wgs--) {
int num_useful_threads = num_useful_wgs * WARPGROUP_SIZE;
auto [is_success, target_frag, num_chunks_each_wg] = expandTcgen05Layout(
meta, tmem_phy_col_extent, num_useful_threads, row_dom, col_dom);
if (!is_success) {
continue;
}
PrimExpr target_thread =
target_frag->ForwardThread(phy_indices, std::nullopt);
PrimExpr dst_thread =
dst_layout->ForwardThread(logical_indices, std::nullopt);
if (!analyzer->CanProveEqual(target_thread, dst_thread)) {
continue;
}
PrimExpr target_reg = target_frag->Forward(phy_indices)[0];
PrimExpr dst_reg = dst_layout->Forward(logical_indices)[0];
if (!analyzer->CanProveEqual(target_reg, dst_reg)) {
continue;
}
// All checks passed, we can use this instruction
PrimExpr relative_wg_idx =
FloorDiv(Sub(T.thread_var, T.thread_bounds->min), WARPGROUP_SIZE);
PrimExpr col_offset =
num_useful_threads == WARPGROUP_SIZE
? PrimExpr(0)
: relative_wg_idx * (num_chunks_each_wg * meta.width);
have_succeeded = true;
Array<PrimExpr> args;
args.push_back(StringImm(meta.intrinsics_name + "<" +
std::to_string(num_chunks_each_wg) + ">"));
args.push_back(
BufferLoad(src, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated later
// in lower_shared_tmem pass
args.push_back(col_offset);
args.push_back(dst.access_ptr(2, DataType::Handle(), 1, 0,
PrimExpr(tmem_phy_col_extent)));
Stmt call =
Evaluate(Call(DataType::Handle(), builtin::call_extern(), args));
if (num_useful_threads != num_threads) {
body =
IfThenElse(T.thread_var < T.thread_bounds->min + num_useful_threads,
call, // No-op for unused threads
Stmt());
} else {
body = call;
}
break;
}
};
try_tcgen05_instruction(getTcgen05Meta_32dp32b());
try_tcgen05_instruction(getTcgen05Meta_32dp64b());
try_tcgen05_instruction(getTcgen05Meta_32dp128b());
try_tcgen05_instruction(getTcgen05Meta_32dp256b());
ICHECK(have_succeeded) << "Failed to find a suitable instruction for "
"tcgen05.ld. Check your layout.";
return body;
}
/**
* @brief Lower a Copy operator to a bulk TMA (Tensor Memory Accelerator)
* transfer.
......
......@@ -24,6 +24,8 @@ enum class CopyInst : uint8_t {
// as they have different memory access patterns
kBulkLoad1D = 5, // utilize tma load 1d
kBulkStore1D = 6, // utilize tma store 1d
kTMemLoad = 7, // tcgen05.ld (tensor memory -> register)
kTMemStore = 8, // tcgen05.st (register -> tensor memory)
};
/// Descriptor for Tensor Memory Access (TMA) copy operations
......@@ -187,6 +189,16 @@ public:
*/
bool CheckSTSMCopy(Target target) const;
/*!
* \brief Check if tensor memory load is supported.
*/
bool CheckTMemLoad(Target target) const;
/*!
* \brief Check if tensor memory store is supported.
*/
bool CheckTMemStore(Target target) const;
/*!
* \brief Get the copy instruction type.
*/
......@@ -214,6 +226,11 @@ protected:
Stmt LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const;
/*!
* \brief Generate lowering for tensor memory copy (tcgen05.ld/st/cp).
*/
Stmt LowerTmemCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;
/*!
* \brief Generate lowering for normal copy.
*/
......
......@@ -128,7 +128,7 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->clear_accum = args[9].as<PrimExpr>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
......@@ -588,7 +588,10 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
ss << ", " << clear_accum;
auto clear_accum_bool = clear_accum.as<Bool>();
ICHECK(clear_accum_bool.has_value())
<< "clear_accum must be a constant Bool type, got " << clear_accum;
ss << ", " << bool(clear_accum_bool.value());
if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
ss << ", " << stride_A << ", " << stride_B;
ss << ", " << offset_A << ", " << offset_B;
......@@ -651,7 +654,6 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
if (TargetIsVolta(T.target)) {
ICHECK(C.scope() == "local.fragment")
<< "Volta gemm only supports C in local.fragment scope, got "
......
......@@ -107,7 +107,7 @@ public:
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
bool clear_accum = false;
PrimExpr clear_accum = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
......
......@@ -62,7 +62,7 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->clear_accum = args[9].as<PrimExpr>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
......
......@@ -26,7 +26,7 @@ public:
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
bool clear_accum = false;
PrimExpr clear_accum = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
......
......@@ -7,7 +7,6 @@
#include "operator.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
namespace tvm {
......
......@@ -11,6 +11,7 @@
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
......
......@@ -1303,6 +1303,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto phase = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".wait(" << phase << ");\n";
} else if (op->op.same_as(tl::ptx_init_tensor_memory())) {
print_extern_call_stmt("tl::tmem_allocate");
} else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) {
print_extern_call_stmt("tl::tmem_deallocate");
} else if (op->op.same_as(tl::no_set_max_nreg())) {
return;
} else if (op->op.same_as(tl::tma_load())) {
......@@ -1387,7 +1391,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::sync_grid())) {
this->need_cooperative_groups_ = true;
this->PrintIndent();
this->stream << "cooperative_groups::this_grid().sync();\n";
this->stream << "cooperative_groups::grid_group grid = "
"cooperative_groups::this_grid();\n";
this->PrintIndent();
this->stream << "grid.sync();\n";
} else if (op->op.same_as(tl::loop_break())) {
this->PrintIndent();
this->stream << "break;\n";
......
......@@ -370,13 +370,15 @@ using tl_mma::gemm_ss;
// }
template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
bool trans_B, typename C_type, typename A_type, typename B_type>
bool trans_B, typename C_type, typename A_type, typename B_type,
typename Barrier_type>
TL_DEVICE void tcgen5mma_gemm_ss(A_type *pA, B_type *pB, uint32_t accum,
uint64_t *umma_bar_ptr, bool clear_accum) {
Barrier_type *umma_bar_ptr, bool clear_accum) {
using MMA =
cute::tl_tcgen5mma::GemmTensorOp<M, N, K, AtomM, AtomN, AtomK, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_ss(pA, pB, accum, umma_bar_ptr, clear_accum);
MMA::body_ss(pA, pB, accum, reinterpret_cast<uint64_t *>(umma_bar_ptr),
clear_accum);
}
} // namespace tl
......@@ -126,8 +126,18 @@ public:
// Actually this test has been done in ParallelOp::InferLayout
// already. Just do it again to avoid missing implementations in other
// `TileOperator`s.
auto dst_layout = layout.as<Fragment>().value();
auto src_layout = layout_map[buffer].as<Fragment>().value();
auto dst_layout_opt = layout.as<Fragment>();
ICHECK(dst_layout_opt.has_value())
<< "Failed to cast layout to Fragment for buffer " << buffer
<< ", layout type is " << layout->GetTypeKey();
auto dst_layout = dst_layout_opt.value();
auto src_layout_opt = layout_map[buffer].as<Fragment>();
ICHECK(src_layout_opt.has_value())
<< "Failed to cast layout_map[buffer] to Fragment for buffer "
<< buffer << ", layout type is "
<< layout_map[buffer]->GetTypeKey();
auto src_layout = src_layout_opt.value();
ICHECK(dst_layout->InputDim() == src_layout->InputDim());
Array<PrimExpr> indices;
indices.reserve(dst_layout->InputDim());
......@@ -382,7 +392,13 @@ private:
return std::nullopt;
}
if (call->op.same_as(builtin::tvm_access_ptr())) {
auto var = call->args[1].as<Var>().value();
auto var_opt = call->args[1].as<Var>();
if (!var_opt.has_value()) {
DLOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: "
<< call->args[1]->GetTypeKey();
return std::nullopt;
}
auto var = var_opt.value();
return buffer_data_to_buffer_[var];
} else if (call->op.same_as(RegionOp::Get())) {
return call->args[0].as<BufferLoadNode>()->buffer;
......
......@@ -119,6 +119,8 @@ private:
{BufferLoad(new_buffer, {0}), PrimExpr(count)});
init_mbarrier_calls_.push_back(Evaluate(call));
}
if (init_mbarrier_calls_.empty())
return block;
Array<Stmt> new_body;
PrimExpr condition;
......@@ -127,8 +129,11 @@ private:
} else {
condition = EQ(thread_var_->var, 0);
}
new_body.push_back(
IfThenElse(condition, SeqStmt(init_mbarrier_calls_), Stmt()));
new_body.push_back(IfThenElse(condition,
init_mbarrier_calls_.size() == 1
? init_mbarrier_calls_.back()
: SeqStmt(init_mbarrier_calls_),
Stmt()));
new_body.push_back(
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
{StringImm("shared")})));
......
......@@ -90,8 +90,9 @@ private:
std::string func_name = le_pos == std::string::npos
? func_name_with_template
: func_name_with_template.substr(0, le_pos);
if (func_name == "tl::utcmma_gemm_ts" ||
func_name == "tl::utcmma_gemm_ss") {
// TODO(lei): refactor to use identical ops.
if (func_name == "tl::tcgen5mma_gemm_ts" ||
func_name == "tl::tcgen5mma_gemm_ss") {
// TCGEN5MMA
auto get_buf_from_access_ptr_call =
[&](const PrimExpr &expr) -> Buffer {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment