"...composable_kernel.git" did not exist on "2f463a94067f96519a083539679a5d187ca0563f"
Unverified Commit ae9b7063 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Feature] Add ptx_cp_async_barrier_noinc intrinsic and related functionality (#809)

- Introduced a new intrinsic `ptx_cp_async_barrier_noinc` for handling the `cp.async.mbarrier.arrive.noinc` operation in TileLang.
- Updated the CUDA code generation to support the new barrier operation.
- Added a corresponding function in the TileLang Python API for ease of use.
- Enhanced the barrier handling in CUDA templates to include the new no-increment operation, improving synchronization capabilities in parallel execution contexts.
parent 5e529522
......@@ -90,6 +90,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_cp_async_barrier_noinc)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -177,6 +177,14 @@ TVM_DLL const Op &ptx_ldmatrix();
*/
TVM_DLL const Op &ptx_stmatrix();
/*!
* \brief tvm intrinsic for ptx async copy barrier using
* cp.async.mbarrier.arrive.noinc
*
* This op is used to represent a ptx async copy barrier operation in tilelang.
*/
TVM_DLL const Op &ptx_cp_async_barrier_noinc();
/*!
* \brief Pack two b16 value into a b32 value
*
......
......@@ -1066,6 +1066,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
}
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive_noinc");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
......
......@@ -113,6 +113,22 @@ TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar) {
: "r"(smem_int_mbar));
}
template <typename BarrierType = uint64_t>
TL_DEVICE void mbarrier_cp_async_arrive_noinc(BarrierType &smem_mbar) {
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
asm volatile("{\n\t"
"cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t"
"}"
:
: "r"(smem_int_mbar));
cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar);
}
TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
......
......@@ -2,17 +2,11 @@
* \file annotate_warp_group_reg_alloc.cc
* \brief Annotate warp group reg alloc for warp specialization
*/
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "warp_specialized_rewriter.h"
#include <unordered_set>
#include <utility>
#include <vector>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
......@@ -57,6 +51,11 @@ private:
class SetMaxNRegInjector : public StmtExprMutator {
public:
static PrimFunc Inject(PrimFunc f) {
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
if (warp_specialized) {
// Should handle set_max_nreg when using hand-written warp specialized
return f;
}
auto T = SetMaxNRegInjector();
T.nreg_ = SetMaxNRegCollector::Collect(f);
f.CopyOnWrite()->body = T(f->body);
......
......@@ -3,21 +3,7 @@
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h"
#include "./common/collector.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
#include "warp_specialized_rewriter.h"
namespace tvm {
namespace tl {
......@@ -1284,73 +1270,6 @@ private:
bool disable_shuffle_elect_ = false;
};
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
// return true means this aws will be disabled
static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
if (detector.has_warp_specialization_) {
LOG(WARNING) << "Auto warp specialization will be disabled because warp "
"specialization is manually enabled";
return true;
}
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
"and mbarrier are both present";
return true;
}
return false;
}
WarpSpecializedDetector() {
has_tma_op_ = false;
has_mbarrier_op_ = false;
has_warp_specialization_ = false;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "warp_specialize" &&
op->value.as<IntImmNode>()->value == 1) {
has_warp_specialization_ = true;
}
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
bool has_tma_op_{false};
IterVar thread_var_;
bool has_mbarrier_op_{false};
bool has_warp_specialization_{false};
};
using namespace tir::transform;
tvm::transform::Pass WarpSpecialized() {
......
/*!
* \file warp_specialized_rewriter.h
* \brief tools for warp-specialized-related analysis and transformation
*/
#pragma once
#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h"
#include "./common/collector.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using namespace runtime;
using arith::IRVisitorWithAnalyzer;
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
// return true means this aws will be disabled
static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
if (detector.has_warp_specialization_) {
LOG(WARNING) << "Auto warp specialization will be disabled because warp "
"specialization is manually enabled";
return true;
}
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
"and mbarrier are both present";
return true;
}
return false;
}
WarpSpecializedDetector() {
has_tma_op_ = false;
has_mbarrier_op_ = false;
has_warp_specialization_ = false;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "warp_specialize" &&
op->value.as<IntImmNode>()->value == 1) {
has_warp_specialization_ = true;
}
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
bool has_tma_op_{false};
IterVar thread_var_;
bool has_mbarrier_op_{false};
bool has_warp_specialization_{false};
};
} // namespace tl
} // namespace tvm
......@@ -350,3 +350,9 @@ def sync_grid():
"""Synchronize all threads in a grid.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"))
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment