Unverified Commit 013adca0 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Bugfix] Fix incorrect synchronization bug in minference example (#786)

* fix

* lint
parent e5b61e9b
...@@ -10,9 +10,7 @@ import triton.language as tl ...@@ -10,9 +10,7 @@ import triton.language as tl
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
from tilelang.testing import torch_assert_close
tilelang.disable_cache() tilelang.disable_cache()
...@@ -27,7 +25,9 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -27,7 +25,9 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scale = (1.0 / dim)**0.5 * 1.44269504 scale = (1.0 / dim)**0.5 * 1.44269504
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
count_shape = [batch, heads, (seq_len + block_M - 1) // block_M] seq_blocks = (seq_len + block_M - 1) // block_M
count_shape = [batch, heads, seq_blocks]
offset_shape = count_shape + [slash_size] offset_shape = count_shape + [slash_size]
index_shape = count_shape + [vertical_size] index_shape = count_shape + [vertical_size]
...@@ -47,7 +47,7 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -47,7 +47,7 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype), V_shared: T.SharedBuffer([block_N, dim], dtype),
column_index: T.SharedBuffer([vertical_size], int_dtype), column_index: T.SharedBuffer([vertical_size_round], int_dtype),
column_count: T.int32, column_count: T.int32,
k: T.int32, k: T.int32,
bz: T.int32, bz: T.int32,
...@@ -80,8 +80,9 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -80,8 +80,9 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
count: T.int32,
): ):
T.ptx_wait_group(1) T.ptx_wait_group(count)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k + j < column_count, 0, -T.infinity(acc_s.dtype)) acc_s[i, j] = T.if_then_else(k + j < column_count, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -106,7 +107,7 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -106,7 +107,7 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
@T.prim_func @T.prim_func
def vs_sparse_flashattn( def vs_sparse_flashattn_ws(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
...@@ -116,13 +117,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -116,13 +117,16 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
ColumnCount: T.Tensor(count_shape, int_dtype), ColumnCount: T.Tensor(count_shape, int_dtype),
ColumnIndex: T.Tensor(index_shape, int_dtype), ColumnIndex: T.Tensor(index_shape, int_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bc, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bc, by, bz):
bx = T.ceildiv(seq_len, block_M) - 1 - bc bx = T.ceildiv(seq_len, block_M) - 1 - bc
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([2, block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([2, block_N, dim], dtype)
K_shared_1 = T.alloc_shared([block_N, dim], dtype)
V_shared_1 = T.alloc_shared([block_N, dim], dtype)
K_shared_2 = T.alloc_shared([block_N, dim], dtype)
V_shared_2 = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype) O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
...@@ -137,10 +141,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -137,10 +141,11 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
column_count = T.alloc_local([1], int_dtype) column_count = T.alloc_local([1], int_dtype)
column_index = T.alloc_shared([vertical_size_round], int_dtype, scope="shared") column_index = T.alloc_shared([vertical_size_round], int_dtype, scope="shared")
K_shared_1 = T.alloc_shared([block_N, dim], dtype) T.create_list_of_mbarrier([128] * 9)
V_shared_1 = T.alloc_shared([block_N, dim], dtype)
K_shared_2 = T.alloc_shared([block_N, dim], dtype) T.annotate_layout({
V_shared_2 = T.alloc_shared([block_N, dim], dtype) O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
block_count[0] = BlockCount[bz, by, bx] block_count[0] = BlockCount[bz, by, bx]
column_count[0] = ColumnCount[bz, by, bx] column_count[0] = ColumnCount[bz, by, bx]
...@@ -153,81 +158,103 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz ...@@ -153,81 +158,103 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
if vi < vertical_size: if vi < vertical_size:
column_index[vi] = ColumnIndex[bz, by, bx, vi] column_index[vi] = ColumnIndex[bz, by, bx, vi]
T.fill(acc_o, 0) tid = T.get_thread_binding()
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) if tid >= 128:
T.annotate_producer_reg_dealloc()
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.mbarrier_arrive(mbarrier=8)
for bi in T.Pipelined(block_count[0], num_stages=num_stages): for bi in T.serial(block_count[0]):
k = block_offset[bi] k = block_offset[bi]
T.copy(K[bz, by, k:k + block_N, :], K_shared) T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1))
T.copy(K[bz, by, k:k + block_N, :], K_shared[bi % 2, :, :])
for i, j in T.Parallel(block_M, block_N): T.mbarrier_arrive(mbarrier=bi % 2)
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1))
-T.infinity(acc_s.dtype)) T.copy(V[bz, by, k:k + block_N, :], V_shared[bi % 2, :, :])
T.mbarrier_arrive(mbarrier=bi % 2 + 2)
T.gemm( else:
Q_shared, T.annotate_consumer_reg_alloc()
K_shared, T.fill(acc_o, 0)
acc_s, T.fill(logsum, 0)
transpose_B=True, T.fill(scores_max, -T.infinity(accum_dtype))
policy=T.GemmWarpPolicy.FullRow) T.mbarrier_wait_parity(mbarrier=8, parity=0)
for bi in T.serial(block_count[0]):
T.copy(scores_max, scores_max_prev) k = block_offset[bi]
for i, j in T.Parallel(block_M, block_N):
T.reduce_max(acc_s, scores_max, dim=1, clear=False) acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0,
-T.infinity(acc_s.dtype))
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1))
for i, j in T.Parallel(block_M, block_N): T.gemm(
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) Q_shared,
for i, j in T.Parallel(block_M, dim): K_shared[bi % 2, :, :],
acc_o[i, j] = acc_o[i, j] * scores_scale[i] acc_s,
transpose_B=True,
T.copy(acc_s, acc_s_cast) policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k:k + block_N, :], V_shared) T.mbarrier_arrive(mbarrier=bi % 2 + 4)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i in T.Parallel(block_M):
if column_count[0] != 0: scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz, by) scores_max[i] * scale)
for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1): for i, j in T.Parallel(block_M, block_N):
k = bi * block_N acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
if bi % 2 == 0: for i, j in T.Parallel(block_M, dim):
Prefetch(K, V, K_shared_2, V_shared_2, column_index, column_count[0], acc_o[i, j] = acc_o[i, j] * scores_scale[i]
k + block_N, bz, by)
T.copy(acc_s, acc_s_cast)
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=(((bi & 3) >> 1)))
column_count[0], Q_shared, K_shared_1, V_shared_1, scores_scale, T.gemm(
scores_sum, logsum) acc_s_cast,
V_shared[bi % 2, :, :],
acc_o,
policy=T.GemmWarpPolicy.FullRow)
T.mbarrier_arrive(mbarrier=bi % 2 + 6)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
if column_count[0] != 0:
Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], 0, bz,
by)
for bi in T.serial(T.ceildiv(column_count[0], block_N) - 1):
k = bi * block_N
if bi % 2 == 0:
Prefetch(K, V, K_shared_2, V_shared_2, column_index,
column_count[0], k + block_N, bz, by)
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k,
column_count[0], Q_shared, K_shared_1, V_shared_1,
scores_scale, scores_sum, logsum, 1)
else:
Prefetch(K, V, K_shared_1, V_shared_1, column_index,
column_count[0], k + block_N, bz, by)
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k,
column_count[0], Q_shared, K_shared_2, V_shared_2,
scores_scale, scores_sum, logsum, 1)
if T.ceildiv(column_count[0], block_N) % 2 == 0:
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev,
T.ceildiv(column_count[0], block_N) * block_N - block_N,
column_count[0], Q_shared, K_shared_2, V_shared_2, scores_scale,
scores_sum, logsum, 0)
else: else:
Prefetch(K, V, K_shared_1, V_shared_1, column_index, column_count[0], Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev,
k + block_N, bz, by) T.ceildiv(column_count[0], block_N) * block_N - block_N,
column_count[0], Q_shared, K_shared_1, V_shared_1, scores_scale,
scores_sum, logsum, 0)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev, k, return vs_sparse_flashattn_ws
column_count[0], Q_shared, K_shared_2, V_shared_2, scores_scale,
scores_sum, logsum)
if T.ceildiv(column_count[0], block_N) % 2 == 0:
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev,
T.ceildiv(column_count[0], block_N) * block_N - block_N,
column_count[0], Q_shared, K_shared_2, V_shared_2, scores_scale,
scores_sum, logsum)
else:
Compute(acc_s, acc_s_cast, acc_o, scores_max, scores_max_prev,
T.ceildiv(column_count[0], block_N) * block_N - block_N,
column_count[0], Q_shared, K_shared_1, V_shared_1, scores_scale,
scores_sum, logsum)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return vs_sparse_flashattn
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
...@@ -466,7 +493,7 @@ def vertical_slash_sparse_attention( ...@@ -466,7 +493,7 @@ def vertical_slash_sparse_attention(
s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(
dim=-1, descending=True)[0] dim=-1, descending=True)[0]
seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device) seqlens = torch.tensor([context_size] * query.shape[0], dtype=torch.int32, device=query.device)
sm_scale = head_dim**-0.5 sm_scale = head_dim**-0.5
block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes( block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes(
seqlens, seqlens,
...@@ -524,7 +551,6 @@ def main(argv=None): ...@@ -524,7 +551,6 @@ def main(argv=None):
parser.add_argument("--slash_size", type=int, default=200) parser.add_argument("--slash_size", type=int, default=200)
args = parser.parse_args(argv) args = parser.parse_args(argv)
# vs_list = [[1000, 200], [1000, 600], [800, 600]]
BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim
...@@ -555,12 +581,10 @@ def main(argv=None): ...@@ -555,12 +581,10 @@ def main(argv=None):
_attn = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash) _attn = vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
triton_out = _attn(True)
tilelang_out = _attn(False) tilelang_out = _attn(False)
triton_out = _attn(True)
torch_assert_close(triton_out, tilelang_out, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.0) torch.testing.assert_close(triton_out, tilelang_out, atol=1e-2, rtol=1e-2)
print("Pass topk sparse attention test with qlen == klen")
triton_time = do_bench(lambda: _attn(True)) triton_time = do_bench(lambda: _attn(True))
tilelang_time = do_bench(lambda: _attn(False)) tilelang_time = do_bench(lambda: _attn(False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment