Commit ee989f6d authored by laibao's avatar laibao
Browse files

refactor(rocm): 提取 unified flash 的 block_size 判定逻辑

parent ea9b8584
...@@ -305,17 +305,17 @@ class RocmPlatform(Platform): ...@@ -305,17 +305,17 @@ class RocmPlatform(Platform):
f"is not MLA type while requested for MLA backend." f"is not MLA type while requested for MLA backend."
) )
use_unified_flash = ( is_non64_block_multiple_64 = (
block_size is not None block_size != 64
and block_size != 64
and block_size % 64 == 0 and block_size % 64 == 0
)
use_unified_flash = (
is_non64_block_multiple_64
and head_size == 256 and head_size == 256
) )
if ( if (
envs.VLLM_USE_FLASH_ATTN_PA envs.VLLM_USE_FLASH_ATTN_PA
and block_size is not None and is_non64_block_multiple_64
and block_size != 64
and block_size % 64 == 0
and head_size != 256 and head_size != 256
): ):
logger.info_once( logger.info_once(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment