Unverified Commit e816a881 authored by yzong-rh's avatar yzong-rh Committed by GitHub
Browse files

[Bugfix] Fix FlashInfer crash with kv_cache_dtype_skip_layers (#39002)


Signed-off-by: default avatarYifan Zong <yzong@redhat.com>
parent e281cb72
......@@ -39,7 +39,7 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, get_kv_quant_mode
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
......@@ -53,7 +53,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_qo_heads: int,
num_kv_heads: int,
head_size: int,
kv_cache_dtype: torch.dtype,
device: torch.device,
vllm_config: VllmConfig,
block_size: int,
......@@ -63,7 +62,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
......@@ -81,13 +79,14 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.block_size = block_size
# Initialize attn MetadataBuilder
# Initialize attn MetadataBuilder (match Attention.get_kv_cache_spec)
self.builder = self.attn.attn_backend.get_builder_cls()(
kv_cache_spec=AttentionSpec(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_dtype,
dtype=self.attn.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.attn.kv_cache_dtype),
),
layer_names=[self.attn.layer_name],
vllm_config=self.vllm_config,
......@@ -126,7 +125,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
# Create dummy KV cache
raw_tensor = torch.zeros(
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
dtype=self.kv_cache_dtype,
dtype=self.attn.kv_cache_torch_dtype,
device=self.device,
)
raw_tensor = raw_tensor.view(kv_cache_shape)
......@@ -348,7 +347,6 @@ def test_attention_quant_pattern(
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused,
block_size=block_size,
......@@ -376,7 +374,6 @@ def test_attention_quant_pattern(
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w,
......
......@@ -63,7 +63,11 @@ from vllm.v1.attention.backends.utils import (
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVQuantMode,
UniformTypeKVCacheSpecs,
)
from vllm.v1.utils import CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
......@@ -600,12 +604,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.head_dim = self.kv_cache_spec.head_size
self.page_size = self.kv_cache_spec.block_size
if self.kv_cache_spec.kv_quant_mode != KVQuantMode.NONE:
self.cache_dtype = self.cache_config.cache_dtype
if is_quantized_kv_cache(self.cache_dtype):
# Cannot use self.kv_cache_spec.dtype here because kv_cache_spec
# storage dtype may not be the same as the op dtype (uint8 vs fp8_e4m3)
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.cache_dtype
)
else:
self.cache_dtype = "auto"
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment