jina_vl.py 5.48 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from typing import Optional

import torch
import torch.nn as nn
8
from transformers import BatchFeature
9

10
from vllm.config import ModelConfig, VllmConfig
11
12
13
14
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
15
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
16
from vllm.multimodal import MULTIMODAL_REGISTRY
17
from vllm.sequence import IntermediateTensors
18
19
20
21
22
23
24
25
26
27
28
29
30

from .interfaces import (SupportsCrossEncoding, SupportsMultiModal,
                         SupportsScoreTemplate)
from .qwen2_vl import (Qwen2VLDummyInputsBuilder,
                       Qwen2VLForConditionalGeneration,
                       Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo)
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix

logger = init_logger(__name__)


class JinaVLScorer(nn.Module):

31
    def __init__(self, model_config: "ModelConfig"):
32
        super().__init__()
33
34
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
35
36
        self.dense = ColumnParallelLinear(config.hidden_size,
                                          config.hidden_size,
37
                                          params_dtype=head_dtype,
38
39
40
                                          bias=True)
        self.out_proj = RowParallelLinear(config.hidden_size,
                                          config.num_labels,
41
                                          params_dtype=head_dtype,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                                          bias=True)

    def forward(self, x, **kwargs):
        x, _ = self.dense(x)
        x = torch.relu(x)
        x, _ = self.out_proj(x)
        return x


class JinaVLMultiModalProcessor(Qwen2VLMultiModalProcessor):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:

        # NOTE: We should reverse the order of the mm_data because the
        # query prompt is placed after the document prompt in the score
        # template for JinaVLForRanking model, but in mm_data they are
        # stored in the opposite order (query first, then document).
        for _, value in mm_data.items():
            value.reverse()
        return super()._call_hf_processor(prompt, mm_data, mm_kwargs,
                                          tok_kwargs)


@MULTIMODAL_REGISTRY.register_processor(JinaVLMultiModalProcessor,
                                        info=Qwen2VLProcessingInfo,
                                        dummy_inputs=Qwen2VLDummyInputsBuilder)
class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
                                      SupportsCrossEncoding,
                                      SupportsMultiModal,
                                      SupportsScoreTemplate):
78
79

    is_pooling_model = True
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    weight_mapper = WeightsMapper(
        orig_to_new_prefix={
            "score.0.": "score.dense.",
            "score.2.": "score.out_proj.",
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config,
                         prefix=maybe_prefix(prefix, "qwen2_vl"))
        pooler_config = vllm_config.model_config.pooler_config
96
        assert pooler_config is not None
97

98
        self.score = JinaVLScorer(vllm_config.model_config)
99
100
101
102
        self.pooler = DispatchPooler({
            "encode":
            Pooler.for_encode(pooler_config),
            "classify":
103
            Pooler.for_classify(pooler_config, classifier=self.score),
104
            "score":
105
            Pooler.for_classify(pooler_config, classifier=self.score),
106
        })
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"

        raise ValueError("Only image modality is supported")

    @classmethod
    def get_score_template(cls, query: str, document: str) -> Optional[str]:
        return f"**Document**:\n{document}\n**Query**:\n{query}"

    @classmethod
    def post_process_tokens(cls, prompt: TokensPrompt) -> None:

        # add score target token at the end of prompt tokens
        prompt['prompt_token_ids'].append(100)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> torch.Tensor:
        hidden_states = super().forward(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )
140
        return hidden_states
141
142
143
144

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.weight_mapper)