Unverified Commit 70ad3f9e authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix][TPU] Fix V1 TPU worker for sliding window (#16059)


Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent d6fc629f
......@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
......@@ -137,7 +137,7 @@ class TPUWorker:
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
if isinstance(layer_spec, AttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass
......@@ -147,7 +147,8 @@ class TPUWorker:
device=self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError
raise NotImplementedError(
f"Unsupported KV cache spec '{type(layer_spec)}'")
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment