Unverified Commit 3933f18a authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Avoid too small block m/n for FlexAttention kernel option (#27853)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent e5ef4dfc
...@@ -896,6 +896,8 @@ def get_kernel_options( ...@@ -896,6 +896,8 @@ def get_kernel_options(
return kernel_options return kernel_options
else: else:
preferred_block = 32 if query.dtype == torch.float32 else 64 preferred_block = 32 if query.dtype == torch.float32 else 64
block_lower_bound = 16
block_m_candidate = ensure_divisible(preferred_block, block_m) block_m_candidate = ensure_divisible(preferred_block, block_m)
block_n_candidate = ensure_divisible(preferred_block, block_n) block_n_candidate = ensure_divisible(preferred_block, block_n)
...@@ -910,6 +912,9 @@ def get_kernel_options( ...@@ -910,6 +912,9 @@ def get_kernel_options(
max(1, block_n_candidate // 2), block_n max(1, block_n_candidate // 2), block_n
) )
block_m_candidate = max(block_m_candidate, block_lower_bound)
block_n_candidate = max(block_n_candidate, block_lower_bound)
kernel_options["BLOCK_M"] = block_m_candidate kernel_options["BLOCK_M"] = block_m_candidate
kernel_options["BLOCK_N"] = block_n_candidate kernel_options["BLOCK_N"] = block_n_candidate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment