Unverified Commit 34e3494e authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix failing `MyGemma2Embedding` test (#13820)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent f75aa727
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model from vllm.model_executor.models.gemma2 import Gemma2Model
...@@ -37,16 +36,12 @@ class MyGemma2Embedding(nn.Module): ...@@ -37,16 +36,12 @@ class MyGemma2Embedding(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment