Commit 32060ecd authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Improve flashattn function in example_gqa_decode.py (#329)

- Added a manual seed for reproducibility in PyTorch.
- Refactored local variable allocations for better memory management.
- Enhanced parallel processing in the flashattn function to improve performance.
- Updated layout annotations for clarity and efficiency.

These changes optimize the flash attention mechanism and ensure consistent behavior across runs.
parent 5ee58ec7
...@@ -7,6 +7,7 @@ from einops import rearrange, einsum ...@@ -7,6 +7,7 @@ from einops import rearrange, einsum
import argparse import argparse
import itertools import itertools
torch.random.manual_seed(0)
def get_configs(): def get_configs():
block_N = [64, 128] block_N = [64, 128]
...@@ -203,29 +204,35 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -203,29 +204,35 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype) lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_local_split = T.alloc_var(accum_dtype) lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype) lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype) lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({ T.annotate_layout({
lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local:
T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id)
lse_local:
T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
}) })
# T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
for k in T.Parallel(num_split): for k, j in T.Parallel(num_split, 128):
lse_local[k, 0] = glse[bz, by, k] lse_local[k, j] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1): for k in T.Pipelined(num_split, num_stages=1):
lse_local_split = glse[bz, by, k] lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split - lse_max_local[0]) lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split): for k in T.serial(num_split):
for i in T.Parallel(dim): for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i] po_local[i] = Output_partial[bz, by, k, i]
lse_local_split = glse[bz, by, k] lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split - lse_logsum_local[0]) scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim): for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0] o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim): for i in T.Parallel(dim):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment