Commit 67d0b677 authored by 徐畅's avatar 徐畅 Committed by LeiWang1999
Browse files

[BugFix] Fix precision issue in GQA decode when block_N exceeds seqlen/num_split (#575)

* [CI] Add flash_decoding example to CI

* Add output of ref latency

* format example_gqa_decode.py

* [BugFix] Fix precision issue in GQA decode when block_N exceeds seqlen/num_split

* format example_gqa_decode.py
parent 837b6398
...@@ -41,6 +41,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -41,6 +41,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
def kernel_func(block_N, block_H, num_split, num_stages, threads): def kernel_func(block_N, block_H, num_split, num_stages, threads):
part_shape = [batch, heads, num_split, dim] part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num) valid_block_H = min(block_H, kv_group_num)
valid_block_N = min(block_N, seqlen_kv // num_split)
@T.macro @T.macro
def flash_attn( def flash_attn(
...@@ -147,15 +148,16 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -147,15 +148,16 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
T.fill(K_shared, 0)
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
K[bid, (seqlen_kv // num_split) * sid + K[bid, (seqlen_kv // num_split) * sid +
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N, k * valid_block_N:(seqlen_kv // num_split) * sid +
cur_kv_head, :], K_shared) (k + 1) * valid_block_N, cur_kv_head, :], K_shared)
T.copy( T.copy(
mask[bid, (seqlen_kv // num_split) * sid + mask[bid, (seqlen_kv // num_split) * sid +
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N, k * valid_block_N:(seqlen_kv // num_split) * sid +
cur_kv_head], mask_local) (k + 1) * valid_block_N, cur_kv_head], mask_local)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(
Q_shared, Q_shared,
...@@ -164,8 +166,9 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -164,8 +166,9 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullRow) policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], acc_s[i, j] = T.if_then_else(
-T.infinity(accum_dtype)) (mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j],
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
...@@ -181,16 +184,17 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -181,16 +184,17 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(
V[bid, (seqlen_kv // num_split) * sid + V[bid, (seqlen_kv // num_split) * sid +
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N, k * valid_block_N:(seqlen_kv // num_split) * sid +
cur_kv_head, :], V_shared) (k + 1) * valid_block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum[:valid_block_H], for i in T.Parallel(block_H):
glse[bid, hid * valid_block_H:(hid + 1) * valid_block_H, sid]) if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H,
sid, :]) sid, :])
...@@ -327,7 +331,7 @@ def ref_program(query, key, value, mask, glse, Output_partial): ...@@ -327,7 +331,7 @@ def ref_program(query, key, value, mask, glse, Output_partial):
def flash_split_ref(Q, K, V, mask): def flash_split_ref(Q, K, V, mask):
num_split = 8 num_split = 16
batch = Q.size(0) batch = Q.size(0)
nheads = Q.size(1) nheads = Q.size(1)
groups = K.size(2) groups = K.size(2)
...@@ -395,7 +399,7 @@ def flash_split_ref(Q, K, V, mask): ...@@ -395,7 +399,7 @@ def flash_split_ref(Q, K, V, mask):
def reduce_ref(Q, K, V, mask, glse, Output_partial): def reduce_ref(Q, K, V, mask, glse, Output_partial):
num_split = 8 num_split = 16
o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0) o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) # [batch, heads] lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) # [batch, heads]
lse_max = glse.max(dim=2, keepdim=False).values lse_max = glse.max(dim=2, keepdim=False).values
...@@ -410,6 +414,37 @@ def reduce_ref(Q, K, V, mask, glse, Output_partial): ...@@ -410,6 +414,37 @@ def reduce_ref(Q, K, V, mask, glse, Output_partial):
return o.to(torch.float16) return o.to(torch.float16)
def ref_split_program(Q, K, V, mask, glse=None, Output_partial=None):
glse_, Output_partial_ = flash_split_ref(Q, K, V, mask)
return reduce_ref(Q, K, V, mask, glse_, Output_partial_)
def print_red_warning(msg):
print(f"\033[91m{msg}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True):
sim = calc_sim(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
if assert_:
raise AssertionError(f'{name} Error: {diff}')
else:
if print_:
print(f'passed: {name} diff={diff}')
def main(batch: int = 1, def main(batch: int = 1,
heads: int = 32, heads: int = 32,
groups: int = 8, groups: int = 8,
...@@ -435,24 +470,45 @@ def main(batch: int = 1, ...@@ -435,24 +470,45 @@ def main(batch: int = 1,
return { return {
"block_N": 128, "block_N": 128,
"block_H": 64, "block_H": 64,
"num_split": 8, "num_split": 16,
"num_stages": 0, "num_stages": 0,
"threads": 128 "threads": 128
} }, sm_version
else: else:
return { return {
"block_N": 128, "block_N": 128,
"block_H": 64, "block_H": 64,
"num_split": 8, "num_split": 16,
"num_stages": 2, "num_stages": 2,
"threads": 128 "threads": 128
} }, sm_version
config = get_heuristic_config() config, sm_version = get_heuristic_config()
program = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config) program = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config)
kernel = tilelang.compile(program, out_idx=[6]) if sm_version == 90:
kernel = tilelang.compile(
program, out_idx=[6], pass_configs={"tl.disable_tma_lower": True})
else:
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8)
glse = torch.empty(batch, heads, 16, device="cuda", dtype=torch.float16)
Output_partial = torch.empty(batch, heads, 16, dim, device="cuda", dtype=torch.float16)
o = kernel(q, k, v, mask, glse, Output_partial)
o_ref = ref_program(q, k, v, mask, glse, Output_partial)
o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial)
assert_similar(o, o_ref)
assert_similar(o_ref_split, o_ref)
torch.testing.assert_close(o, o_ref, rtol=0.01, atol=0.01)
torch.testing.assert_close(o_ref_split, o_ref, rtol=0.01, atol=0.01)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
profiler.assert_allclose(ref_split_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment