Unverified Commit 1774a1aa authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[Feature] Add 1D TMA support (#761)



* [Feature] Add 1D TMA support
- Check the contiguous conditions of 1D TMA copy
- Add new interface and params order of `tma_load` and `tma_store` call
- Add 1D `tma_store` interface in sm90 template
- Add elementwise kernel for 1D TMA example

* [Lint]

* [BugFix] Add conditions for 1D TMA copy on non-swizzle shared tensors

* [Lint]

* [BugFix] 1D TMA load

* [README] Update GDN README for clarity and add acknowledgements (#758)

- Improved formatting and clarity of the GDN kernel implementation description.
- Updated requirement section to list dependencies in a clearer format.
- Added an acknowledgements section to credit the developers and the Xiaomi LLM-Core Team for their contributions.

* cutlass v4.2.0 supporting cuda 13 (#760)

* [Lint]

* [Lint]

* [MXFP4] Add test for bf16&mxfp4 gemm

* [BugFix]

* [Lint]

---------
Co-authored-by: default avatarYu Cheng <54519279+chengyupku@users.noreply.github.com>
Co-authored-by: default avatarJohnny <johnnync13@gmail.com>
parent e05a20ab
...@@ -2,6 +2,7 @@ import tilelang.testing ...@@ -2,6 +2,7 @@ import tilelang.testing
import example_dequant_gemv_fp16xint4 import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -15,5 +16,11 @@ def test_example_dequant_gemm_fp4_hopper(): ...@@ -15,5 +16,11 @@ def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main() example_dequant_gemm_fp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_bf16_mxfp4_hopper():
example_dequant_gemm_bf16_mxfp4_hopper.main()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
import argparse
import tilelang
import tilelang.language as T
import torch
def ref_program(x, y):
return x + y
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return elem_add
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=128)
parser.add_argument("--n", type=int, default=128)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
# Default config
config = {"block_M": 128, "block_N": 128, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
print("All passed!")
if __name__ == "__main__":
main()
import tilelang.testing import tilelang.testing
import example_elementwise_add import example_elementwise_add
import example_elementwise_add_tma_1d
def test_example_elementwise_add(): def test_example_elementwise_add():
example_elementwise_add.main() example_elementwise_add.main()
def test_example_elementwise_add_tma_1d():
example_elementwise_add_tma_1d.main()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -18,7 +18,6 @@ except ImportError: ...@@ -18,7 +18,6 @@ except ImportError:
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils import assert_similar
torch.random.manual_seed(0) torch.random.manual_seed(0)
torch.set_printoptions(profile="full") torch.set_printoptions(profile="full")
...@@ -504,6 +503,7 @@ def run_test( ...@@ -504,6 +503,7 @@ def run_test(
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1) dim=-1)
from utils import assert_similar
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
......
...@@ -772,6 +772,18 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, ...@@ -772,6 +772,18 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
stride *= s; stride *= s;
} }
Array<PrimExpr> global_indices;
for (auto r : global_range) {
global_indices.push_back(r->min);
}
std::vector<PrimExpr> global_strides;
PrimExpr global_stride = 1;
for (size_t i = 0; i < global_tensor->shape.size(); i++) {
auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
global_strides.insert(global_strides.begin(), global_stride);
global_stride *= s;
}
ICHECK(strides.size() == indices.size()) ICHECK(strides.size() == indices.size())
<< "strides.size() != indices.size()" << strides.size() << " " << "strides.size() != indices.size()" << strides.size() << " "
<< indices.size(); << indices.size();
...@@ -779,12 +791,114 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, ...@@ -779,12 +791,114 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
for (size_t i = 0; i < indices.size(); i++) { for (size_t i = 0; i < indices.size(); i++) {
offset += indices[i] * strides[i]; offset += indices[i] * strides[i];
} }
PrimExpr global_offset = 0;
for (size_t i = 0; i < global_indices.size(); i++) {
global_offset += global_indices[i] * global_strides[i];
}
auto shared_tensor_before_remap = shared_tensor;
Layout shared_layout; Layout shared_layout;
if (T.layout_map.count(shared_tensor)) { if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor]; shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor]; shared_tensor = T.buffer_remap[shared_tensor];
} }
// Add 1D TMA copy when the global and shared memory is contiguous
{
// Check if shared_tensor->name is present in T.buffer_var_gemm
// (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
bool shared_is_contiguous = true;
for (const auto &v : T.buffer_var_gemm) {
if (v->name_hint == shared_tensor->name) {
shared_is_contiguous = false;
break;
}
}
bool shared_not_full_dim_encounter = false;
for (ssize_t i = shared_range.size() - 1; i >= 0; --i) {
if (!shared_not_full_dim_encounter) {
if (!analyzer->CanProve(shared_range[i]->extent ==
shared_tensor_before_remap->shape[i] &&
shared_range[i]->min == 0)) {
shared_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(shared_range[i]->extent == 1)) {
shared_is_contiguous = false;
break;
}
}
}
// Currently we check the empty stride of global tensor
bool global_is_contiguous = !global_tensor->strides.empty();
bool global_not_full_dim_encounter = false;
for (ssize_t i = global_range.size() - 1; i >= 0; --i) {
if (!global_not_full_dim_encounter) {
if (!analyzer->CanProve(global_range[i]->extent ==
global_tensor->shape[i] &&
global_range[i]->min == 0)) {
global_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(global_range[i]->extent == 1)) {
global_is_contiguous = false;
break;
}
}
}
// Ensure there is element match and no OOB
PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}
PrimExpr global_elements = 1;
for (size_t i = 0; i < global_range.size(); i++) {
global_elements *= global_range[i]->extent;
}
bool element_match =
analyzer->CanProveEqual(shared_elements, global_elements);
bool no_oob = true;
for (size_t i = 0; i < shared_range.size(); i++) {
if (!analyzer->CanProve(shared_range[i]->min + shared_range[i]->extent <=
shared_tensor_before_remap->shape[i])) {
no_oob = false;
break;
}
}
for (size_t i = 0; i < global_range.size(); i++) {
if (!analyzer->CanProve(global_range[i]->min + global_range[i]->extent <=
global_tensor->shape[i])) {
no_oob = false;
break;
}
}
// Add 1D TMA copy only for load
if (shared_is_contiguous && global_is_contiguous && element_match &&
no_oob && is_load) {
PrimExpr elements = analyzer->Simplify(shared_elements);
PrimExpr shared_addr = shared_tensor_before_remap.access_ptr(
is_load ? 2 : 1, DataType::Handle(), 1, offset, elements);
PrimExpr global_addr = global_tensor.access_ptr(
is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
Stmt tma_copy;
if (is_load) {
// the zero is a placeholder for mbarrier id
tma_copy =
Evaluate(Call(DataType::Handle(), tma_load(),
{shared_addr, global_addr, 0,
elements * shared_tensor_before_remap->dtype.bytes(),
this->eviction_policy}));
} else {
tma_copy =
Evaluate(Call(DataType::Handle(), tma_store(),
{global_addr, shared_addr,
elements * shared_tensor_before_remap->dtype.bytes(),
this->eviction_policy}));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
}
}
TMADesc desc; TMADesc desc;
// Verify copy rank // Verify copy rank
desc.rank = global_tensor->shape.size(); desc.rank = global_tensor->shape.size();
...@@ -1221,10 +1335,11 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const { ...@@ -1221,10 +1335,11 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
// Register the Copy operation with TVM's TIR system // Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs // This makes the copy operation available for use in TVM programs
// - Takes 4 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma // - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy
// - Marked as opaque since it has side effects (memory writes) // - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy) TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(4) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -49,6 +49,7 @@ struct LowerArgs { ...@@ -49,6 +49,7 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace; AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map; LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
Array<Var> buffer_var_gemm;
}; };
struct LayoutInferArgs { struct LayoutInferArgs {
......
...@@ -171,6 +171,16 @@ tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar, ...@@ -171,6 +171,16 @@ tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar,
: "memory"); : "memory");
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
".L2::cache_hint [%0], [%1], %2, %3;"
:
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
:);
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_store(const CUtensorMap &descriptor, TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0) { void const *const smem_ptr, int32_t const &crd0) {
......
...@@ -62,10 +62,17 @@ public: ...@@ -62,10 +62,17 @@ public:
private: private:
void VisitExpr_(const CallNode *call) final { void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]); auto arg0 = call->args[0].as<Call>();
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); if (call->op.same_as(tma_load()) && arg0 &&
int type_bytes = access_ptr->args[0]->dtype.bytes(); !arg0.value()->op.same_as(create_tma_descriptor())) {
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; // 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
bulk_copy_bytes = call->args[3] * loop_extents;
} else {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
} }
StmtExprVisitor::VisitExpr_(call); StmtExprVisitor::VisitExpr_(call);
} }
...@@ -155,10 +162,15 @@ private: ...@@ -155,10 +162,15 @@ private:
PrimExpr VisitExpr_(const CallNode *op) { PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) { if (op->op.same_as(tma_load())) {
auto arg0 = op->args[0].as<Call>();
bool is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
op->op.same_as(tma_load());
visited_tma_load_ = true; visited_tma_load_ = true;
Array<PrimExpr> new_args = op->args; Array<PrimExpr> new_args = op->args;
new_args.Set(1, Call(DataType::Handle(), get_mbarrier(), new_args.Set(is_1d_tma_load ? 2 : 1,
{IntImm(DataType::Int(32), 0)})); Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32), 0)}));
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} }
return IRMutatorWithAnalyzer::VisitExpr_(op); return IRMutatorWithAnalyzer::VisitExpr_(op);
...@@ -443,7 +455,14 @@ private: ...@@ -443,7 +455,14 @@ private:
<< "tma_load must be in the tma_op_to_barrier_id_"; << "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)]; auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto new_args = op->args; auto new_args = op->args;
new_args.Set(1, barrier_id); auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor());
if (is_1d_tma_load) {
new_args.Set(2, barrier_id);
} else {
new_args.Set(1, barrier_id);
}
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) { } else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op))) ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/gemm.h"
#include "../op/op.h" #include "../op/op.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
...@@ -71,6 +72,51 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout, ...@@ -71,6 +72,51 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout,
buffer->buffer_type); buffer->buffer_type);
} }
class BufferGemmCollector : public StmtExprVisitor {
public:
BufferGemmCollector() { Clear(); }
void Clear() { buffer_var_gemm_.clear(); }
void Collect(Stmt stmt) { VisitStmt(stmt); }
Array<Var> GetBufferVarGemm() { return buffer_var_gemm_; }
private:
void VisitStmt_(const EvaluateNode *op) {
auto call = Downcast<Call>(op->value);
if (call->op.same_as(Op::Get("tl.gemm"))) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]);
auto srcB_buffer_access_ptr = Downcast<Call>(call->args[1]);
ICHECK(srcB_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcB_buffer_var = Downcast<Var>(srcB_buffer_access_ptr->args[1]);
auto dst_buffer_access_ptr = Downcast<Call>(call->args[2]);
ICHECK(dst_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto dst_buffer_var = Downcast<Var>(dst_buffer_access_ptr->args[1]);
buffer_var_gemm_.push_back(srcA_buffer_var);
buffer_var_gemm_.push_back(srcB_buffer_var);
buffer_var_gemm_.push_back(dst_buffer_var);
} else if (call->op.same_as(Op::Get("tl.gemm_sp"))) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]);
auto srcB_buffer_access_ptr = Downcast<Call>(call->args[1]);
ICHECK(srcB_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcB_buffer_var = Downcast<Var>(srcB_buffer_access_ptr->args[1]);
auto dst_buffer_access_ptr = Downcast<Call>(call->args[2]);
ICHECK(dst_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto dst_buffer_var = Downcast<Var>(dst_buffer_access_ptr->args[1]);
buffer_var_gemm_.push_back(srcA_buffer_var);
buffer_var_gemm_.push_back(srcB_buffer_var);
buffer_var_gemm_.push_back(dst_buffer_var);
}
}
Array<Var> buffer_var_gemm_;
};
/*! /*!
* \brief A class that rewrites buffer references in a statement based on a * \brief A class that rewrites buffer references in a statement based on a
* given buffer remapping. * given buffer remapping.
...@@ -171,6 +217,11 @@ public: ...@@ -171,6 +217,11 @@ public:
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute"; ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
substituter.target_ = target.value(); substituter.target_ = target.value();
// For TMA 1D, we should collect the buffers which are not used in GEMM and
// do not need swizzle
BufferGemmCollector collector;
collector.Collect(f->body);
substituter.buffer_var_gemm_ = collector.GetBufferVarGemm();
PrimFuncNode *fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
fptr->body = fptr->body =
...@@ -415,7 +466,9 @@ private: ...@@ -415,7 +466,9 @@ private:
} }
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
// LOG(INFO) << "evaluate node: " << op->value;
const CallNode *call = op->value.as<CallNode>(); const CallNode *call = op->value.as<CallNode>();
// LOG(INFO) << "call: " << call->op;
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
if (call && call->op.as<GlobalVarNode>()) if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op)); return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
...@@ -444,10 +497,10 @@ private: ...@@ -444,10 +497,10 @@ private:
thread_bounds = Range::FromMinExtent(0, 1); thread_bounds = Range::FromMinExtent(0, 1);
} }
auto lowered = auto lowered = tile_op->Lower(
tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var, LowerArgs{target_, thread_bounds, thread_var_->var, callback,
callback, layout_map_, buffer_remap_}, layout_map_, buffer_remap_, buffer_var_gemm_},
analyzer_); analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered); return IRMutatorWithAnalyzer::VisitStmt(lowered);
} }
...@@ -481,6 +534,7 @@ private: ...@@ -481,6 +534,7 @@ private:
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_; std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
Map<Var, Var> var_remap_; Map<Var, Var> var_remap_;
bool has_tma_{false}; bool has_tma_{false};
Array<Var> buffer_var_gemm_;
}; };
namespace transform { namespace transform {
......
...@@ -321,9 +321,19 @@ private: ...@@ -321,9 +321,19 @@ private:
PrimExpr VisitExpr_(const CallNode *op) final { PrimExpr VisitExpr_(const CallNode *op) final {
auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]); auto mbar = makeGetBarrier(producer_barrier_idx_);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); auto arg0 = call->args[0].as<Call>();
call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_)); // Check if this is a 1D TMA load
auto is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
call->op.same_as(tma_load());
if (is_1d_tma_load) {
call.CopyOnWrite()->args.Set(2, mbar);
} else {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
call.CopyOnWrite()->args.Set(1, mbar);
}
} }
return call; return call;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment