"testing/vscode:/vscode.git/clone" did not exist on "bf90a5f58c1ce9a3f20144368d72b02ed5fbeae6"
Commit 8678aac0 authored by Chenghua's avatar Chenghua Committed by LeiWang1999
Browse files

[Examples] Expand tuning configurations for FlashAttention example (#204)

* [Example] Modify tuning configurations for FlashAttention example

* [Examples] formatting example_gqa_fwd_bshd.py
parent 45559a1f
...@@ -12,20 +12,53 @@ import argparse ...@@ -12,20 +12,53 @@ import argparse
from functools import partial from functools import partial
def get_configs(): class FlashAttentionTuneSpace:
block_M = [128]
block_N = [128] def __init__(
num_stages = [2] self,
threads = [256] block_sizes=(64, 128, 256),
_configs = list(itertools.product(block_M, block_N, num_stages, threads)) thread_options=(128, 256, 512),
num_stages_range=(2, 3),
configs = [{ max_shared_mem=100 * 1024,
'block_M': c[0], warp_alignment=16,
'block_N': c[1], dim=128,
'num_stages': c[2], dtype_bytes=2,
'threads': c[3] ):
} for c in _configs] self.block_sizes = block_sizes
return configs self.thread_options = thread_options
self.num_stages_range = num_stages_range
self.max_shared_mem = max_shared_mem
self.warp_alignment = warp_alignment
self.dim = dim
self.dtype_bytes = dtype_bytes
def get_configs(user_config=None):
config = user_config or FlashAttentionTuneSpace()
valid_configs = []
for block_M, block_N in itertools.product(config.block_sizes, repeat=2):
for threads in config.thread_options:
assert threads % 32 == 0
warp_count = threads // 32
warp_M = block_M // warp_count
warp_N = block_N // warp_count
if (warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0):
continue
shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N)
if shared_mem > config.max_shared_mem:
continue
for num_stages in config.num_stages_range:
valid_configs.append({
"block_M": block_M,
"block_N": block_N,
"num_stages": num_stages,
"threads": threads,
})
return valid_configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment