Unverified Commit 0d49483e authored by Chengji Yao's avatar Chengji Yao Committed by GitHub
Browse files

[TPU] fix kv cache dtype in model runner (#19244)


Signed-off-by: default avatarChengji Yao <chengjiyao@google.com>
parent 90b78ec5
...@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, ...@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange) PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasMetadata) PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
...@@ -138,6 +139,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -138,6 +139,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype self._hidden_states_dtype = self.dtype
self.is_multimodal_model = model_config.is_multimodal_model self.is_multimodal_model = model_config.is_multimodal_model
...@@ -480,7 +486,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -480,7 +486,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype, dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window, sliding_window=attn_module.sliding_window,
use_mla=False, use_mla=False,
) )
...@@ -489,7 +495,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -489,7 +495,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype, dtype=self.kv_cache_dtype,
use_mla=False, use_mla=False,
) )
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment