lfm2_siglip2.py 16.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of Siglip2VisionModel intended to be only used
within a vision language model."""

from collections.abc import Iterable

import torch
from torch import nn
from torch.nn import functional as F
from transformers import Siglip2VisionConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
16
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
17
18
19
20
21
22
23
24
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

25
from .vision import is_vit_use_data_parallel, should_torch_compile_mm_vit
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


class Siglip2VisionEmbeddings(nn.Module):
    def __init__(self, config: Siglip2VisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.patch_size = config.patch_size
        self.patch_embedding = nn.Linear(
            in_features=config.num_channels * self.patch_size * self.patch_size,
            out_features=self.embed_dim,
        )
        self.num_patches = config.num_patches
        self.position_embedding_size = int(self.num_patches**0.5)
        self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    def forward(
        self,
        pixel_values_packed: torch.FloatTensor,
        spatial_shapes: torch.LongTensor,
    ) -> torch.Tensor:
        """Embed patchified pixel values in packed (unpadded) form.

        Args:
            pixel_values_packed: (1, total_tokens, patch_dim) or
                (total_tokens, patch_dim), packed in tile order.
            spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile.

        Returns:
            (1, total_tokens, embed_dim) packed embeddings.
        """
        assert spatial_shapes.device.type == "cpu", (
            "Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
            "variable-length packing."
        )

        if pixel_values_packed.dim() == 3:
            assert pixel_values_packed.shape[0] == 1
            pixel_values_flat = pixel_values_packed[0]
        else:
            pixel_values_flat = pixel_values_packed

        lengths = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(dtype=torch.int64)
        lengths_list = lengths.tolist()
        total_tokens = int(sum(lengths_list))
        if total_tokens != pixel_values_flat.shape[0]:
            raise ValueError(
                "Packed pixel_values token count does not match spatial_shapes: "
                f"{pixel_values_flat.shape[0]} vs {total_tokens}."
            )

        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values_flat.to(dtype=target_dtype))

        positional_embeddings = self.position_embedding.weight.reshape(
            self.position_embedding_size, self.position_embedding_size, -1
        )
        packed_pos_embeds = self.resize_positional_embeddings_packed(
            positional_embeddings,
            spatial_shapes,
            lengths_list=lengths_list,
        )

        embeddings = patch_embeds + packed_pos_embeds
        return embeddings.unsqueeze(0)

92
    @staticmethod
93
    def resize_positional_embeddings_packed(
94
95
        positional_embeddings: torch.Tensor,
        spatial_shapes: torch.LongTensor,
96
        lengths_list: list[int],
97
    ) -> torch.Tensor:
98
        """Resize positional embeddings per image and return a packed tensor.
99
100

        Args:
101
102
103
            positional_embeddings: (height, width, embed_dim) base grid.
            spatial_shapes: (batch_size, 2) on CPU, (height, width) per image.
            lengths_list: flattened token length per image (height * width).
104
105

        Returns:
106
107
            (total_tokens, embed_dim) packed positional embeddings, concatenated
            in the same order as `lengths_list`.
108
        """
109
110
        assert spatial_shapes.device.type == "cpu"

111
112
113
        embed_dim = positional_embeddings.shape[-1]
        source_dtype = positional_embeddings.dtype

114
115
116
        total_tokens = int(sum(lengths_list))
        packed_pos_embeds = torch.empty(
            (total_tokens, embed_dim),
117
118
119
120
            device=positional_embeddings.device,
            dtype=source_dtype,
        )

121
122
        # (height, width, embed_dim) -> (1, embed_dim, height, width)
        pos_4d = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
123
124

        # Upcast to float32 on CPU because antialias is not supported for
125
126
127
128
129
130
131
132
133
134
135
        # bfloat16/float16 on CPU.
        if pos_4d.device.type == "cpu":
            pos_4d = pos_4d.to(torch.float32)

        offset = 0
        for i, length in enumerate(lengths_list):
            if length <= 0:
                continue
            height, width = spatial_shapes[i].tolist()
            resized = F.interpolate(
                pos_4d,
136
137
138
139
140
                size=(height, width),
                mode="bilinear",
                align_corners=False,
                antialias=True,
            )
141
142
143
144
            resized = resized.reshape(embed_dim, height * width).transpose(0, 1)
            resized = resized.to(source_dtype)
            packed_pos_embeds[offset : offset + length] = resized
            offset += length
145

146
        return packed_pos_embeds
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171


class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

172
        use_data_parallel = is_vit_use_data_parallel()
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
        assert self.num_heads % tp_size == 0
        self.num_heads_per_partition = self.num_heads // tp_size

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
            disable_tp=use_data_parallel,
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
            disable_tp=use_data_parallel,
        )
        self.attn = MMEncoderAttention(
            num_heads=self.num_heads_per_partition,
            head_size=self.head_dim,
            scale=self.scale,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_seqlen: int | torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(
            hidden_states
        )  # batch_size, q_len, 3 * num_heads_per_partition * head_dim
        bsz, q_len, _ = qkv.shape
        query_states, key_states, value_states = qkv.chunk(3, dim=-1)
        query_states = query_states.view(
            bsz, q_len, self.num_heads_per_partition, self.head_dim
        )
        key_states = key_states.view(
            bsz, q_len, self.num_heads_per_partition, self.head_dim
        )
        value_states = value_states.view(
            bsz, q_len, self.num_heads_per_partition, self.head_dim
        )

        # Use unified MultiHeadAttention implementation
        out = self.attn(
            query=query_states,
            key=key_states,
            value=value_states,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        out = out.reshape(bsz, q_len, -1)
        attn_output, _ = self.out_proj(out)
        return attn_output


class Siglip2MLP(nn.Module):
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
243
        use_data_parallel = is_vit_use_data_parallel()
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


@support_torch_compile(
    dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0},
    enable_if=should_torch_compile_mm_vit,
)
class Siglip2EncoderLayer(nn.Module):
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = Siglip2Attention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Siglip2MLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_seqlen: int | torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
            cu_seqlens: Cumulative sequence lengths tensor.
            max_seqlen: Maximum sequence length.
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class Siglip2Encoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers`
    self attention layers. Each layer is a [`Siglip2EncoderLayer`].

    Args:
        config: PretrainedConfig
    """

    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList(
            [
                Siglip2EncoderLayer(
                    config=config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{idx}",
                )
                for idx in range(config.num_hidden_layers)
            ]
        )

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_seqlen: int | torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
                cu_seqlens=cu_seqlens,
                max_seqlen=max_seqlen,
            )
            hidden_states = layer_outputs
        return hidden_states


class Siglip2VisionTransformer(nn.Module):
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        embed_dim = config.hidden_size
        self.config = config
        self.embeddings = Siglip2VisionEmbeddings(config)
        # Keep the import local to avoid circular dependencies during model init.
        from vllm.compilation.backends import set_model_tag

        with set_model_tag("Siglip2Encoder", is_encoder=True):
            self.encoder = Siglip2Encoder(
                config,
                quant_config=quant_config,
                prefix=f"{prefix}.encoder",
            )
        num_hidden_layers = config.num_hidden_layers
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {num_hidden_layers} "
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )

        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
400
        pixel_values_packed: torch.FloatTensor,
401
402
        spatial_shapes: torch.LongTensor,
        cu_seqlens: torch.Tensor,
403
        max_seqlen: torch.Tensor,
404
405
406
407
    ) -> torch.Tensor:
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
408
        of the input images.
409
        """
410
        hidden_states = self.embeddings(pixel_values_packed, spatial_shapes)
411
412
413
414
415
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
416
        return self.post_layernorm(encoder_outputs)
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435


class Siglip2Model(torch.nn.Module):
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()

        self.vision_model = Siglip2VisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
        )

    def forward(
        self,
436
        pixel_values_packed: torch.FloatTensor,
437
438
        spatial_shapes: torch.LongTensor,
        cu_seqlens: torch.Tensor,
439
        max_seqlen: torch.Tensor,
440
441
    ) -> torch.Tensor:
        return self.vision_model(
442
            pixel_values_packed=pixel_values_packed,
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
            spatial_shapes=spatial_shapes,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params