"vscode:/vscode.git/clone" did not exist on "3bff7958eb6d717cfbdcf69d01317800083e729a"
Unverified Commit 8374387b authored by Vadim Gimpelson's avatar Vadim Gimpelson Committed by GitHub
Browse files

[FlashInfer] Revert block_size 16 + head_size 256 workaround on Blackwell (#36987)


Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
parent 912fbe95
......@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.registry import AttentionBackendEnum
......@@ -148,17 +147,6 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
).page_size_bytes
else:
kernel_block_alignment_size = 16
if (
current_platform.is_device_capability_family(100)
and model_config.get_head_size() == 256
and (
attention_config.backend is None
or attention_config.backend == AttentionBackendEnum.FLASHINFER
)
):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
# head size 256 and block size 16 is not supported on blackwell.
kernel_block_alignment_size = 32
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
......
......@@ -630,15 +630,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.paged_kv_indices = self._make_buffer(max_num_pages)
self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)
if self.head_dim == 256 and current_platform.is_device_capability_family(100):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
# head size 256 and block size 16 is not supported on blackwell.
assert kv_cache_spec.block_size != 16, (
"There is a bug in FlashInfer "
"block_size 16 head size 256 support. Please avoid this combination by "
"passing --block-size 32 or --block-size 64."
)
def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32
) -> CpuGpuBuffer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment