Unverified Commit f7a6bd0f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix missing `kv_caches` and `attn_metadata` in `OpenVINOCausalLM` (#14271)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 0ca3b8e0
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# ruff: noqa: SIM117 # ruff: noqa: SIM117
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import Optional
import openvino as ov import openvino as ov
import torch import torch
...@@ -12,8 +12,8 @@ from optimum.intel import OVModelForCausalLM ...@@ -12,8 +12,8 @@ from optimum.intel import OVModelForCausalLM
from torch import nn from torch import nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor, from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states) _prune_hidden_states)
...@@ -24,7 +24,7 @@ from vllm.platforms import current_platform ...@@ -24,7 +24,7 @@ from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
def _flattenize_inputs(inputs): def _flatten_inputs(inputs):
""" """
Helper function for making nested inputs flattens Helper function for making nested inputs flattens
""" """
...@@ -33,10 +33,9 @@ def _flattenize_inputs(inputs): ...@@ -33,10 +33,9 @@ def _flattenize_inputs(inputs):
if input_data is None: if input_data is None:
continue continue
if isinstance(input_data, (list, tuple)): if isinstance(input_data, (list, tuple)):
flatten_inputs.extend(_flattenize_inputs(input_data)) flatten_inputs.extend(_flatten_inputs(input_data))
elif isinstance(input_data, dict): elif isinstance(input_data, dict):
flatten_inputs.extend(_flattenize_inputs(list( flatten_inputs.extend(_flatten_inputs(list(input_data.values())))
input_data.values())))
else: else:
flatten_inputs.append(input_data) flatten_inputs.append(input_data)
return flatten_inputs return flatten_inputs
...@@ -147,15 +146,15 @@ class OpenVINOCausalLM(nn.Module): ...@@ -147,15 +146,15 @@ class OpenVINOCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], kv_caches: list[tuple[ov.Tensor, ov.Tensor]],
attn_metadata: OpenVINOAttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
flatten_kv_cache = _flattenize_inputs(kv_caches) flat_kv_caches = _flatten_inputs(kv_caches)
attn_metadata = get_forward_context().attn_metadata
inputs = [ inputs = [
input_ids, input_ids,
positions, positions,
*flatten_kv_cache, *flat_kv_caches,
attn_metadata.past_lens, attn_metadata.past_lens,
attn_metadata.subsequence_begins, attn_metadata.subsequence_begins,
attn_metadata.block_indices, attn_metadata.block_indices,
......
...@@ -346,6 +346,8 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -346,6 +346,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
input_tokens, input_tokens,
"positions": "positions":
input_positions, input_positions,
"kv_caches":
kv_caches,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
device=self.device), device=self.device),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment