florence2.py 9.89 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import math
4
from typing import Iterable, List, Optional, Set, Tuple
5
6
7
8
9

import torch
import torch.nn as nn

from vllm.attention import AttentionMetadata
10
from vllm.config import VllmConfig
11
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Joe Runde's avatar
Joe Runde committed
12
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
                                             BartParallelLMHead,
                                             BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .utils import AutoWeightsLoader


class Florence2LanguageModel(nn.Module):

25
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
26
        super().__init__()
27
28
29
30
31

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

32
33
34
35
36
37
38
39
        self.config = config

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
        self.encoder = BartEncoder(config,
                                   cache_config=cache_config,
40
41
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.encoder")
42
43
        self.decoder = BartDecoder(config,
                                   cache_config=cache_config,
44
45
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.decoder")
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

        if self.config.tie_word_embeddings:
            self.encoder.embed_tokens.weight = self.shared.weight
            self.decoder.embed_tokens.weight = self.shared.weight

    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
                encoder_input_ids: torch.Tensor,
                encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata) -> torch.Tensor:
        r"""
        Args:
            input_ids
                Indices of *decoder* input sequence tokens in the vocabulary.
                Padding will be ignored by default should you
                provide it.
            positions
                Positions of *decoder* input sequence tokens.
            encoder_input_ids
                Indices of *encoder* input sequence tokens in the vocabulary.
            encoder_positions:
                Positions of *encoder* input sequence tokens.
            kv_caches:
                Layer-wise list of KV cache tensors
            attn_metadata:
                vLLM Attention metadata structure
        Returns:
            Model output torch.Tensor
        """

        encoder_hidden_states = None

        if encoder_input_ids.numel() > 0:
            # Run encoder attention if a non-zero number of encoder tokens
            # are provided as input
            encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
                                                 positions=encoder_positions,
                                                 kv_caches=kv_caches,
                                                 attn_metadata=attn_metadata)

        # decoder outputs consists of
        # (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            decoder_input_ids=input_ids,
            decoder_positions=positions,
            encoder_hidden_states=encoder_hidden_states,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata)

        return decoder_outputs


class Florence2LanguageForConditionalGeneration(nn.Module):

99
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
100
        super().__init__()
101
102
103

        config = vllm_config.model_config.hf_config

104
        self.config = config
105
        self.model = Florence2LanguageModel(vllm_config=vllm_config,
106
                                            prefix=f"{prefix}.model")
107
108
109
110
111
112
113
114
115
116
        embed_scale = math.sqrt(
            config.d_model) if config.scale_embedding else 1.0

        self.vocab_size = config.vocab_size
        self.lm_head = BartParallelLMHead(self.vocab_size,
                                          config.d_model,
                                          embed_scale=embed_scale)

        self.logits_processor = LogitsProcessor(self.vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
117
        self.sampler = get_sampler()
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        encoder_input_ids: torch.Tensor,
        encoder_positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        **kwargs,
    ) -> torch.Tensor:
        r"""
        Args:
            input_ids
                torch.Tensor of *decoder* input token ids.
            positions
                torch.Tensor of *decoder* position indices.
            encoder_input_ids
                torch.Tensor of *encoder* input token ids.
            encoder_positions
                torch.Tensor of *encoder* position indices
            kv_caches:
                Layer-wise list of KV cache tensors
            attn_metadata:
                vLLM Attention metadata structure
        Returns:
            Output torch.Tensor
        """
        return self.model(input_ids, positions, encoder_input_ids,
                          encoder_positions, kv_caches, attn_metadata)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> SamplerOutput:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

163
164
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
165
166
167
168
169
170
171
172
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
173
        loaded_params: Set[str] = set()
174
175
176
177
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
178
179
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
180
181
182
183
184
185
186
187
188
189
190
191
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "final_logits_bias" in name:
                    continue
                if self.config.tie_word_embeddings and "embed_tokens" in name:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
192
193
            loaded_params.add(name)
        return loaded_params
194
195
196
197


class Florence2ForConditionalGeneration(nn.Module):

198
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
199
        super().__init__()
200
        config = vllm_config.model_config.hf_config
201
202
203

        # TODO(Isotr0py): Add vision backbone
        self.language_model = Florence2LanguageForConditionalGeneration(
204
            vllm_config=vllm_config.with_hf_config(config.text_config),
205
            prefix=f"{prefix}.language_model",
206
        )
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

    @property
    def sampler(self):
        return self.language_model.sampler

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        *,
        encoder_input_ids: torch.Tensor,
        encoder_positions: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        r"""
        Args:
            input_ids
                torch.Tensor of *decoder* input token ids.
            positions
                torch.Tensor of *decoder* position indices.
            encoder_input_ids
                torch.Tensor of *encoder* input token ids.
            encoder_positions
                torch.Tensor of *encoder* position indices
            kv_caches:
                Layer-wise list of KV cache tensors
            attn_metadata:
                vLLM Attention metadata structure
        Returns:
            Output torch.Tensor
        """
        return self.language_model(input_ids, positions, encoder_input_ids,
                                   encoder_positions, kv_caches, attn_metadata)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
        return self.language_model.sample(logits, sampling_metadata)

259
260
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
261
262
263
264
265
        skip_prefixes = [
            'image_projection', "vision_tower", "image_proj_norm",
            "image_pos_embed", "visual_temporal_embed"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
266
        return loader.load_weights(weights)