Commit 0d873fcf authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Dev][Bugfix] Fix bug in ThreadTagChecker; Add WgmmaSync rewriter and add MHA...

[Dev][Bugfix] Fix bug in ThreadTagChecker; Add WgmmaSync rewriter and add MHA WGMMA pipelined example (#128)

* [Dev] Add RetNet Linear Attention example

* [Dev] Add WgmmaSync rewriter for pipelined WGMMA operations and add MHA WGMMA pipelined example (FA3-like scheduling)

This commit introduces a new transformation pass `RewriteWgmmaSync` to optimize warp group matrix multiply accumulate (WGMMA) operations in the TileLang compiler:

- Implemented `WgmmaSyncRewriter` in `src/transform/wgmma_sync_rewriter.cc`
- Added pass registration for `RewriteWgmmaSync`
- Updated `tilelang/engine/phase.py` to include the new transformation pass
- Updated `tilelang/transform/__init__.py` to expose the new pass

The rewriter intelligently manages synchronization and dependencies between WGMMA operations, improving pipeline efficiency for complex matrix multiplication kernels.

* [Bugfix] Fix bug in ThreadTagChecker for warp specialization

Improve thread tag validation in warp specialized rewriter to prevent unintended transformations:
- Add more precise checks for threadIdx.y and threadIdx.z
- Validate thread extent to ensure only single-extent thread bindings are allowed
- Prevent warp specialization for multi-extent thread bindings in y and z dimensions

* lint

* [CI] Add TMA descriptor attribute to transformed module in test case
parent 7b74bb01
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.functional as F
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Buffer(shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Buffer(shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
Output: T.Buffer(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
batch, heads, seq_len, dim, is_causal = args.batch, args.heads, args.seq_len, args.dim, args.is_causal
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not args.tune):
program = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune)(
block_M=128, block_N=128, num_stages=2, threads=256)
ref_program = partial(ref_program, is_causal=is_causal)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -878,9 +878,12 @@ public:
private:
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
auto iter_var = Downcast<IterVar>(op->node);
if (iter_var->thread_tag.length() > 0 &&
iter_var->thread_tag != "threadIdx.x") {
IterVar iter_var = Downcast<IterVar>(op->node);
String thread_tag = iter_var->thread_tag;
bool is_y_or_z =
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
if (!thread_tag.empty() && is_y_or_z && !is_one(iter_var->dom->extent)) {
is_valid_ = false;
}
}
......@@ -891,10 +894,16 @@ private:
if (op->kind == ForKind::kThreadBinding) {
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
if (thread_tag.length() > 0 && thread_tag != "threadIdx.x") {
bool is_y_or_z =
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
if (!thread_tag.empty() && is_y_or_z) {
auto iter_var = Downcast<IterVar>(op->thread_binding);
if (iter_var.defined() && iter_var->dom.defined() &&
!is_one(iter_var->dom->extent)) {
is_valid_ = false;
}
}
}
StmtExprVisitor::VisitStmt_(op);
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file warp_specialized_pipeline.cc
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
bool isGemm(Stmt stmt) {
bool is_gemm = false;
if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
if (call && call->op.same_as(Op::Get("tir.call_extern"))) {
if (call->args[0].as<StringImmNode>()) {
std::string name = Downcast<StringImm>(call->args[0])->value;
if (name.find("gemm") != std::string::npos) {
is_gemm = true;
}
}
}
}
return is_gemm;
}
bool isGemmSync(Stmt stmt) {
bool is_gemm_sync = false;
if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
if (call && call->op.same_as(Op::Get("tir.call_extern"))) {
if (call->args[0].as<StringImmNode>()) {
std::string name = Downcast<StringImm>(call->args[0])->value;
if (name.find("warpgroup_wait") != std::string::npos) {
is_gemm_sync = true;
}
}
}
}
return is_gemm_sync;
}
bool isArriveBarrier(Stmt stmt) {
bool is_arrive_barrier = false;
if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
if (call && call->op.same_as(Op::Get("tir.ptx_arrive_barrier"))) {
is_arrive_barrier = true;
}
}
return is_arrive_barrier;
}
class WgmmaSyncRewriter : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = WgmmaSyncRewriter();
T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_)
T.buffer_data_to_buffer_.Set(buffer->data, buffer);
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
void CollectWgmmaInfo(const SeqStmtNode *op) {
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
auto stmt = op->seq[i];
if (isGemm(stmt)) {
gemm_stmts_.push_back(stmt);
gemm_stmt_ids_.push_back(i);
bool found_release = false;
for (int j = i + 1; j < static_cast<int>(op->seq.size()); j++) {
auto release_stmt = op->seq[j];
if (isArriveBarrier(release_stmt)) {
found_release = true;
gemm_release_stmts_.push_back(release_stmt);
break;
}
}
if (!found_release) {
gemm_release_stmts_.push_back(Evaluate(0));
}
// ICHECK(op->seq.size() > i + 1);
// auto release_stmt = op->seq[i + 1];
// auto next_call =
// Downcast<Evaluate>(release_stmt)->value.as<CallNode>();
// ICHECK(next_call);
// ICHECK(next_call->op.same_as(Op::Get("tir.ptx_arrive_barrier")));
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"",
/*body*/ op->seq[i]);
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
std::set<const BufferNode *> read_set, write_set;
for (auto region : access[0])
read_set.insert(region->buffer.get());
for (auto region : access[1])
write_set.insert(region->buffer.get());
gemm_read_buffers_.push_back(read_set);
gemm_write_buffers_.push_back(write_set);
}
}
}
Stmt VisitStmt_(const ForNode *op) final {
auto order_anno = op->annotations.Get("tl_pipeline_order");
if (!order_anno.defined()) {
return StmtExprMutator::VisitStmt_(op);
}
CollectWgmmaInfo(op->body.as<SeqStmtNode>());
auto stmt_node = (op->body).as<SeqStmtNode>();
ICHECK(stmt_node);
auto intersect_fn = [](const std::set<const BufferNode *> &lhs,
const std::set<const BufferNode *> &rhs) {
for (auto ptr : lhs)
if (rhs.count(ptr))
return true;
return false;
};
for (int r = 0; r < static_cast<int>(gemm_stmts_.size()); r++) {
bool found = false;
auto last_stmt = Stmt();
for (int i = 0; i < static_cast<int>(stmt_node->seq.size()); i++) {
if (stmt_node->seq[i].same_as(gemm_stmts_[r])) {
found = true;
last_stmt = stmt_node->seq[i];
continue;
}
if (!found)
continue;
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"",
/*body*/ stmt_node->seq[i]);
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
std::set<const BufferNode *> read_set, write_set;
for (auto region : access[0])
read_set.insert(region->buffer.get());
for (auto region : access[1])
write_set.insert(region->buffer.get());
if (intersect_fn(read_set, gemm_write_buffers_[r]) ||
intersect_fn(write_set, gemm_read_buffers_[r]) ||
intersect_fn(write_set, gemm_write_buffers_[r])) {
break;
}
last_stmt = stmt_node->seq[i];
}
last_stmts_.push_back(last_stmt);
}
auto new_seq = Array<Stmt>();
for (int i = 0; i < static_cast<int>(stmt_node->seq.size()); i++) {
bool remove_ = false;
for (int j = 0; j < static_cast<int>(gemm_stmts_.size()); j++) {
if (stmt_node->seq[i].same_as(gemm_release_stmts_[j])) {
remove_ = true;
continue;
}
}
if (remove_)
continue;
auto stmt = stmt_node->seq[i];
for (int j = 0; j < static_cast<int>(gemm_stmts_.size()); j++) {
if (stmt_node->seq[i].same_as(gemm_stmts_[j])) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(Op::Get("tir.call_extern")));
ICHECK(call->args[0].as<StringImmNode>());
std::string name = Downcast<StringImm>(call->args[0])->value;
std::string new_name = name.substr(0, name.size() - 1) + ", -1>";
auto new_args = Array<PrimExpr>();
new_args.push_back(StringImm(new_name));
for (int k = 1; k < static_cast<int>(call->args.size()); k++) {
new_args.push_back(call->args[k]);
}
stmt = Evaluate(
Call(DataType::Handle(), builtin::call_extern(), new_args));
break;
}
}
new_seq.push_back(stmt);
for (int j = 0; j < static_cast<int>(gemm_stmts_.size()); j++) {
if (stmt_node->seq[i].same_as(last_stmts_[j])) {
Array<PrimExpr> new_args;
new_args.push_back(StringImm("cute::warpgroup_wait<0>"));
new_args.push_back(Integer(j));
auto new_call =
Call(DataType::Handle(), builtin::call_extern(), new_args);
new_seq.push_back(Evaluate(new_call));
if (std::count(gemm_release_stmts_.begin(), gemm_release_stmts_.end(),
gemm_release_stmts_[j]) == 1) {
new_seq.push_back(gemm_release_stmts_[j]);
} else {
gemm_release_stmts_[j] = Evaluate(0);
}
}
}
}
int gemm_count = 0;
int max_sync_index = 0;
for (int i = 0; i < static_cast<int>(new_seq.size()); i++) {
if (isGemm(new_seq[i])) {
gemm_count++;
} else if (isGemmSync(new_seq[i])) {
auto call = Downcast<Evaluate>(new_seq[i])->value.as<CallNode>();
auto sync_index = Downcast<IntImm>(call->args[1])->value;
auto wait_count = gemm_count - sync_index - 1;
if (sync_index > max_sync_index)
max_sync_index = sync_index;
if (sync_index < max_sync_index) {
// new_seq.erase(new_seq.begin() + i);
new_seq.Set(i, Evaluate(0));
} else {
Array<PrimExpr> new_args;
std::string call_str =
"cute::warpgroup_wait<" + std::to_string(wait_count) + ">";
new_args.push_back(StringImm(call_str));
new_seq.Set(i, Evaluate(Call(DataType::Handle(),
builtin::call_extern(), new_args)));
}
}
}
auto new_for =
For(op->loop_var, op->min, op->extent, op->kind,
new_seq.size() == 1 ? new_seq[0] : SeqStmt(std::move(new_seq)),
op->thread_binding, op->annotations);
return new_for;
}
WgmmaSyncRewriter() = default;
Map<Buffer, Optional<Stmt>> buffer_lca_;
Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<std::set<const BufferNode *>> gemm_read_buffers_;
std::vector<std::set<const BufferNode *>> gemm_write_buffers_;
std::vector<Stmt> gemm_stmts_;
std::vector<Stmt> gemm_release_stmts_;
std::vector<Stmt> last_stmts_;
std::vector<int32_t> gemm_stmt_ids_;
friend class WgmmaReleaseCollector;
};
using namespace tir::transform;
tvm::transform::Pass RewriteWgmmaSync() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return WgmmaSyncRewriter::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.RewriteWgmmaSync")
.set_body_typed(RewriteWgmmaSync);
} // namespace tl
} // namespace tvm
......@@ -19,6 +19,7 @@ def _check(original, transformed):
transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main"))
transformed = tvm.tir.transform.BindTarget(auto_target)(transformed)
transformed = tir.transform.LowerOpaqueBlock()(transformed)
transformed["main"] = transformed["main"].with_attr("tma_descriptor_args", {})
tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True)
......
......@@ -36,6 +36,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tl.transform.WarpSpecialized()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tl.transform.RewriteWgmmaSync()(mod)
# mod = tl.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.InjectFenceProxy()(mod)
else:
......
......@@ -95,6 +95,17 @@ def WarpSpecializedPipeline():
return _ffi_api.WarpSpecializedPipeline() # type: ignore
def RewriteWgmmaSync():
"""RewriteWgmmaSync
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.RewriteWgmmaSync() # type: ignore
def ThreadSync(storage_scope: str):
"""Insert sync between parallel read/write of shared buffers.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment