Unverified Commit 08bfedc1 authored by Yubo Wang's avatar Yubo Wang Committed by GitHub
Browse files

[Bugfix] Fix extract_hidden_states crash with quantized KV cache dtype (#39160)


Signed-off-by: default avatarYubo Wang <yubowang2019@gmail.com>
parent 0102bd2f
...@@ -9,6 +9,7 @@ extract_hidden_states speculative decoding method. ...@@ -9,6 +9,7 @@ extract_hidden_states speculative decoding method.
""" """
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import replace
from typing import ClassVar from typing import ClassVar
import torch import torch
...@@ -352,6 +353,10 @@ class ExtractHiddenStatesModel(nn.Module): ...@@ -352,6 +353,10 @@ class ExtractHiddenStatesModel(nn.Module):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
# Hidden states dtype should be independent of KV cache dtype.
if cache_config is not None and is_quantized_kv_cache(cache_config.cache_dtype):
cache_config = replace(cache_config, cache_dtype="auto")
# Create a single cache-only attention layer # Create a single cache-only attention layer
# Note: We set num_heads <- self.num_hidden_states # Note: We set num_heads <- self.num_hidden_states
# and head_size <- hidden_size so that we can insert # and head_size <- hidden_size so that we can insert
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment