"tests/models/language/generation/test_bart.py" did not exist on "7eb4a51c5f34a52427d170d0e654e0a346c6d69c"
blip.py 11.9 KB
Newer Older
1
2
"""Minimal implementation of BlipVisionModel intended to be only used 
within a vision language model."""
3
from typing import Iterable, Optional, Set, Tuple, Union
4
5
6
7
8

import torch
import torch.nn as nn
from transformers import Blip2VisionConfig, BlipVisionConfig

9
from vllm.attention.layer import MultiHeadAttention
10
from vllm.distributed import divide, get_tensor_model_parallel_world_size
11
12
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
13
                                               QKVParallelLinear,
14
15
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
16
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
    assert image_size % patch_size == 0
    return image_size // patch_size


def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_blip_patch_grid_length(image_size=image_size,
                                             patch_size=patch_size)
    return grid_length * grid_length


# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):

33
    def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        super().__init__()

        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=3,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

        self.num_patches = get_blip_num_patches(image_size=self.image_size,
                                                patch_size=self.patch_size)
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
            torch.randn(1, self.num_positions, self.embed_dim))

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)

        position_embeds = self.position_embedding.to(target_dtype)
        embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]

        return embeddings


73
class BlipAttention(nn.Module):
74
75
76
77
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
78
        config: Union[BlipVisionConfig, Blip2VisionConfig],
79
        quant_config: Optional[QuantizationConfig] = None,
80
81
        prefix: str = "",
    ) -> None:
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.num_heads,
            bias=config.qkv_bias,
            quant_config=quant_config,
101
            prefix=f"{prefix}.qkv",
102
103
104
105
106
        )
        self.projection = RowParallelLinear(
            self.embed_dim,
            self.embed_dim,
            quant_config=quant_config,
107
            prefix=f"{prefix}.projection",
108
109
110
111
112
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

113
114
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
115

116
117
118
119
120
121
122
123
124
125
126
127
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads,
                           self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
128
        out = self.attn(query_states, key_states, value_states)
129
130
        attn_output, _ = self.projection(out)

131
        return attn_output, None
132
133


134
135
class BlipMLP(nn.Module):

136
137
138
139
140
141
    def __init__(
        self,
        config: BlipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
142
143
144
145
146
147
148
149
        super().__init__()

        self.config = config

        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
150
151
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
152
153
154
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
155
156
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
157
158
159
160
161
162
163
164
165
166
167

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class BlipEncoderLayer(nn.Module):

168
169
170
171
172
173
    def __init__(
        self,
        config: BlipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
174
175
        super().__init__()

176
        # fallback to sdpa attention if tp unavailable
177
178
179
180
181
        self.self_attn = BlipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
182
183
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
184
185
186
        self.mlp = BlipMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
187
188
189
190
191
192
193
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
194
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class BlipEncoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self 
    attention layers. Each layer is a [`BlipEncoderLayer`].

    Args:
        config: BlipConfig
    """

214
215
216
217
218
219
220
    def __init__(
        self,
        config: BlipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
221
222
223
224
225
226
227
228
229
230
        super().__init__()

        self.config = config

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

        self.layers = nn.ModuleList([
231
232
233
234
            BlipEncoderLayer(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        ])

    def forward(self, inputs_embeds: torch.Tensor):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class BlipVisionModel(nn.Module):
    config_class = BlipVisionConfig
    main_input_name = "pixel_values"

249
250
251
252
253
254
255
256
257
    def __init__(
        self,
        config: BlipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
258
259
260
261
262
263
264
265
        super().__init__()
        self.config = config

        self.embeddings = BlipVisionEmbeddings(config)
        self.encoder = BlipEncoder(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
266
            prefix=f"{prefix}.encoder",
267
        )
268

269
        num_hidden_layers = config.num_hidden_layers
270
271
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
272
                f"The original encoder only has {num_hidden_layers} "
273
274
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
275
276
277
278
279
280

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
281
282
283
284
            self.post_layernorm = nn.LayerNorm(config.hidden_size,
                                               eps=config.layer_norm_eps)
        else:
            self.post_layernorm = None
285
286
287
288
289

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.encoder(inputs_embeds=hidden_states)

290
291
292
        if self.post_layernorm is None:
            return hidden_states

293
        return self.post_layernorm(hidden_states)
294

295
296
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
297
298
299
300
301
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
302
        ]
303
        params_dict = dict(self.named_parameters())
304
        loaded_params: Set[str] = set()
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        layer_count = len(self.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in BlipVisionModel
            if (name.startswith("post_layernorm")
                    and self.post_layernorm is None):
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
322
323
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
324
325
326
327
328
329
330
331
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
332
333
            loaded_params.add(name)
        return loaded_params