Unverified Commit 599264ca authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Bugfix] Fix CopyNode Lower method to include disable_tma flag in GetCopyInst (#888)

* Fix CopyNode Lower method to include disable_tma flag in GetCopyInst call

* Refactor flash attention implementation to disable TMA for specific copy and allow TMA for other operations

* attempt to fix lint
parent f58bcd43
......@@ -8,7 +8,7 @@ from functools import partial
num_split = 4
@tilelang.jit(out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True})
@tilelang.jit(out_idx=[5])
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim]
......@@ -124,7 +124,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid = by // heads
sid = bz
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared)
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -147,7 +149,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :])
T.copy(
O_shared,
Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :],
disable_tma=True)
@T.macro
def combine(
......@@ -188,7 +193,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i in T.Parallel(block_M):
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
for k in T.Pipelined(num_split, num_stages=2):
T.copy(Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], po_shared)
T.copy(
Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :],
po_shared,
disable_tma=True)
T.copy(po_shared, po_local)
for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i]
......@@ -197,7 +205,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i, j in T.Parallel(block_M, dim):
o_accum_local[i, j] += po_local[i, j] * scale_local[i]
T.copy(o_accum_local, o_shared)
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True)
@T.prim_func
def flashattn_mha_inference(
......
......@@ -692,8 +692,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
auto copy_inst =
GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer);
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, analyzer);
if (copy_inst == CopyInst::kBulkLoad1D ||
copy_inst == CopyInst::kBulkStore1D) {
auto bulk_copy = LowerBulkCopy1D(T, analyzer, copy_inst);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment