Unverified Commit f4a828f6 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Enhancement][Bugfix] Fix bug in warp specialized pass and add gemm_sr...


[Enhancement][Bugfix] Fix bug in warp specialized pass and add gemm_sr fallback support for Hopper (#712)

* bug fix and support gemm_sr fallback for hopper

* Update gemm.cc

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 1b308baf
......@@ -241,6 +241,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
}
bool Gemm::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false;
}
if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
......@@ -443,7 +447,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
B->dtype.bits(), trans_B ? 2 : 1);
results.Set(B, ABLayout);
} else {
ICHECK(0) << "WGMMA only support B in shared.";
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
}
} else if (TargetIsCDNA(T.target)) {
auto fragment =
......@@ -490,4 +496,4 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tvm
......@@ -624,6 +624,19 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
}
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
template <int num_mma> TL_DEVICE void wait_wgmma() {
cute::warpgroup_wait<num_mma>();
}
......
......@@ -572,12 +572,11 @@ public:
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false)
bool mbarrier_only = false, bool only_has_wgmma = false)
: is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
bool onlyHasWgMMA() const { return only_has_wgmma_; }
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
only_has_wgmma_(only_has_wgmma) {}
bool hasSimtCopy() const { return has_simt_copy_; }
......@@ -617,8 +616,6 @@ private:
auto map = ExtractSyncPattern(op->seq);
only_has_wgmma_ = WgMMACollector::HasWgMMA(SeqStmt(op->seq));
/*
std::cout << "Print ExtractSyncPattern" << std::endl;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
......@@ -1212,11 +1209,12 @@ private:
block_realize.CopyOnWrite()->block = block;
return block_realize;
}
only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body);
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
false, only_has_wgmma_);
Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body);
bool only_has_wgmma = consumer.onlyHasWgMMA();
PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case
......@@ -1259,8 +1257,8 @@ private:
PrimExpr arrive_thread_count =
producer.released_barrier_.count(i)
? (producer.hasSimtCopy() ? producer_thread_extent : 1)
: (only_has_wgmma ? FloorDiv(consumer_thread_extent, 128)
: consumer_thread_extent);
: (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128)
: consumer_thread_extent);
barrier_num_threads.push_back(arrive_thread_count);
}
......@@ -1289,6 +1287,7 @@ private:
bool disable_warp_specialized_ = false;
bool disable_shuffle_elect_ = false;
Array<IntImm> nreg_;
bool only_has_wgmma_ = false;
};
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment