gemma2_embedding.py 1.91 KB
Newer Older
1
from typing import Iterable, List, Optional, Tuple, Union
2
3
4
5
6
7
8
9
10

import torch
from torch import nn

from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

11
12
from .gemma2 import Gemma2Model
from .interfaces import SupportsPP
13

14
15

class Gemma2EmbeddingModel(nn.Module, SupportsPP):
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    """A model that uses Gemma2 with additional embedding functionalities.

   This class encapsulates the Gemma2Model and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of Gemma2Model used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__()
        self.model = Gemma2Model(**kwargs)
        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

34
35
36
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

37
38
39
40
41
42
43
44
    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
45
46
47
    ) -> Union[torch.Tensor, IntermediateTensors]:
        return self.model(input_ids, positions, kv_caches, attn_metadata,
                          intermediate_tensors, inputs_embeds)
48
49
50
51
52
53
54
55
56

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
57
        self.model.load_weights(weights)