Unverified Commit b12a63cf authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Bugfix] Ensure correct handling for cases where `seq_q<seq_kv` in flash attention examples (#864)

* fix flash attention examples  for `seqlen_q<seqlen_kv` cases

* lint
parent 3b21a67d
...@@ -34,6 +34,9 @@ def flashattn(batch, ...@@ -34,6 +34,9 @@ def flashattn(batch,
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
...@@ -45,7 +48,6 @@ def flashattn(batch, ...@@ -45,7 +48,6 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -135,8 +137,10 @@ def flashattn(batch, ...@@ -135,8 +137,10 @@ def flashattn(batch,
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( T.min(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
...@@ -159,7 +163,7 @@ def ref_program(Q, K, V, is_causal): ...@@ -159,7 +163,7 @@ def ref_program(Q, K, V, is_causal):
if is_causal: if is_causal:
seq_q = Q.size(2) seq_q = Q.size(2)
seq_kv = K.size(2) seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device)) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
......
...@@ -34,6 +34,9 @@ def flashattn(batch, ...@@ -34,6 +34,9 @@ def flashattn(batch,
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
...@@ -45,7 +48,6 @@ def flashattn(batch, ...@@ -45,7 +48,6 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -135,8 +137,10 @@ def flashattn(batch, ...@@ -135,8 +137,10 @@ def flashattn(batch,
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( T.min(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined( for k in T.Pipelined(
loop_range, loop_range,
...@@ -164,7 +168,7 @@ def ref_program(Q, K, V, is_causal): ...@@ -164,7 +168,7 @@ def ref_program(Q, K, V, is_causal):
if is_causal: if is_causal:
seq_q = Q.size(2) seq_q = Q.size(2)
seq_kv = K.size(2) seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device)) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment