Commit 7fdcedd0 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support pass config `disable_warp_specialize` to disable auto...

[Enhancement] Support pass config `disable_warp_specialize` to disable auto specialization on hopper (#357)

* [Enhancement] Add warp specialization configuration option and update related functionality

* [Add] Introduced a new pass configuration option `kDisableWarpSpecialized` to control warp specialization behavior.
* [Refactor] Updated `WarpSpecializedRewriter` and `WSCodeEmitter` to utilize the new configuration option, allowing for more flexible optimization strategies.
* [Update] Modified the optimization pipeline in `phase.py` to include pipeline planning when warp specialization is disabled, enhancing performance with async copy.
* [Documentation] Updated JIT compilation parameters to reflect the new configuration option for better clarity.

* lint fix

* [Add] Implement test for GEMM with warp specialization configuration

* Introduced a new test file `test_tilelang_pass_config_disable_warp_specialized.py` to validate the functionality of the warp specialization configuration option.
* Added a `run_gemm` function to execute matrix multiplication tests with and without warp specialization, ensuring correctness through profiling against reference results.
* Included a specific test case for GEMM with float16 data types, enhancing test coverage for the new configuration feature.

* [Refactor] Improve formatting in test_tilelang_pass_config_disable_warp_specialized.py

* Reformatted the `tilelang.compile` call in the `run_gemm` function for better readability by breaking it into multiple lines.
* Added a blank line for improved code structure and clarity in the `test_gemm_f16f16f16_nn` function.
parent a686f0f1
...@@ -17,6 +17,7 @@ namespace tvm { ...@@ -17,6 +17,7 @@ namespace tvm {
namespace tl { namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
#define TIR_DEFINE_TL_BUILTIN(OpName) \ #define TIR_DEFINE_TL_BUILTIN(OpName) \
......
...@@ -14,7 +14,8 @@ namespace tvm { ...@@ -14,7 +14,8 @@ namespace tvm {
namespace tl { namespace tl {
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
static constexpr const char *kDisableWarpSpecialized =
"tl.disable_warp_specialized";
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
/*! /*!
......
...@@ -556,24 +556,31 @@ class WSCodeEmitter : public StmtMutator { ...@@ -556,24 +556,31 @@ class WSCodeEmitter : public StmtMutator {
public: public:
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer, Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker) const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false)
: is_emitting_producer_(is_emitting_producer), : is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker), buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var) {} thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
private: private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) { template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
Role role = marker_.GetRole(op); Role role = marker_.GetRole(op);
if (role == Role::kBoth) if (mbarrier_only_) {
if (role != Role::kProducer)
return StmtMutator::VisitStmt_(op);
}
if (role == Role::kBoth) {
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
else if ((role == Role::kProducer) == is_emitting_producer_) } else if ((role == Role::kProducer) == is_emitting_producer_) {
return GetRef<Stmt>(op); return GetRef<Stmt>(op);
else } else {
return Evaluate(0); return Evaluate(0);
}
} }
// TODO: only need to add block for ops in the loop // TODO: only need to add block for ops in the loop
Stmt VisitStmt_(const SeqStmtNode *op) final { Stmt VisitStmt_(const SeqStmtNode *op) final {
bool has_producer = false; bool has_producer = false;
for (auto stmt : op->seq) { for (auto stmt : op->seq) {
if (marker_.GetRole(stmt) == Role::kProducer) { if (marker_.GetRole(stmt) == Role::kProducer) {
...@@ -590,18 +597,20 @@ private: ...@@ -590,18 +597,20 @@ private:
op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); }); op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
auto map = ExtractSyncPattern(op->seq); auto map = ExtractSyncPattern(op->seq);
// std::cout << "Print ExtractSyncPattern" << std::endl; /*
// for (int i = 0; i < static_cast<int>(op->seq.size()); i++) { std::cout << "Print ExtractSyncPattern" << std::endl;
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
// << map.release_after[i] << std::endl; std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
// } << map.release_after[i] << std::endl;
// std::cout << "Print sync pattern" << std::endl; }
// for (auto pattern : map.patterns) { std::cout << "Print sync pattern" << std::endl;
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << for (auto pattern : map.patterns) {
// std::endl; std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// } std::endl;
// std::cout << "End of ExtractSyncPattern" << std::endl; }
// pipeline_info_.PrintPipelineInfo(); std::cout << "End of ExtractSyncPattern" << std::endl;
pipeline_info_.PrintPipelineInfo();
*/
Array<Stmt> new_body; Array<Stmt> new_body;
Map<String, ObjectRef> annotations; Map<String, ObjectRef> annotations;
annotations.Set(String("stmt_group"), Integer(1)); annotations.Set(String("stmt_group"), Integer(1));
...@@ -610,16 +619,19 @@ private: ...@@ -610,16 +619,19 @@ private:
ProducerTraitsCollector collector; ProducerTraitsCollector collector;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) { for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {}; Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kConsumer) if (!mbarrier_only_) {
continue; if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
if (marker_.GetRole(op->seq[i]) == Role::kBoth) { continue;
block_stmt.push_back(seq_transformed[i]); if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
new_body.push_back(MakeGroupBlock( block_stmt.push_back(seq_transformed[i]);
block_stmt.size() == 1 ? block_stmt[0] new_body.push_back(MakeGroupBlock(
: SeqStmt(std::move(block_stmt)), block_stmt.size() == 1 ? block_stmt[0]
annotations)); : SeqStmt(std::move(block_stmt)),
continue; annotations));
continue;
}
} }
if (map.acquire[i] != -1) { if (map.acquire[i] != -1) {
PrimExpr acquire_barrier_id = PrimExpr acquire_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.acquire[i]; stage_ + num_barriers_ + num_stages_ * map.acquire[i];
...@@ -987,6 +999,7 @@ private: ...@@ -987,6 +999,7 @@ private:
PrimExpr stage_ = 0; PrimExpr stage_ = 0;
int num_stages_ = 1; int num_stages_ = 1;
Var thread_var_; Var thread_var_;
bool mbarrier_only_ = false;
PipelineInfo pipeline_info_; PipelineInfo pipeline_info_;
friend class WarpSpecializedRewriter; friend class WarpSpecializedRewriter;
}; };
...@@ -1072,7 +1085,9 @@ private: ...@@ -1072,7 +1085,9 @@ private:
class WarpSpecializedRewriter : public StmtExprMutator { class WarpSpecializedRewriter : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc f) { WarpSpecializedRewriter(bool disable_warp_specialized)
: disable_warp_specialized_(disable_warp_specialized) {}
static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized) {
// Check if function only uses threadIdx.x before proceeding // Check if function only uses threadIdx.x before proceeding
if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) { if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
LOG(WARNING) << "WarpSpecialize will be disabled because the program " LOG(WARNING) << "WarpSpecialize will be disabled because the program "
...@@ -1083,7 +1098,7 @@ public: ...@@ -1083,7 +1098,7 @@ public:
return f; return f;
} }
auto T = WarpSpecializedRewriter(); auto T = WarpSpecializedRewriter(disable_warp_specialized);
T.nreg_ = SetMaxNRegCollector::Collect(f); T.nreg_ = SetMaxNRegCollector::Collect(f);
T.buffer_lca_ = DetectBufferAccessLCA(f); T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_) for (auto [buffer, _] : T.buffer_lca_)
...@@ -1154,11 +1169,27 @@ private: ...@@ -1154,11 +1169,27 @@ private:
return block_realize; return block_realize;
} }
if (disable_warp_specialized_) {
WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_,
marker, true);
auto code = mbarrier_emitter(block->body);
int num_barriers = mbarrier_emitter.num_barriers_;
Array<PrimExpr> barrier_num_threads;
barrier_num_threads.reserve(num_barriers);
PrimExpr arrive_thread_count = thread_iv_->dom->extent;
for (int i = 0; i < num_barriers; i++) {
barrier_num_threads.push_back(arrive_thread_count);
}
Stmt init_barrier = Evaluate(Call(
DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
block_realize.CopyOnWrite()->block = block;
return block_realize;
}
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker); WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
Stmt producer_code = producer(block->body); Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body); Stmt consumer_code = consumer(block->body);
PrimExpr consumer_thread_extent = thread_iv_->dom->extent; PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent; PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case // Need one warp-group for bulk-copy only case
...@@ -1166,7 +1197,6 @@ private: ...@@ -1166,7 +1197,6 @@ private:
producer_thread_extent = 128; producer_thread_extent = 128;
// TODO: estimate the correct reg usage. // TODO: estimate the correct reg usage.
int dec_reg = nreg_[0].as<IntImmNode>()->value; int dec_reg = nreg_[0].as<IntImmNode>()->value;
int inc_reg = nreg_[1].as<IntImmNode>()->value; int inc_reg = nreg_[1].as<IntImmNode>()->value;
...@@ -1222,6 +1252,7 @@ private: ...@@ -1222,6 +1252,7 @@ private:
IterVar thread_iv_; IterVar thread_iv_;
Optional<PrimExpr> updated_thread_extent_; Optional<PrimExpr> updated_thread_extent_;
bool need_update_thread_extent_ = false; bool need_update_thread_extent_ = false;
bool disable_warp_specialized_ = false;
Array<IntImm> nreg_; Array<IntImm> nreg_;
}; };
...@@ -1229,7 +1260,9 @@ using namespace tir::transform; ...@@ -1229,7 +1260,9 @@ using namespace tir::transform;
tvm::transform::Pass WarpSpecialized() { tvm::transform::Pass WarpSpecialized() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return WarpSpecializedRewriter::Substitute(f); bool disable_warp_specialized =
ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
} }
......
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
disable_warp_specialized=False,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={"tl.disable_warp_specialized": disable_warp_specialized})
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
disable_warp_specialized=False,
)
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
disable_warp_specialized=True,
)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -106,4 +106,4 @@ def test_multi_version_buffer(): ...@@ -106,4 +106,4 @@ def test_multi_version_buffer():
if __name__ == "__main__": if __name__ == "__main__":
test_multi_version_buffer() tilelang.testing.main()
...@@ -34,6 +34,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -34,6 +34,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod) mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.WarpSpecialized()(mod)
# if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
......
...@@ -137,6 +137,7 @@ def compile( ...@@ -137,6 +137,7 @@ def compile(
Available options: Available options:
"tir.disable_vectorize": bool, default: False "tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False "tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None "tl.config_index_bitwidth": int, default: None
""" """
return cached( return cached(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment