"examples/dynamic_shape/test_example_dynamic.py" did not exist on "c99b7056b489a5d222e9c7b41b954d243cff97da"
Unverified Commit 05f2fc6d authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Enhancement] Enhance warp specialization logic (#680)



- Removed unnecessary configurations from the @tilelang.jit decorator in `example_grouped_gemm_fwd.py`, simplifying the kernel compilation process.
- Updated the `grouped_gemm` function to accept a tuple for batch sizes, enhancing compatibility with the kernel invocation.
- Added logic in `warp_specialized_rewriter.cc` to track buffer usage in `CallNode` expressions, improving the handling of TMA load operations.

This refactor aims to streamline the code and improve maintainability while ensuring better performance in grouped matrix multiplication operations.
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 042c60fb
...@@ -7,11 +7,6 @@ import math ...@@ -7,11 +7,6 @@ import math
tilelang.disable_cache() tilelang.disable_cache()
@tilelang.jit(
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
""" """
Perform grouped matrix multiplication using PyTorch. Perform grouped matrix multiplication using PyTorch.
...@@ -44,11 +39,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): ...@@ -44,11 +39,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
return output return output
@tilelang.jit( @tilelang.jit(out_idx=[2])
out_idx=[2], pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def grouped_gemm(batch_sizes_list, def grouped_gemm(batch_sizes_list,
K, K,
N, N,
...@@ -150,7 +141,8 @@ def run_tilelang_grouped_gemm(batch_sizes_list, ...@@ -150,7 +141,8 @@ def run_tilelang_grouped_gemm(batch_sizes_list,
profile=False): profile=False):
padding_M = block_M padding_M = block_M
batch_sum = sum(batch_sizes_list) batch_sum = sum(batch_sizes_list)
kernel = grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, num_stages, threads) kernel = grouped_gemm(
tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads)
# print(kernel.get_kernel_source()) # print(kernel.get_kernel_source())
device = torch.device("cuda") device = torch.device("cuda")
......
...@@ -50,8 +50,6 @@ public: ...@@ -50,8 +50,6 @@ public:
for (const auto &buffer : usage.buffer_use_count_) { for (const auto &buffer : usage.buffer_use_count_) {
used_in_producer_cond_.insert(buffer.first); used_in_producer_cond_.insert(buffer.first);
} }
for (const auto &buffer : used_in_producer_cond_) {
}
} }
void VisitStmt_(const IfThenElseNode *op) final { void VisitStmt_(const IfThenElseNode *op) final {
...@@ -76,6 +74,16 @@ public: ...@@ -76,6 +74,16 @@ public:
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
for (auto arg : op->args) {
if (auto buffer_load = arg.as<BufferLoadNode>()) {
used_in_producer_cond_.insert(buffer_load->buffer.get());
}
}
}
}
private: private:
std::unordered_set<const BufferNode *> used_in_producer_cond_; std::unordered_set<const BufferNode *> used_in_producer_cond_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment