jina_vl.py 4.96 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping

import torch
import torch.nn as nn
7
from transformers import BatchFeature
8

9
from vllm.config import ModelConfig, VllmConfig
10
11
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
12
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
13
from vllm.model_executor.layers.pooler import DispatchPooler
14
from vllm.multimodal import MULTIMODAL_REGISTRY
15
from vllm.sequence import IntermediateTensors
16

17
18
19
20
21
22
23
from .interfaces import SupportsCrossEncoding, SupportsMultiModal, SupportsScoreTemplate
from .qwen2_vl import (
    Qwen2VLDummyInputsBuilder,
    Qwen2VLForConditionalGeneration,
    Qwen2VLMultiModalProcessor,
    Qwen2VLProcessingInfo,
)
24
25
26
27
28
29
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix

logger = init_logger(__name__)


class JinaVLScorer(nn.Module):
30
    def __init__(self, model_config: "ModelConfig", prefix: str = ""):
31
        super().__init__()
32
        config = model_config.hf_config.get_text_config()
33
        head_dtype = model_config.head_dtype
34
        self.dense = ColumnParallelLinear(
35
36
37
38
39
            config.hidden_size,
            config.hidden_size,
            params_dtype=head_dtype,
            bias=True,
            prefix=f"{prefix}.dense",
40
41
        )
        self.out_proj = RowParallelLinear(
42
43
44
45
46
            config.hidden_size,
            config.num_labels,
            params_dtype=head_dtype,
            bias=True,
            prefix=f"{prefix}.out_proj",
47
        )
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    def forward(self, x, **kwargs):
        x, _ = self.dense(x)
        x = torch.relu(x)
        x, _ = self.out_proj(x)
        return x


class JinaVLMultiModalProcessor(Qwen2VLMultiModalProcessor):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # NOTE: We should reverse the order of the mm_data because the
        # query prompt is placed after the document prompt in the score
        # template for JinaVLForRanking model, but in mm_data they are
        # stored in the opposite order (query first, then document).
        for _, value in mm_data.items():
            value.reverse()
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)


@MULTIMODAL_REGISTRY.register_processor(
    JinaVLMultiModalProcessor,
    info=Qwen2VLProcessingInfo,
    dummy_inputs=Qwen2VLDummyInputsBuilder,
)
class JinaVLForSequenceClassification(
    Qwen2VLForConditionalGeneration,
    SupportsCrossEncoding,
    SupportsMultiModal,
    SupportsScoreTemplate,
):
84
    is_pooling_model = True
85
86
87
88
89
90
91
92
93
94
    weight_mapper = WeightsMapper(
        orig_to_new_prefix={
            "score.0.": "score.dense.",
            "score.2.": "score.out_proj.",
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
95
96
        }
    )
97
98

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
99
100
101
        super().__init__(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "qwen2_vl")
        )
102
        pooler_config = vllm_config.model_config.pooler_config
103
        assert pooler_config is not None
104

105
106
107
        self.score = JinaVLScorer(
            vllm_config.model_config, prefix=maybe_prefix(prefix, "score")
        )
108
        self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
109
110

    @classmethod
111
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
112
113
114
115
116
117
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"

        raise ValueError("Only image modality is supported")

    @classmethod
118
    def get_score_template(cls, query: str, document: str) -> str | None:
119
120
121
122
123
        return f"**Document**:\n{document}\n**Query**:\n{query}"

    @classmethod
    def post_process_tokens(cls, prompt: TokensPrompt) -> None:
        # add score target token at the end of prompt tokens
124
        prompt["prompt_token_ids"].append(100)
125
126
127

    def forward(
        self,
128
        input_ids: torch.Tensor | None,
129
        positions: torch.Tensor,
130
131
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
132
133
134
135
136
137
138
139
140
        **kwargs: object,
    ) -> torch.Tensor:
        hidden_states = super().forward(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )
141
        return hidden_states
142
143
144
145

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.weight_mapper)