blip.py 14.9 KB
Newer Older
1
2
"""Minimal implementation of BlipVisionModel intended to be only used 
within a vision language model."""
3
from typing import Iterable, Optional, Set, Tuple, Union
4
5
6
7
8
9

import torch
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig

10
from vllm.attention.layer import MultiHeadAttention
11
from vllm.config import ModelConfig
12
from vllm.distributed import divide, get_tensor_model_parallel_world_size
13
from vllm.inputs import DecoderOnlyInputs, token_inputs
14
15
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
16
                                               QKVParallelLinear,
17
18
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20
21
from vllm.multimodal.utils import (cached_get_tokenizer,
                                   repeat_and_pad_placeholder_tokens)
22
from vllm.sequence import SequenceData
23
24
25
26
27
28
29
30
31
32
33
34
35
36


def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
    assert image_size % patch_size == 0
    return image_size // patch_size


def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_blip_patch_grid_length(image_size=image_size,
                                             patch_size=patch_size)
    return grid_length * grid_length


def get_blip_image_feature_size(
37
        hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
38
39
40
41
42
    return get_blip_num_patches(image_size=hf_config.image_size,
                                patch_size=hf_config.patch_size)


def get_max_blip_image_tokens(
43
        hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
44
45
46
47
48
49
    return get_blip_image_feature_size(hf_config)


def dummy_seq_data_for_blip(
    hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
    seq_len: int,
50
    num_images: int,
51
52
53
54
55
56
57
58
59
    *,
    image_token_id: int,
    image_feature_size_override: Optional[int] = None,
):
    if image_feature_size_override is None:
        image_feature_size = get_blip_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

60
    return SequenceData.from_prompt_token_counts(
61
62
63
        (image_token_id, image_feature_size * num_images),
        (0, seq_len - image_feature_size * num_images),
    )
64
65
66
67


def dummy_image_for_blip(
    hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
68
    num_images: int,
69
70
71
72
73
74
75
76
77
78
79
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    width = height = hf_config.image_size
    if image_width_override is not None:
        width = image_width_override
    if image_height_override is not None:
        height = image_height_override

    image = Image.new("RGB", (width, height), color=0)
80
    return {"image": image if num_images == 1 else [image] * num_images}
81
82
83
84
85


def input_processor_for_blip(
    model_config: ModelConfig,
    hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
86
    inputs: DecoderOnlyInputs,
87
88
89
90
    *,
    image_token_id: int,
    image_feature_size_override: Optional[int] = None,
):
91
    multi_modal_data = inputs.get("multi_modal_data")
92
    if multi_modal_data is None or "image" not in multi_modal_data:
93
        return inputs
94

95
96
97
98
99
    if "multi_modal_placeholders" in inputs and "image" in inputs[
            "multi_modal_placeholders"]:
        # The inputs already have placeholders.
        return inputs

100
101
102
103
104
105
106
    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    if image_feature_size_override is None:
        image_feature_size = get_blip_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

107
    new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
108
        tokenizer,
109
110
        inputs.get("prompt"),
        inputs["prompt_token_ids"],
111
        placeholder_token_id=image_token_id,
112
113
114
115
        repeat_count=image_feature_size,
    )

    # NOTE: Create a defensive copy of the original inputs
116
117
    return token_inputs(prompt_token_ids=new_token_ids,
                        prompt=new_prompt,
118
119
                        multi_modal_data=multi_modal_data,
                        multi_modal_placeholders={"image": ranges})
120
121
122
123
124


# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):

125
    def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        super().__init__()

        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=3,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

        self.num_patches = get_blip_num_patches(image_size=self.image_size,
                                                patch_size=self.patch_size)
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
            torch.randn(1, self.num_positions, self.embed_dim))

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)

        position_embeds = self.position_embedding.to(target_dtype)
        embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]

        return embeddings


165
class BlipAttention(nn.Module):
166
167
168
169
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
170
        config: Union[BlipVisionConfig, Blip2VisionConfig],
171
        quant_config: Optional[QuantizationConfig] = None,
172
173
        prefix: str = "",
    ) -> None:
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.num_heads,
            bias=config.qkv_bias,
            quant_config=quant_config,
193
            prefix=f"{prefix}.qkv",
194
195
196
197
198
        )
        self.projection = RowParallelLinear(
            self.embed_dim,
            self.embed_dim,
            quant_config=quant_config,
199
            prefix=f"{prefix}.projection",
200
201
202
203
204
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

205
206
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
207

208
209
210
211
212
213
214
215
216
217
218
219
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads,
                           self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
220
        out = self.attn(query_states, key_states, value_states)
221
222
        attn_output, _ = self.projection(out)

223
        return attn_output, None
224
225


226
227
class BlipMLP(nn.Module):

228
229
230
231
232
233
    def __init__(
        self,
        config: BlipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
234
235
236
237
238
239
240
241
        super().__init__()

        self.config = config

        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
242
243
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
244
245
246
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
247
248
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
249
250
251
252
253
254
255
256
257
258
259

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class BlipEncoderLayer(nn.Module):

260
261
262
263
264
265
    def __init__(
        self,
        config: BlipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
266
267
        super().__init__()

268
        # fallback to sdpa attention if tp unavailable
269
270
271
272
273
        self.self_attn = BlipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
274
275
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
276
277
278
        self.mlp = BlipMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
279
280
281
282
283
284
285
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
286
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class BlipEncoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self 
    attention layers. Each layer is a [`BlipEncoderLayer`].

    Args:
        config: BlipConfig
    """

306
307
308
309
310
311
312
    def __init__(
        self,
        config: BlipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
313
314
315
316
317
318
319
320
321
322
        super().__init__()

        self.config = config

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

        self.layers = nn.ModuleList([
323
324
325
326
            BlipEncoderLayer(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        ])

    def forward(self, inputs_embeds: torch.Tensor):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class BlipVisionModel(nn.Module):
    config_class = BlipVisionConfig
    main_input_name = "pixel_values"

341
342
343
344
345
346
347
348
349
    def __init__(
        self,
        config: BlipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
350
351
352
353
354
355
356
357
        super().__init__()
        self.config = config

        self.embeddings = BlipVisionEmbeddings(config)
        self.encoder = BlipEncoder(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
358
            prefix=f"{prefix}.encoder",
359
        )
360

361
        num_hidden_layers = config.num_hidden_layers
362
363
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
364
                f"The original encoder only has {num_hidden_layers} "
365
366
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
367
368
369
370
371
372

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
373
374
375
376
            self.post_layernorm = nn.LayerNorm(config.hidden_size,
                                               eps=config.layer_norm_eps)
        else:
            self.post_layernorm = None
377
378
379
380
381

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.encoder(inputs_embeds=hidden_states)

382
383
384
        if self.post_layernorm is None:
            return hidden_states

385
        return self.post_layernorm(hidden_states)
386

387
388
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
389
390
391
392
393
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
394
        ]
395
        params_dict = dict(self.named_parameters())
396
        loaded_params: Set[str] = set()
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
        layer_count = len(self.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in BlipVisionModel
            if (name.startswith("post_layernorm")
                    and self.post_layernorm is None):
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
414
415
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
416
417
418
419
420
421
422
423
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
424
425
            loaded_params.add(name)
        return loaded_params