Unverified Commit 34e3494e authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix failing `MyGemma2Embedding` test (#13820)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent f75aa727
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
......@@ -37,16 +36,12 @@ class MyGemma2Embedding(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment