Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0d49483e
Unverified
Commit
0d49483e
authored
Jun 06, 2025
by
Chengji Yao
Committed by
GitHub
Jun 06, 2025
Browse files
[TPU] fix kv cache dtype in model runner (#19244)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
90b78ec5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
3 deletions
+9
-3
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+9
-3
No files found.
vllm/v1/worker/tpu_model_runner.py
View file @
0d49483e
...
@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
...
@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange
)
PlaceholderRange
)
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
PallasMetadata
)
PallasMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
@@ -138,6 +139,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -138,6 +139,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
if
cache_config
.
cache_dtype
==
"auto"
:
self
.
kv_cache_dtype
=
self
.
dtype
else
:
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
self
.
_hidden_states_dtype
=
self
.
dtype
self
.
_hidden_states_dtype
=
self
.
dtype
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
...
@@ -480,7 +486,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -480,7 +486,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_size
=
block_size
,
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
dtype
=
self
.
kv_cache_
dtype
,
sliding_window
=
attn_module
.
sliding_window
,
sliding_window
=
attn_module
.
sliding_window
,
use_mla
=
False
,
use_mla
=
False
,
)
)
...
@@ -489,7 +495,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -489,7 +495,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_size
=
block_size
,
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
dtype
=
self
.
kv_cache_
dtype
,
use_mla
=
False
,
use_mla
=
False
,
)
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment