jina.py 3.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/jinaai/jina-reranker-v3/blob/main/modeling.py
from collections.abc import Iterable

import torch
from torch import nn

from vllm.config import VllmConfig
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata

from ..layers.pooler import DispatchPooler
from ..layers.pooler.tokwise import (
    StepPool,
    TokenPooler,
    TokenPoolingMethodOutputItem,
)
from .interfaces import SupportsLateInteraction
from .qwen3 import Qwen3Model
from .utils import AutoWeightsLoader, maybe_prefix


class JinaForRanking(nn.Module, SupportsLateInteraction):
    is_pooling_model = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.projector_dim: int = config.embedding_size

        self.vllm_config = vllm_config
        self.quant_config = quant_config
        self.model = Qwen3Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )

        self.projector = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2, bias=False),
            nn.ReLU(),
            nn.Linear(config.hidden_size // 2, self.projector_dim, bias=False),
        )

        self.pooler = DispatchPooler(
            {
                "token_embed": TokenPooler(
                    pooling=JinaForRankingPool(self.projector),
                )
            }
        )

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, skip_prefixes=(["lm_head."]))
        return loader.load_weights(weights)


class JinaForRankingPool(StepPool):
    def __init__(self, projector: nn.Sequential):
        super().__init__()

        self.doc_token_id = 151670
        self.query_token_id = 151671
        self.projector = projector

    def get_supported_tasks(self) -> set[PoolingTask]:
        return {"token_embed"}

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> list[TokenPoolingMethodOutputItem]:
        pooled_data_lst = super().forward(hidden_states, pooling_metadata)
        prompt_token_ids = pooling_metadata.get_prompt_token_ids()

        embeds_list = list[torch.Tensor | None]()
        for data, token_ids in zip(pooled_data_lst, prompt_token_ids):
            # for unfinished chunked prefill
            if data is None:
                embeds_list.append(None)
            else:
                docs_indexes = torch.where(torch.eq(token_ids, self.doc_token_id))[0]
                query_indexes = torch.where(torch.eq(token_ids, self.query_token_id))[0]

                # The JinaForRanking model concatenates docs first, then query.
                # Let's stay consistent with this novel design.
                indexes = torch.cat([docs_indexes, query_indexes])
                embeds = self.projector(data[indexes])
                embeds_list.append(embeds)

        return embeds_list