"src/vscode:/vscode.git/clone" did not exist on "2c9b8c2432ffe2eceba32d07ce8b0e467dd4538e"
Unverified Commit 6664d170 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Enhancement] Add eviction policy support for TMA operations, enhance CUDA...

[Enhancement] Add eviction policy support for TMA operations, enhance CUDA codegen, and introduce new pass config (#690)

* Enhance TMA and barrier handling in CUDA code generation

- Updated `CodeGenTileLangCUDA` to support eviction policies for TMA operations, allowing for more flexible memory management.
- Introduced a new `CacheHintSm90` enum to define eviction strategies in `copy_sm90.h`.
- Modified TMA load/store functions to accept eviction policies, improving performance on different architectures.
- Enhanced `TmaBarrierCollector` and `TmaBarrierRewriter` to account for SIMT copies, ensuring correct barrier insertion.
- Refactored thread synchronization logic to utilize barrier IDs, improving the efficiency of partial thread synchronization.
- Updated Python interface for `copy` and `c2d_im2col` to include optional eviction policy parameters, enhancing usability.

* update shuffle and elect optimization

* fix bug

* fix bug

* fix potential bug

* lint fix

* lint fix

* update shuffle_elect template

* fix bug

* fix bug

* fix template

* lint and fix

* fix typo
parent fe70549f
......@@ -27,6 +27,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
......@@ -88,7 +89,7 @@ TIR_DEFINE_TL_BUILTIN(ptx_stmatirx)
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(sync_thread_partial)
.set_num_inputs(1)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......@@ -140,5 +141,10 @@ TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
} // namespace tl
} // namespace tvm
......@@ -32,7 +32,7 @@ static constexpr const char *kPtxasRegisterUsageLevel =
"tl.ptxas_register_usage_level";
static constexpr const char *kEnablePTXASVerboseOutput =
"tl.enable_ptxas_verbose_output";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
/*!
* \brief Whether to disable dynamic tail split
*
......@@ -294,6 +294,13 @@ TVM_DLL const Op &tl_gemm();
*/
TVM_DLL const Op &tl_gemm_sp();
/*!
* \brief tilelang intrinsic for shuffle elect.
*
* This op is used to represent a shuffle elect operation in tilelang.
*/
TVM_DLL const Op &tl_shuffle_elect();
} // namespace tl
} // namespace tvm
......
......@@ -297,7 +297,7 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());
Array<PrimExpr> args;
args.reserve(desc.rank + 3);
args.reserve(desc.rank + 4);
args.push_back(create_descriptor);
if (is_load)
args.push_back(0); // mbarrier id placeholder
......@@ -319,6 +319,7 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords)
args.push_back(coord);
args.push_back(this->eviction_policy);
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args)));
} else {
......@@ -327,6 +328,7 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
args.push_back(shared_addr);
for (auto coord : global_coords)
args.push_back(coord);
args.push_back(this->eviction_policy);
tma_copy = Evaluate(Call(DataType::Handle(), op, args));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
......@@ -368,6 +370,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
stride = args[5].as<IntImm>().value()->value;
dilation = args[6].as<IntImm>().value()->value;
padding = args[7].as<IntImm>().value()->value;
eviction_policy = args[8].as<IntImm>().value()->value;
}
Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
......@@ -477,7 +480,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
Array<PrimExpr> args;
args.reserve(desc.rank * 2 + 1);
args.reserve(desc.rank * 2 + 2);
args.push_back(create_desc);
args.push_back(0); // mbar placeholder
auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst;
......@@ -487,7 +490,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
args.push_back(coord);
for (auto offset : image_offset)
args.push_back(offset);
args.push_back(this->eviction_policy);
Stmt tma_copy =
IfThenElse(EQ(T.thread_var, T.thread_bounds->min),
Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
......@@ -522,7 +525,7 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
}
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(8)
.set_num_inputs(9)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -57,7 +57,7 @@ public:
private:
Buffer src, dst;
int stride, padding, dilation, kernel;
int stride, padding, dilation, kernel, eviction_policy;
PrimExpr nhw_step, c_step;
};
......
......@@ -45,6 +45,9 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
auto disable_tma = Downcast<Bool>(args[3]);
this->disable_tma = disable_tma;
}
if (args.size() >= 5) {
this->eviction_policy = args[4].as<IntImmNode>()->value;
}
}
Array<IterVar> Copy::MakeIterVars() const {
......@@ -477,7 +480,7 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(3)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -58,6 +58,8 @@ protected:
Bool disable_tma = Bool(false);
std::unique_ptr<ParallelOp> par_op_;
int eviction_policy;
};
class Fill : public Operator {
......
......@@ -926,13 +926,14 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
}
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
auto print_extern_call_stmt = [&](std::string name, size_t start = 0,
size_t end = 0) {
// Cache context into a private ss, otherwise the let node may generate
// within the function call arguments.
std::ostringstream ss;
for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset)
for (size_t i = start; i < op->args.size() - end; i++) {
if (i > start)
ss << ", ";
ss << this->PrintExpr(op->args[i]);
}
......@@ -990,13 +991,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial");
print_extern_call_stmt("cutlass::arch::NamedBarrier::sync");
} else if (op->op.same_as(tl::no_set_max_nreg())) {
return;
} else if (op->op.same_as(tl::tma_load())) {
std::ostringstream ss;
ICHECK_GE(op->args.size(), 2);
ss << "tl::tma_load(";
auto eviction_policy =
this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
ss << "tl::tma_load<tl::CacheHintSm90::" << eviction_policy << ">(";
auto desc = op->args[0];
ss << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
......@@ -1004,7 +1008,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else {
ss << this->PrintExpr(op->args[1]) << ", ";
}
for (size_t i = 2; i < op->args.size(); i++) {
for (size_t i = 2; i < op->args.size() - 1; i++) {
if (i > 2)
ss << ", ";
ss << this->PrintExpr(op->args[i]);
......@@ -1013,9 +1017,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->PrintIndent();
this->stream << ss.str();
} else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col");
std::stringstream ss;
ss << "tl::tma_load_im2col<tl::CacheHintSm90::"
<< this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value]
<< ">";
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::tma_store())) {
print_extern_call_stmt("tl::tma_store");
std::stringstream ss;
ss << "tl::tma_store<tl::CacheHintSm90::"
<< this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value]
<< ">";
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::ptx_ldmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
......@@ -1537,6 +1551,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
enable_sparse_gemm_ = true;
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
} else {
CodeGenC::VisitExpr_(op, os);
}
......
......@@ -126,6 +126,8 @@ private:
int32_t GetWmmaFragmentSize(const std::string &scope, const VarNode *variable,
int32_t size);
std::vector<std::string> eviction_policy_names_ = {
"EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"};
std::unordered_set<std::string> bf16_supported_ops_ = {
"bf1622float2", "bf1622int16", "float22bf162", "bf162bf162"};
};
......
......@@ -241,4 +241,13 @@ TL_DEVICE void __sync_thread_partial() {
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count));
}
template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {
if constexpr (thread_extent == 0) {
return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync();
}
return __shfl_sync(0xffffffff, (threadIdx.x / 32) % (thread_extent / 32),
0) == 0 &&
cute::elect_one_sync();
}
} // namespace tl
This diff is collapsed.
......@@ -122,8 +122,11 @@ private:
Stmt VisitStmt_(const IfThenElseNode *op) {
// Check if this is the TMA block
const EQNode *eq = op->condition.as<EQNode>();
if (eq != nullptr) {
bool flag = false;
if (op->condition.as<CallNode>()) {
flag = op->condition.as<CallNode>()->op.same_as(tl_shuffle_elect());
}
if (op->condition.as<EQNode>() || flag) {
Stmt ret = IRMutatorWithAnalyzer::VisitStmt_(op);
if (visited_tma_load_) {
......@@ -164,6 +167,9 @@ private:
class TmaBarrierCollector : public IRVisitorWithAnalyzer {
public:
TmaBarrierCollector(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id() {
return tma_op_to_barrier_id_;
}
......@@ -222,7 +228,128 @@ private:
std::vector<Call> pending_tma_ops_;
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
Map<PrimExpr, IntImm> barrier_id_to_range_;
Map<Var, Buffer> buffer_data_to_buffer_;
};
class TmaSequenceCollector : public IRVisitorWithAnalyzer {
public:
TmaSequenceCollector(Map<ObjectRef, PrimExpr> tma_op_to_barrier_id)
: tma_op_to_barrier_id_(std::move(tma_op_to_barrier_id)) {}
std::vector<bool> GetSequence() {
std::vector<bool> clear_zero_list(expect_tx_count_, false);
int zero_idx = -1;
int zero_count = 0;
for (auto v : sequence) {
if (v == 0) {
zero_count += 1;
zero_idx += 1;
} else {
if (zero_count == 1) {
clear_zero_list[zero_idx] = expect_[zero_idx] && !has_simt_copy_;
if (clear_zero_list[zero_idx] == false) {
int begin = int_sets_[zero_idx].min().as<IntImmNode>()->value;
int end = int_sets_[zero_idx].max().as<IntImmNode>()->value;
for (int i = begin; i <= end; ++i) {
restore_barrier_ids_.push_back(i);
}
}
} else {
for (int i{zero_idx}; i > zero_idx - zero_count; --i) {
int begin = int_sets_[i].min().as<IntImmNode>()->value;
int end = int_sets_[i].max().as<IntImmNode>()->value;
for (int i = begin; i <= end; ++i) {
restore_barrier_ids_.push_back(i);
}
}
}
zero_count = 0;
}
}
return clear_zero_list;
}
std::vector<int> GetRestoreBarrierIds() { return restore_barrier_ids_; }
void VisitStmt_(const ForNode *op) final {
var_int_set_.Set(op->loop_var,
arith::IntSet::FromMinExtent(op->min, op->extent));
IRVisitorWithAnalyzer::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(mbarrier_expect_tx())) {
PrimExpr e =
tma_op_to_barrier_id_[GetRef<Call>(op)].as<CallNode>()->args[0];
auto int_set = arith::EvalSet(e, var_int_set_);
expect_.push_back(if_depth_ == 1);
sequence.push_back(0);
int_sets_.push_back(int_set);
expect_tx_count_ += 1;
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
sequence.push_back(1);
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
has_simt_copy_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitStmt_(const IfThenElseNode *op) final {
if_depth_ += 1;
IRVisitorWithAnalyzer::VisitStmt(op->then_case);
if (op->else_case) {
IRVisitorWithAnalyzer::VisitStmt(op->else_case.value());
}
if_depth_ -= 1;
}
std::vector<int> sequence;
int expect_tx_count_{0};
std::vector<bool> expect_;
bool has_simt_copy_{false};
std::vector<int> restore_barrier_ids_;
int if_depth_{0};
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
arith::Analyzer *analyzer_;
Map<Var, arith::IntSet> var_int_set_;
std::vector<arith::IntSet> int_sets_;
};
class BarrierCreationRewriter : public StmtExprMutator {
public:
BarrierCreationRewriter(std::vector<int> restore_barrier_ids,
PrimExpr producer_thread_extent)
: restore_barrier_ids_(std::move(restore_barrier_ids)),
producer_thread_extent_(producer_thread_extent) {}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(create_list_of_mbarrier())) {
std::vector<bool> tmp_(op->args.size(), false);
Array<PrimExpr> new_args;
for (auto &id : restore_barrier_ids_) {
tmp_[id] = true;
}
for (size_t i{0}; i < op->args.size(); ++i) {
if (tmp_[i]) {
new_args.push_back(producer_thread_extent_);
} else {
new_args.push_back(op->args[i]);
}
}
return Call(op->dtype, op->op, new_args);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
std::vector<int> restore_barrier_ids_;
PrimExpr producer_thread_extent_;
};
// we trust mbarrier_wait_parity to be correct
class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
public:
......@@ -236,8 +363,12 @@ public:
has_create_list_of_mbarrier_(has_create_list_of_mbarrier) {}
static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
auto buffer_lca = DetectBufferAccessLCA(f);
Map<Var, Buffer> buffer_data_to_buffer_;
for (auto [buffer, _] : buffer_lca)
buffer_data_to_buffer_.Set(buffer->data, buffer);
f = TmaExpectTxRewriter::Rewrite(f, analyzer);
TmaBarrierCollector collector;
TmaBarrierCollector collector(buffer_data_to_buffer_);
collector(f->body);
bool has_create_list_of_mbarrier = false;
PostOrderVisit(f->body, [&](const ObjectRef &node) {
......@@ -253,6 +384,9 @@ public:
collector.barrier_id_to_range(),
has_create_list_of_mbarrier);
f.CopyOnWrite()->body = rewriter(f->body);
auto barrier_creation_rewriter = BarrierCreationRewriter(
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_);
f.CopyOnWrite()->body = barrier_creation_rewriter(f->body);
return f;
}
......@@ -266,6 +400,42 @@ private:
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode *op) {
if (first_if) {
if (op->condition.as<GENode>()) {
producer_thread_extent_ =
thread_var_->dom->extent - op->condition.as<GENode>()->b;
}
TmaSequenceCollector collector(tma_op_to_barrier_id_);
collector(op->then_case);
clear_expect_list_ = collector.GetSequence();
restore_barrier_ids_ = collector.GetRestoreBarrierIds();
first_if = false;
is_producer_ = true;
auto then_case = StmtExprMutator::VisitStmt(op->then_case);
is_producer_ = false;
Stmt else_case;
if (op->else_case.defined())
else_case = StmtExprMutator::VisitStmt(op->else_case.value());
return IfThenElse(op->condition, then_case, else_case);
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "kWarpSpecializationScope") {
has_warp_specialization_ = true;
first_if = true;
} else if (op->attr_key == tir::attr::thread_extent &&
Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
thread_var_ = Downcast<IterVar>(op->node);
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
// check this must be in the tma_op_to_barrier_id_
......@@ -281,6 +451,22 @@ private:
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto new_args = op->args;
new_args.Set(0, barrier_id);
if (!has_warp_specialization_)
clear_arrive_ = false;
else
clear_arrive_ = clear_expect_list_[cur_expect_idx_++];
if (clear_arrive_) {
return Call(op->dtype, builtin::ptx_arrive_barrier_expect_tx(),
new_args);
}
return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
if (clear_arrive_) {
clear_arrive_ = false;
return 0;
}
// by default, all threads must wait.
auto new_args = op->args;
return Call(op->dtype, op->op, new_args);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
......@@ -288,6 +474,13 @@ private:
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
Map<PrimExpr, IntImm> barrier_id_to_range_;
bool has_create_list_of_mbarrier_;
bool clear_arrive_{false};
bool first_if{false}, has_warp_specialization_{false}, is_producer_{false};
IterVar thread_var_;
int tma_expect_tx_{0}, cur_expect_idx_{0};
std::vector<bool> clear_expect_list_;
std::vector<int> restore_barrier_ids_;
PrimExpr producer_thread_extent_;
};
tvm::transform::Pass InjectTmaBarrier() {
......
......@@ -6,6 +6,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
......@@ -21,9 +22,9 @@ using namespace tir;
#if (CUDA_MAJOR_VERSION >= 12)
class LowerHopperIntrin : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc &f) {
static PrimFunc Substitute(PrimFunc &f, bool disable_shuffle_elect) {
PrimFuncNode *fptr = f.CopyOnWrite();
LowerHopperIntrin substituter;
LowerHopperIntrin substituter(disable_shuffle_elect);
fptr->body = substituter.VisitStmt(f->body);
Map<String, Array<PrimExpr>> init_desc_arg_map;
for (auto [call, var] : substituter.desc_map_) {
......@@ -73,10 +74,15 @@ public:
auto stmts = prefetch_calls_;
stmts.insert(stmts.end(), init_mbarrier_calls_.begin(),
init_mbarrier_calls_.end());
auto init_stmt =
IfThenElse(EQ(iv->var, IntImm(iv->var->dtype, 0)),
stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
stmt_seq.push_back(init_stmt);
PrimExpr condition;
if (!disable_shuffle_elect_) {
condition = Call(DataType::Bool(), tl_shuffle_elect(), {0});
} else {
condition = EQ(iv->var, 0);
}
auto stmt_ = IfThenElse(condition,
stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
stmt_seq.push_back(stmt_);
if (!init_mbarrier_calls_.empty()) {
Stmt mem_sync =
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
......@@ -121,14 +127,6 @@ public:
{mbarrier, call->args[i]})));
}
return 0;
} else if (call->op.same_as(sync_thread_partial())) {
int barrier_id = init_mbarrier_calls_.size();
PrimExpr mbarrier =
Call(DataType::Handle(), get_mbarrier(), {barrier_id});
init_mbarrier_calls_.push_back(Evaluate(
Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[0]})));
return Call(DataType::Handle(), sync_thread_partial(), {mbarrier});
} else {
return StmtExprMutator::VisitExpr_(call);
}
......@@ -138,14 +136,18 @@ private:
Array<Stmt> prefetch_calls_;
Array<Stmt> init_mbarrier_calls_;
std::unordered_map<Call, Var, StructuralHash, ExprDeepEqual> desc_map_;
LowerHopperIntrin() = default;
LowerHopperIntrin(bool disable_shuffle_elect)
: disable_shuffle_elect_(disable_shuffle_elect) {}
bool disable_shuffle_elect_;
};
using namespace tir::transform;
tvm::transform::Pass LowerHopperIntrin() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerHopperIntrin::Substitute(f);
bool disable_shuffle_elect =
ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
return LowerHopperIntrin::Substitute(f, disable_shuffle_elect);
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
}
......
......@@ -44,11 +44,13 @@ private:
continue;
} else {
if (!current_if_bodies.empty()) {
new_seq.push_back(IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: SeqStmt(current_if_bodies),
Stmt()));
auto if_stmt =
IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: this->VisitStmt(SeqStmt(current_if_bodies)),
Stmt());
new_seq.push_back(if_stmt);
current_if_bodies.clear();
}
......@@ -60,11 +62,13 @@ private:
}
if (!current_if_bodies.empty()) {
new_seq.push_back(IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: SeqStmt(current_if_bodies),
Stmt()));
auto if_stmt =
IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: this->VisitStmt(SeqStmt(current_if_bodies)),
Stmt());
new_seq.push_back(if_stmt);
current_condition = PrimExpr();
current_if_bodies.clear();
}
......@@ -73,11 +77,13 @@ private:
}
if (!current_if_bodies.empty()) {
new_seq.push_back(IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: SeqStmt(current_if_bodies),
Stmt()));
auto if_stmt =
IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: this->VisitStmt(SeqStmt(current_if_bodies)),
Stmt());
new_seq.push_back(if_stmt);
}
return new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq);
......
......@@ -29,7 +29,8 @@ public:
// The syncs inserted before each statement
std::unordered_set<const Object *> syncs_inserted_;
std::unordered_map<const Object *, int> partial_syncs_inserted_;
std::unordered_map<const Object *, std::tuple<int, int>>
partial_syncs_inserted_;
protected:
bool Enabled(const VarNode *buf, const StorageScope &scope) const final {
......@@ -257,17 +258,24 @@ private:
scope_.push_back(std::vector<StmtEntry>());
num_partial_threads_ = partitions[0];
barrier_id_ += 1;
this->VisitStmt(body->then_case);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (!has_sync_)
barrier_id_ -= 1;
has_sync_ = false;
num_partial_threads_ = partitions[1];
scope_.push_back(std::vector<StmtEntry>());
barrier_id_ += 1;
VisitStmt(body->else_case.value());
auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (!has_sync_)
barrier_id_ -= 1;
has_sync_ = false;
s.access.insert(s.access.end(), v.begin(), v.end());
num_partial_threads_ = std::nullopt;
......@@ -281,10 +289,12 @@ private:
// condition";
if (syncs_inserted_.count(obj))
return;
if (num_partial_threads_.defined()) {
if (num_partial_threads_.defined() && barrier_id_ >= 0 &&
barrier_id_ < 16) {
syncs_inserted_.insert(obj);
partial_syncs_inserted_[obj] =
static_cast<int>(num_partial_threads_.value()->value);
partial_syncs_inserted_[obj] = std::make_tuple(
static_cast<int>(num_partial_threads_.value()->value), barrier_id_);
has_sync_ = true;
} else {
syncs_inserted_.insert(obj);
}
......@@ -294,6 +304,8 @@ private:
Optional<IntImm> num_partial_threads_;
// synchronization scope
StorageScope sync_scope_;
int barrier_id_{-1};
bool has_sync_{false};
};
// There are cases where necessary syncthreads is not inserted by
......@@ -318,7 +330,7 @@ class ThreadPartialSyncInserter : public StmtExprMutator {
public:
ThreadPartialSyncInserter(
StorageScope sync_scope, const std::unordered_set<const Object *> &syncs,
std::unordered_map<const Object *, int> partial_syncs)
std::unordered_map<const Object *, std::tuple<int, int>> partial_syncs)
: sync_scope_(sync_scope), syncs_(syncs), partial_syncs_(partial_syncs) {}
Stmt VisitStmt(const Stmt &stmt) final {
......@@ -329,8 +341,10 @@ public:
if (partial_syncs_.count(stmt.get())) {
auto iter = partial_syncs_.find(stmt.get());
ICHECK(sync_scope_.rank == StorageRank::kShared);
barrier = Evaluate(
Call(DataType::Int(32), tl::sync_thread_partial(), {iter->second}));
int num_threads, barrier_id;
std::tie(num_threads, barrier_id) = iter->second;
barrier = Evaluate(Call(DataType::Int(32), tl::sync_thread_partial(),
{num_threads, barrier_id}));
} else {
return StmtExprMutator::VisitStmt(stmt);
}
......@@ -347,7 +361,8 @@ private:
// data structure.
StorageScope sync_scope_;
const std::unordered_set<const Object *> &syncs_;
const std::unordered_map<const Object *, int> &partial_syncs_;
const std::unordered_map<const Object *, std::tuple<int, int>>
&partial_syncs_;
};
Stmt TileLangThreadPartialSync(Stmt stmt, std::string storage_scope) {
......
......@@ -242,10 +242,15 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
}
static Stmt makeArriveBarrier(PrimExpr barrier_id) {
auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(),
{makeGetBarrier(barrier_id)});
return Evaluate(call);
static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1,
PrimExpr pred = 1) {
Array<PrimExpr> args = {makeGetBarrier(barrier_id)};
if (cta_id != -1) {
args.push_back(cta_id);
args.push_back(pred);
}
return Evaluate(
Call(DataType::Handle(), builtin::ptx_arrive_barrier(), args));
}
static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
......@@ -318,14 +323,18 @@ private:
class ThreadIdxRewriter : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) {
auto rewriter = ThreadIdxRewriter(thread_var, replaced);
static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced,
PrimExpr thread_extent, bool do_shuffle = false) {
auto rewriter =
ThreadIdxRewriter(thread_var, replaced, thread_extent, do_shuffle);
return rewriter(stmt);
}
private:
ThreadIdxRewriter(Var thread_var, PrimExpr replaced)
: thread_var_(thread_var), replaced_(replaced) {}
ThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent,
bool do_shuffle)
: thread_var_(thread_var), replaced_(replaced),
thread_extent_(thread_extent), do_shuffle_(do_shuffle) {}
PrimExpr VisitExpr_(const VarNode *var) final {
if (var == thread_var_.get()) {
......@@ -335,8 +344,34 @@ private:
}
}
Stmt VisitStmt_(const IfThenElseNode *op) final {
auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
return parameter == thread_var_.get();
};
maybe_thread_opt_ = false;
if (!op->else_case.defined() && op->condition.as<EQNode>() &&
UsesVar(op->condition, f_uses_thread_index) &&
!(UsesVar(op->then_case, f_uses_thread_index))) {
auto eq_op = Downcast<EQ>(op->condition);
if (eq_op->a.as<VarNode>() == thread_var_.get() ||
eq_op->b.as<VarNode>() == thread_var_.get()) {
maybe_thread_opt_ = true;
}
maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_;
}
if (maybe_thread_opt_)
return IfThenElse(
Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}),
StmtExprMutator::VisitStmt(op->then_case), std::nullopt);
else
return StmtExprMutator::VisitStmt_(op);
}
Var thread_var_;
PrimExpr replaced_;
PrimExpr thread_extent_;
bool maybe_thread_opt_ = false;
bool do_shuffle_;
};
Block MakeGroupBlock(const Stmt &stmt,
......@@ -497,6 +532,41 @@ private:
PipelineInfo pipeline_info_;
};
class WgMMACollector : public StmtExprVisitor {
public:
WgMMACollector() = default;
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tl_gemm()) || op->op.same_as(tl_gemm_sp())) {
auto op_name = std::string(op->args[0].as<StringImmNode>()->value);
if (has_wgmma_) {
has_wgmma_ =
op_name.find("false") == std::string::npos && !in_if_scope_;
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const IfThenElseNode *op) final {
in_if_scope_ = true;
StmtExprVisitor::VisitStmt(op->then_case);
if (op->else_case.defined()) {
StmtExprVisitor::VisitStmt(op->else_case.value());
}
in_if_scope_ = false;
}
static bool HasWgMMA(Stmt stmt) {
auto collector = WgMMACollector();
collector(stmt);
return collector.has_wgmma_;
}
bool has_wgmma_{true};
bool in_if_scope_{false};
};
class WSCodeEmitter : public StmtMutator {
public:
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
......@@ -507,6 +577,10 @@ public:
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
bool onlyHasWgMMA() const { return only_has_wgmma_; }
bool hasSimtCopy() const { return has_simt_copy_; }
private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
Role role = marker_.GetRole(op);
......@@ -542,6 +616,9 @@ private:
op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
auto map = ExtractSyncPattern(op->seq);
only_has_wgmma_ = WgMMACollector::HasWgMMA(SeqStmt(op->seq));
/*
std::cout << "Print ExtractSyncPattern" << std::endl;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
......@@ -594,8 +671,9 @@ private:
MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
collector.Collect(stmt);
block_stmt.push_back(stmt);
if (collector.HasSimtCopy() > 0) {
if (collector.HasSimtCopy()) {
block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
has_simt_copy_ = true;
}
if (map.release_after[i][j]) {
block_stmt.push_back(makeArriveBarrier(release_barrier_id));
......@@ -630,7 +708,11 @@ private:
int pattern_idx = map.release[i][j];
PrimExpr release_barrier_id =
stage_ + num_barriers_ + num_stages_ * pattern_idx;
block_stmt.push_back(makeArriveBarrier(release_barrier_id));
if (only_has_wgmma_)
block_stmt.push_back(makeArriveBarrier(
release_barrier_id, 0, EQ(FloorMod(thread_var_, 128), 0)));
else
block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int s = 0; s < num_stages_; s++) {
released_barrier_.insert(s + num_barriers_ +
num_stages_ * pattern_idx);
......@@ -982,6 +1064,8 @@ private:
bool mbarrier_only_ = false;
PipelineInfo pipeline_info_;
friend class WarpSpecializedRewriter;
bool only_has_wgmma_ = false;
bool has_simt_copy_ = false;
};
class SetMaxNRegCollector : public StmtExprVisitor {
......@@ -1022,9 +1106,12 @@ private:
class WarpSpecializedRewriter : public StmtExprMutator {
public:
WarpSpecializedRewriter(bool disable_warp_specialized)
: disable_warp_specialized_(disable_warp_specialized) {}
static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized) {
WarpSpecializedRewriter(bool disable_warp_specialized,
bool disable_shuffle_elect)
: disable_warp_specialized_(disable_warp_specialized),
disable_shuffle_elect_(disable_shuffle_elect) {}
static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized,
bool disable_shuffle_elect) {
// Check if function only uses threadIdx.x before proceeding
if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
LOG(WARNING) << "WarpSpecialize will be disabled because the program "
......@@ -1035,7 +1122,8 @@ public:
return f;
}
auto T = WarpSpecializedRewriter(disable_warp_specialized);
auto T = WarpSpecializedRewriter(disable_warp_specialized,
disable_shuffle_elect);
T.nreg_ = SetMaxNRegCollector::Collect(f);
T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_)
......@@ -1085,7 +1173,7 @@ private:
ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
Var thread_iv = Downcast<Var>(for_node->loop_var);
Stmt new_body =
ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_);
ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_, 0);
return new_body;
}
return for_node;
......@@ -1128,6 +1216,7 @@ private:
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body);
bool only_has_wgmma = consumer.onlyHasWgMMA();
PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case
......@@ -1150,10 +1239,15 @@ private:
producer_code = SeqStmt({dec_reg_stmt, producer_code});
consumer_code = SeqStmt({inc_reg_stmt, consumer_code});
producer_code =
ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var,
thread_iv_->var - consumer_thread_extent);
updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
producer_code = ThreadIdxRewriter::Rewrite(
producer_code, thread_iv_->var,
thread_iv_->var - consumer_thread_extent, producer_thread_extent,
!disable_shuffle_elect_);
consumer_code = ThreadIdxRewriter::Rewrite(
consumer_code, thread_iv_->var, thread_iv_->var, consumer_thread_extent,
!disable_shuffle_elect_);
need_update_thread_extent_ = true;
ICHECK(producer.num_barriers_ == consumer.num_barriers_)
......@@ -1162,9 +1256,11 @@ private:
Array<PrimExpr> barrier_num_threads;
barrier_num_threads.reserve(num_barriers);
for (int i = 0; i < num_barriers; i++) {
PrimExpr arrive_thread_count = producer.released_barrier_.count(i)
? producer_thread_extent
: consumer_thread_extent;
PrimExpr arrive_thread_count =
producer.released_barrier_.count(i)
? (producer.hasSimtCopy() ? producer_thread_extent : 1)
: (only_has_wgmma ? FloorDiv(consumer_thread_extent, 128)
: consumer_thread_extent);
barrier_num_threads.push_back(arrive_thread_count);
}
......@@ -1191,6 +1287,7 @@ private:
Optional<PrimExpr> updated_thread_extent_;
bool need_update_thread_extent_ = false;
bool disable_warp_specialized_ = false;
bool disable_shuffle_elect_ = false;
Array<IntImm> nreg_;
};
......@@ -1257,10 +1354,13 @@ tvm::transform::Pass WarpSpecialized() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
bool disable_warp_specialized =
ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
bool disable_shuffle_elect =
ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
if (!warp_specialized) {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized);
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
disable_shuffle_elect);
}
return f;
};
......
"""The language interface for tl programs."""
from typing import Union, List, Optional
from typing import Union, List, Optional, Literal
from tilelang import language as T
from tilelang.utils.language import get_buffer_region_from_load
from tvm import ir, tir
......@@ -81,12 +81,11 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def copy(
src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion],
dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None,
disable_tma: bool = False,
):
def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion],
dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None,
disable_tma: bool = False,
eviction_policy: Optional[Literal["evict_normal", "evict_first", "evict_last"]] = None):
"""Copy data between memory regions.
Args:
......@@ -145,20 +144,24 @@ def copy(
if coalesced_width is None:
coalesced_width = -1 # PrimExpr can not be None
if eviction_policy is None:
eviction_policy = 0
else:
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width,
disable_tma)
def c2d_im2col(
img: tir.Buffer,
col: tir.Buffer,
nhw_step: tir.PrimExpr,
c_step: tir.PrimExpr,
kernel: int,
stride: int,
dilation: int,
pad: int,
):
disable_tma, eviction_policy)
def c2d_im2col(img: tir.Buffer,
col: tir.Buffer,
nhw_step: tir.PrimExpr,
c_step: tir.PrimExpr,
kernel: int,
stride: int,
dilation: int,
pad: int,
eviction_policy: Optional[Literal["evict_normal", "evict_first",
"evict_last"]] = None):
"""Perform im2col transformation for 2D convolution.
Args:
......@@ -174,15 +177,10 @@ def c2d_im2col(
Returns:
tir.Call: A handle to the im2col operation
"""
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.c2d_im2col"),
img.access_ptr("r"),
col.access_ptr("w"),
nhw_step,
c_step,
kernel,
stride,
dilation,
pad,
)
if eviction_policy is None:
eviction_policy = 0
else:
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img.access_ptr("r"),
col.access_ptr("w"), nhw_step, c_step, kernel, stride, dilation, pad,
eviction_policy)
......@@ -43,6 +43,9 @@ class PassConfigKey(str, Enum):
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE = "tl.enable_aggressive_shared_memory_merge"
"""Enable aggressive merge of shared memory allocations. Default: False"""
TL_DISABLE_SHUFFLE_ELECT = "tl.disable_shuffle_elect"
"""Disable shuffle election optimization. Default: False"""
# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment