Unverified Commit 599264ca authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Bugfix] Fix CopyNode Lower method to include disable_tma flag in GetCopyInst (#888)

* Fix CopyNode Lower method to include disable_tma flag in GetCopyInst call

* Refactor flash attention implementation to disable TMA for specific copy and allow TMA for other operations

* attempt to fix lint
parent f58bcd43
...@@ -8,7 +8,7 @@ from functools import partial ...@@ -8,7 +8,7 @@ from functools import partial
num_split = 4 num_split = 4
@tilelang.jit(out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}) @tilelang.jit(out_idx=[5])
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim] shape_q = [batch, seqlen_q, heads, dim]
...@@ -124,7 +124,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -124,7 +124,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid = by // heads bid = by // heads
sid = bz sid = bz
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared) # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -147,7 +149,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -147,7 +149,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M]) T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M])
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :]) T.copy(
O_shared,
Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :],
disable_tma=True)
@T.macro @T.macro
def combine( def combine(
...@@ -188,7 +193,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -188,7 +193,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
for k in T.Pipelined(num_split, num_stages=2): for k in T.Pipelined(num_split, num_stages=2):
T.copy(Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], po_shared) T.copy(
Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :],
po_shared,
disable_tma=True)
T.copy(po_shared, po_local) T.copy(po_shared, po_local)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i] lse_local_split[i] = lse_local[k, i]
...@@ -197,7 +205,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -197,7 +205,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
o_accum_local[i, j] += po_local[i, j] * scale_local[i] o_accum_local[i, j] += po_local[i, j] * scale_local[i]
T.copy(o_accum_local, o_shared) T.copy(o_accum_local, o_shared)
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True)
@T.prim_func @T.prim_func
def flashattn_mha_inference( def flashattn_mha_inference(
......
...@@ -692,8 +692,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -692,8 +692,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
PassContext pass_ctx = PassContext::Current(); PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower = bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value(); pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
auto copy_inst = auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
GetCopyInst(target, disable_tma_lower, T.layout_map, analyzer); T.layout_map, analyzer);
if (copy_inst == CopyInst::kBulkLoad1D || if (copy_inst == CopyInst::kBulkLoad1D ||
copy_inst == CopyInst::kBulkStore1D) { copy_inst == CopyInst::kBulkStore1D) {
auto bulk_copy = LowerBulkCopy1D(T, analyzer, copy_inst); auto bulk_copy = LowerBulkCopy1D(T, analyzer, copy_inst);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment