Unverified Commit f4a828f6 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Enhancement][Bugfix] Fix bug in warp specialized pass and add gemm_sr...


[Enhancement][Bugfix] Fix bug in warp specialized pass and add gemm_sr fallback support for Hopper (#712)

* bug fix and support gemm_sr fallback for hopper

* Update gemm.cc

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 1b308baf
...@@ -241,6 +241,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size, ...@@ -241,6 +241,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
} }
bool Gemm::CheckWGMMA() const { bool Gemm::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false;
}
if (C->dtype == DataType::Float(16)) { if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0; return K % 16 == 0;
...@@ -443,7 +447,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -443,7 +447,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
B->dtype.bits(), trans_B ? 2 : 1); B->dtype.bits(), trans_B ? 2 : 1);
results.Set(B, ABLayout); results.Set(B, ABLayout);
} else { } else {
ICHECK(0) << "WGMMA only support B in shared."; auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
} }
} else if (TargetIsCDNA(T.target)) { } else if (TargetIsCDNA(T.target)) {
auto fragment = auto fragment =
......
...@@ -624,6 +624,19 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { ...@@ -624,6 +624,19 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
} }
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
template <int num_mma> TL_DEVICE void wait_wgmma() { template <int num_mma> TL_DEVICE void wait_wgmma() {
cute::warpgroup_wait<num_mma>(); cute::warpgroup_wait<num_mma>();
} }
......
...@@ -572,12 +572,11 @@ public: ...@@ -572,12 +572,11 @@ public:
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer, Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker, const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false) bool mbarrier_only = false, bool only_has_wgmma = false)
: is_emitting_producer_(is_emitting_producer), : is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker), buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {} thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
only_has_wgmma_(only_has_wgmma) {}
bool onlyHasWgMMA() const { return only_has_wgmma_; }
bool hasSimtCopy() const { return has_simt_copy_; } bool hasSimtCopy() const { return has_simt_copy_; }
...@@ -617,8 +616,6 @@ private: ...@@ -617,8 +616,6 @@ private:
auto map = ExtractSyncPattern(op->seq); auto map = ExtractSyncPattern(op->seq);
only_has_wgmma_ = WgMMACollector::HasWgMMA(SeqStmt(op->seq));
/* /*
std::cout << "Print ExtractSyncPattern" << std::endl; std::cout << "Print ExtractSyncPattern" << std::endl;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) { for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
...@@ -1212,11 +1209,12 @@ private: ...@@ -1212,11 +1209,12 @@ private:
block_realize.CopyOnWrite()->block = block; block_realize.CopyOnWrite()->block = block;
return block_realize; return block_realize;
} }
only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body);
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker); WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
false, only_has_wgmma_);
Stmt producer_code = producer(block->body); Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body); Stmt consumer_code = consumer(block->body);
bool only_has_wgmma = consumer.onlyHasWgMMA();
PrimExpr consumer_thread_extent = thread_iv_->dom->extent; PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent; PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case // Need one warp-group for bulk-copy only case
...@@ -1259,7 +1257,7 @@ private: ...@@ -1259,7 +1257,7 @@ private:
PrimExpr arrive_thread_count = PrimExpr arrive_thread_count =
producer.released_barrier_.count(i) producer.released_barrier_.count(i)
? (producer.hasSimtCopy() ? producer_thread_extent : 1) ? (producer.hasSimtCopy() ? producer_thread_extent : 1)
: (only_has_wgmma ? FloorDiv(consumer_thread_extent, 128) : (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128)
: consumer_thread_extent); : consumer_thread_extent);
barrier_num_threads.push_back(arrive_thread_count); barrier_num_threads.push_back(arrive_thread_count);
} }
...@@ -1289,6 +1287,7 @@ private: ...@@ -1289,6 +1287,7 @@ private:
bool disable_warp_specialized_ = false; bool disable_warp_specialized_ = false;
bool disable_shuffle_elect_ = false; bool disable_shuffle_elect_ = false;
Array<IntImm> nreg_; Array<IntImm> nreg_;
bool only_has_wgmma_ = false;
}; };
class WarpSpecializedDetector : public IRVisitorWithAnalyzer { class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment