florence2.py 8.89 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import math
4
from typing import Iterable, Optional, Set, Tuple
5
6
7
8

import torch
import torch.nn as nn

9
from vllm.config import VllmConfig
10
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Joe Runde's avatar
Joe Runde committed
11
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
12
13
14
15
16
17
18
19
20
21
22
23
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
                                             BartParallelLMHead,
                                             BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .utils import AutoWeightsLoader


class Florence2LanguageModel(nn.Module):

24
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
25
        super().__init__()
26
27
28
29
30

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

31
32
33
34
35
36
37
38
        self.config = config

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
        self.encoder = BartEncoder(config,
                                   cache_config=cache_config,
39
40
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.encoder")
41
42
        self.decoder = BartDecoder(config,
                                   cache_config=cache_config,
43
44
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.decoder")
45
46
47
48
49
50
51

        if self.config.tie_word_embeddings:
            self.encoder.embed_tokens.weight = self.shared.weight
            self.decoder.embed_tokens.weight = self.shared.weight

    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
                encoder_input_ids: torch.Tensor,
52
                encoder_positions: torch.Tensor) -> torch.Tensor:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        r"""
        Args:
            input_ids
                Indices of *decoder* input sequence tokens in the vocabulary.
                Padding will be ignored by default should you
                provide it.
            positions
                Positions of *decoder* input sequence tokens.
            encoder_input_ids
                Indices of *encoder* input sequence tokens in the vocabulary.
            encoder_positions:
                Positions of *encoder* input sequence tokens.
        Returns:
            Model output torch.Tensor
        """

        encoder_hidden_states = None

        if encoder_input_ids.numel() > 0:
            # Run encoder attention if a non-zero number of encoder tokens
            # are provided as input
            encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
75
                                                 positions=encoder_positions)
76
77
78
79
80
81

        # decoder outputs consists of
        # (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            decoder_input_ids=input_ids,
            decoder_positions=positions,
82
            encoder_hidden_states=encoder_hidden_states)
83
84
85
86
87
88

        return decoder_outputs


class Florence2LanguageForConditionalGeneration(nn.Module):

89
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
90
        super().__init__()
91
92
93

        config = vllm_config.model_config.hf_config

94
        self.config = config
95
        self.model = Florence2LanguageModel(vllm_config=vllm_config,
96
                                            prefix=f"{prefix}.model")
97
98
99
100
101
102
103
104
105
106
        embed_scale = math.sqrt(
            config.d_model) if config.scale_embedding else 1.0

        self.vocab_size = config.vocab_size
        self.lm_head = BartParallelLMHead(self.vocab_size,
                                          config.d_model,
                                          embed_scale=embed_scale)

        self.logits_processor = LogitsProcessor(self.vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
107
        self.sampler = get_sampler()
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        encoder_input_ids: torch.Tensor,
        encoder_positions: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        r"""
        Args:
            input_ids
                torch.Tensor of *decoder* input token ids.
            positions
                torch.Tensor of *decoder* position indices.
            encoder_input_ids
                torch.Tensor of *encoder* input token ids.
            encoder_positions
                torch.Tensor of *encoder* position indices
        Returns:
            Output torch.Tensor
        """
        return self.model(input_ids, positions, encoder_input_ids,
131
                          encoder_positions)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> SamplerOutput:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

147
148
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
149
150
151
152
153
154
155
156
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
157
        loaded_params: Set[str] = set()
158
159
160
161
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
162
163
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
164
165
166
167
168
169
170
171
172
173
174
175
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "final_logits_bias" in name:
                    continue
                if self.config.tie_word_embeddings and "embed_tokens" in name:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
176
177
            loaded_params.add(name)
        return loaded_params
178
179
180
181


class Florence2ForConditionalGeneration(nn.Module):

182
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
183
        super().__init__()
184
        config = vllm_config.model_config.hf_config
185
186
187

        # TODO(Isotr0py): Add vision backbone
        self.language_model = Florence2LanguageForConditionalGeneration(
188
            vllm_config=vllm_config.with_hf_config(config.text_config),
189
            prefix=f"{prefix}.language_model",
190
        )
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    @property
    def sampler(self):
        return self.language_model.sampler

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        *,
        encoder_input_ids: torch.Tensor,
        encoder_positions: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        r"""
        Args:
            input_ids
                torch.Tensor of *decoder* input token ids.
            positions
                torch.Tensor of *decoder* position indices.
            encoder_input_ids
                torch.Tensor of *encoder* input token ids.
            encoder_positions
                torch.Tensor of *encoder* position indices
        Returns:
            Output torch.Tensor
        """
        return self.language_model(input_ids, positions, encoder_input_ids,
220
                                   encoder_positions)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
        return self.language_model.sample(logits, sampling_metadata)

237
238
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
239
240
241
242
243
        skip_prefixes = [
            'image_projection', "vision_tower", "image_proj_norm",
            "image_pos_embed", "visual_temporal_embed"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
244
        return loader.load_weights(weights)