Commit 7fdcedd0 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support pass config `disable_warp_specialize` to disable auto...

[Enhancement] Support pass config `disable_warp_specialize` to disable auto specialization on hopper (#357)

* [Enhancement] Add warp specialization configuration option and update related functionality

* [Add] Introduced a new pass configuration option `kDisableWarpSpecialized` to control warp specialization behavior.
* [Refactor] Updated `WarpSpecializedRewriter` and `WSCodeEmitter` to utilize the new configuration option, allowing for more flexible optimization strategies.
* [Update] Modified the optimization pipeline in `phase.py` to include pipeline planning when warp specialization is disabled, enhancing performance with async copy.
* [Documentation] Updated JIT compilation parameters to reflect the new configuration option for better clarity.

* lint fix

* [Add] Implement test for GEMM with warp specialization configuration

* Introduced a new test file `test_tilelang_pass_config_disable_warp_specialized.py` to validate the functionality of the warp specialization configuration option.
* Added a `run_gemm` function to execute matrix multiplication tests with and without warp specialization, ensuring correctness through profiling against reference results.
* Included a specific test case for GEMM with float16 data types, enhancing test coverage for the new configuration feature.

* [Refactor] Improve formatting in test_tilelang_pass_config_disable_warp_specialized.py

* Reformatted the `tilelang.compile` call in the `run_gemm` function for better readability by breaking it into multiple lines.
* Added a blank line for improved code structure and clarity in the `test_gemm_f16f16f16_nn` function.
parent a686f0f1
......@@ -17,6 +17,7 @@ namespace tvm {
namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
#define TIR_DEFINE_TL_BUILTIN(OpName) \
......
......@@ -14,7 +14,8 @@ namespace tvm {
namespace tl {
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
static constexpr const char *kDisableWarpSpecialized =
"tl.disable_warp_specialized";
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
/*!
......
......@@ -556,24 +556,31 @@ class WSCodeEmitter : public StmtMutator {
public:
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker)
const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false)
: is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var) {}
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
Role role = marker_.GetRole(op);
if (role == Role::kBoth)
if (mbarrier_only_) {
if (role != Role::kProducer)
return StmtMutator::VisitStmt_(op);
}
if (role == Role::kBoth) {
return StmtMutator::VisitStmt_(op);
else if ((role == Role::kProducer) == is_emitting_producer_)
} else if ((role == Role::kProducer) == is_emitting_producer_) {
return GetRef<Stmt>(op);
else
} else {
return Evaluate(0);
}
}
// TODO: only need to add block for ops in the loop
Stmt VisitStmt_(const SeqStmtNode *op) final {
bool has_producer = false;
for (auto stmt : op->seq) {
if (marker_.GetRole(stmt) == Role::kProducer) {
......@@ -590,18 +597,20 @@ private:
op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
auto map = ExtractSyncPattern(op->seq);
// std::cout << "Print ExtractSyncPattern" << std::endl;
// for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
// << map.release_after[i] << std::endl;
// }
// std::cout << "Print sync pattern" << std::endl;
// for (auto pattern : map.patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// }
// std::cout << "End of ExtractSyncPattern" << std::endl;
// pipeline_info_.PrintPipelineInfo();
/*
std::cout << "Print ExtractSyncPattern" << std::endl;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
<< map.release_after[i] << std::endl;
}
std::cout << "Print sync pattern" << std::endl;
for (auto pattern : map.patterns) {
std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
std::endl;
}
std::cout << "End of ExtractSyncPattern" << std::endl;
pipeline_info_.PrintPipelineInfo();
*/
Array<Stmt> new_body;
Map<String, ObjectRef> annotations;
annotations.Set(String("stmt_group"), Integer(1));
......@@ -610,16 +619,19 @@ private:
ProducerTraitsCollector collector;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
continue;
if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
block_stmt.push_back(seq_transformed[i]);
new_body.push_back(MakeGroupBlock(
block_stmt.size() == 1 ? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
continue;
if (!mbarrier_only_) {
if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
continue;
if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
block_stmt.push_back(seq_transformed[i]);
new_body.push_back(MakeGroupBlock(
block_stmt.size() == 1 ? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
continue;
}
}
if (map.acquire[i] != -1) {
PrimExpr acquire_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.acquire[i];
......@@ -987,6 +999,7 @@ private:
PrimExpr stage_ = 0;
int num_stages_ = 1;
Var thread_var_;
bool mbarrier_only_ = false;
PipelineInfo pipeline_info_;
friend class WarpSpecializedRewriter;
};
......@@ -1072,7 +1085,9 @@ private:
class WarpSpecializedRewriter : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
WarpSpecializedRewriter(bool disable_warp_specialized)
: disable_warp_specialized_(disable_warp_specialized) {}
static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized) {
// Check if function only uses threadIdx.x before proceeding
if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
LOG(WARNING) << "WarpSpecialize will be disabled because the program "
......@@ -1083,7 +1098,7 @@ public:
return f;
}
auto T = WarpSpecializedRewriter();
auto T = WarpSpecializedRewriter(disable_warp_specialized);
T.nreg_ = SetMaxNRegCollector::Collect(f);
T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_)
......@@ -1154,11 +1169,27 @@ private:
return block_realize;
}
if (disable_warp_specialized_) {
WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_,
marker, true);
auto code = mbarrier_emitter(block->body);
int num_barriers = mbarrier_emitter.num_barriers_;
Array<PrimExpr> barrier_num_threads;
barrier_num_threads.reserve(num_barriers);
PrimExpr arrive_thread_count = thread_iv_->dom->extent;
for (int i = 0; i < num_barriers; i++) {
barrier_num_threads.push_back(arrive_thread_count);
}
Stmt init_barrier = Evaluate(Call(
DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
block_realize.CopyOnWrite()->block = block;
return block_realize;
}
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body);
PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case
......@@ -1166,7 +1197,6 @@ private:
producer_thread_extent = 128;
// TODO: estimate the correct reg usage.
int dec_reg = nreg_[0].as<IntImmNode>()->value;
int inc_reg = nreg_[1].as<IntImmNode>()->value;
......@@ -1222,6 +1252,7 @@ private:
IterVar thread_iv_;
Optional<PrimExpr> updated_thread_extent_;
bool need_update_thread_extent_ = false;
bool disable_warp_specialized_ = false;
Array<IntImm> nreg_;
};
......@@ -1229,7 +1260,9 @@ using namespace tir::transform;
tvm::transform::Pass WarpSpecialized() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return WarpSpecializedRewriter::Substitute(f);
bool disable_warp_specialized =
ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized);
};
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
......
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
disable_warp_specialized=False,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={"tl.disable_warp_specialized": disable_warp_specialized})
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
disable_warp_specialized=False,
)
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
disable_warp_specialized=True,
)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -106,4 +106,4 @@ def test_multi_version_buffer():
if __name__ == "__main__":
test_multi_version_buffer()
tilelang.testing.main()
......@@ -34,6 +34,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
# if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
......
......@@ -137,6 +137,7 @@ def compile(
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None
"""
return cached(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment