Unverified Commit 05f2fc6d authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Enhancement] Enhance warp specialization logic (#680)



- Removed unnecessary configurations from the @tilelang.jit decorator in `example_grouped_gemm_fwd.py`, simplifying the kernel compilation process.
- Updated the `grouped_gemm` function to accept a tuple for batch sizes, enhancing compatibility with the kernel invocation.
- Added logic in `warp_specialized_rewriter.cc` to track buffer usage in `CallNode` expressions, improving the handling of TMA load operations.

This refactor aims to streamline the code and improve maintainability while ensuring better performance in grouped matrix multiplication operations.
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 042c60fb
......@@ -7,11 +7,6 @@ import math
tilelang.disable_cache()
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
"""
Perform grouped matrix multiplication using PyTorch.
......@@ -44,11 +39,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
return output
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
@tilelang.jit(out_idx=[2])
def grouped_gemm(batch_sizes_list,
K,
N,
......@@ -150,7 +141,8 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
profile=False):
padding_M = block_M
batch_sum = sum(batch_sizes_list)
kernel = grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, num_stages, threads)
kernel = grouped_gemm(
tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads)
# print(kernel.get_kernel_source())
device = torch.device("cuda")
......
......@@ -50,8 +50,6 @@ public:
for (const auto &buffer : usage.buffer_use_count_) {
used_in_producer_cond_.insert(buffer.first);
}
for (const auto &buffer : used_in_producer_cond_) {
}
}
void VisitStmt_(const IfThenElseNode *op) final {
......@@ -76,6 +74,16 @@ public:
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
for (auto arg : op->args) {
if (auto buffer_load = arg.as<BufferLoadNode>()) {
used_in_producer_cond_.insert(buffer_load->buffer.get());
}
}
}
}
private:
std::unordered_set<const BufferNode *> used_in_producer_cond_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment