"docs/vscode:/vscode.git/clone" did not exist on "51ef828f10acddbe941c38255c5de7f61738abad"
llama_embedding.py 2.04 KB
Newer Older
1
from typing import Iterable, List, Optional, Tuple, Union
2
3
4
5
6
7
8

import torch
from torch import nn

from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
9
from vllm.sequence import IntermediateTensors, PoolerOutput
10

11
from .interfaces import SupportsPP
12
from .llama import LlamaModel
13

14
15

class LlamaEmbeddingModel(nn.Module, SupportsPP):
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    """A model that uses Llama with additional embedding functionalities.

   This class encapsulates the LlamaModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of LlamaModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

    def __init__(
        self,
        **kwargs,
    ) -> None:
        super().__init__()
        self.model = LlamaModel(**kwargs)
        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
33
34
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
35
36
37
38
39
40
41

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
42
        intermediate_tensors: Optional[IntermediateTensors] = None,
43
        inputs_embeds: Optional[torch.Tensor] = None,
44
    ) -> Union[torch.Tensor, IntermediateTensors]:
45
46
        return self.model(input_ids, positions, kv_caches, attn_metadata,
                          intermediate_tensors, inputs_embeds)
47
48
49
50
51
52
53
54
55

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
56
57
58
59
        self.model.load_weights(weights)

    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        self.model.load_kv_cache_scales(quantization_param_path)