bagel.py 20.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
"""Inference-only BAGEL model compatible with HuggingFace weights.

BAGEL is a unified multimodal model for image understanding and generation.
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
"""

from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, TypeAlias

import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
33
    BaseDummyInputsBuilder,
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.bagel import BagelProcessor
from vllm.utils.tensor_schema import TensorSchema

from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .siglip import SiglipVisionModel
from .utils import (
    AutoWeightsLoader,
51
    StageMissingLayer,
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)

logger = init_logger(__name__)


class BagelImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """

    type: Literal["pixel_values"]
    pixel_values: torch.Tensor  # Shape: (bn, 3, h, w)


BagelImageInputs: TypeAlias = BagelImagePixelInputs


class BagelVisionMLP(nn.Module):
    """MLP connector for vision features."""

    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        act_layer: str = "gelu_pytorch_tanh",
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.act = get_act_fn(act_layer)
        self.fc2 = RowParallelLinear(
            hidden_features,
            out_features,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc1(x)
        x = self.act(x)
        x, _ = self.fc2(x)
        return x


class PositionEmbedding(nn.Module):
    """2D position embedding for vision tokens using sin-cos embeddings."""

    def __init__(self, max_num_patch_per_side: int, hidden_size: int):
        super().__init__()
        self.max_num_patch_per_side = max_num_patch_per_side
        self.hidden_size = hidden_size

        # Create learnable 2D position embeddings (frozen sin-cos)
        pos_embed = self._get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side)
        self.register_buffer(
            "pos_embed",
            torch.from_numpy(pos_embed).float(),
            persistent=False,
        )

    @staticmethod
    def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int):
        """Generate 2D sin-cos position embeddings."""
        import numpy as np

        grid_h = np.arange(grid_size, dtype=np.float32)
        grid_w = np.arange(grid_size, dtype=np.float32)
        grid = np.meshgrid(grid_w, grid_h)  # w goes first
        grid = np.stack(grid, axis=0)
        grid = grid.reshape([2, 1, grid_size, grid_size])
        pos_embed = PositionEmbedding._get_2d_sincos_pos_embed_from_grid(
            embed_dim, grid
        )
        return pos_embed

    @staticmethod
    def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid):
        """Generate 2D sin-cos position embeddings from grid."""
        import numpy as np

        assert embed_dim % 2 == 0
        # use half of dimensions to encode grid_h
        emb_h = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
            embed_dim // 2, grid[0]
        )
        emb_w = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
            embed_dim // 2, grid[1]
        )
        emb = np.concatenate([emb_h, emb_w], axis=1)
        return emb

    @staticmethod
    def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos):
        """Generate 1D sin-cos position embeddings."""
        import numpy as np

        assert embed_dim % 2 == 0
        omega = np.arange(embed_dim // 2, dtype=np.float64)
        omega /= embed_dim / 2.0
        omega = 1.0 / 10000**omega

        pos = pos.reshape(-1)
        out = np.einsum("m,d->md", pos, omega)

        emb_sin = np.sin(out)
        emb_cos = np.cos(out)
        emb = np.concatenate([emb_sin, emb_cos], axis=1)
        return emb

    def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            position_ids: Flattened position IDs, shape (N,) where each ID
                         corresponds to a position in the flattened grid
        Returns:
            Position embeddings of shape (N, hidden_size)
        """
        # Ensure position_ids are on the same device as pos_embed
        position_ids = position_ids.to(self.pos_embed.device)
        return self.pos_embed[position_ids]


class BagelProcessingInfo(BaseProcessingInfo):
    """Processing information for BAGEL model."""

    def get_hf_processor(self, **kwargs: object) -> BagelProcessor:
        from vllm.transformers_utils.processor import cached_get_image_processor

        image_processor = cached_get_image_processor(
            self.ctx.model_config.model,
            revision=self.ctx.model_config.revision,
            trust_remote_code=self.ctx.model_config.trust_remote_code,
        )

        tokenizer = self.get_tokenizer()

        return BagelProcessor(
            image_processor=image_processor,
            tokenizer=tokenizer,
            **kwargs,
        )

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        hf_config = self.get_hf_config()
        # Calculate max tokens per image
        # For BAGEL: (vit_max_num_patch_per_side) ** 2
        max_num_patches = hf_config.vit_max_num_patch_per_side**2
        return {"image": max_num_patches}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vit_config = hf_config.vit_config
        patch_size = vit_config.patch_size

        # Calculate number of patches
        num_patches_h = image_height // patch_size
        num_patches_w = image_width // patch_size
        return num_patches_h * num_patches_w


class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]):
    """Build dummy inputs for BAGEL model profiling."""

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        # Use a simple placeholder for each image
        return "<|image_pad|>" * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        hf_config = self.info.get_hf_config()
        vit_config = hf_config.vit_config

        # Use the configured image size
        image_size = vit_config.image_size
        image_overrides = mm_options.get("image") if mm_options else None

        return {
            "image": self._get_dummy_images(
                width=image_size,
                height=image_size,
                num_images=num_images,
                overrides=image_overrides,
            ),
        }


class BagelMultiModalProcessor(BaseMultiModalProcessor[BagelProcessingInfo]):
    """Multimodal processor for BAGEL model."""

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptReplacement]:
        """Replace image placeholders with the correct number of tokens."""
        hf_config = self.info.get_hf_config()

        # Get the tokenizer to look up the image token ID
        tokenizer = self.info.get_tokenizer()
        image_token_id = tokenizer.get_vocab().get("<|image_pad|>")
        if image_token_id is None:
            raise ValueError(
                "Image token '<|image_pad|>' not found in tokenizer vocabulary"
            )

        def get_replacement_bagel(item_idx: int):
            # For BAGEL, calculate number of tokens based on max patch size
            num_tokens = hf_config.vit_max_num_patch_per_side**2
            # Use the image token ID from tokenizer
            return [image_token_id] * num_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement_bagel,
            )
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: Any,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return {
            "pixel_values": MultiModalFieldConfig.batched("image"),
        }


@MULTIMODAL_REGISTRY.register_processor(
    BagelMultiModalProcessor,
    info=BagelProcessingInfo,
    dummy_inputs=BagelDummyInputsBuilder,
)
class BagelForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
):
    """
    BAGEL: A unified multimodal model for image understanding and generation.

    For vLLM, we focus on the image understanding (vision-to-text) capabilities.
    The image generation part is not supported in vLLM.
    """

    # Weight mapping from HF to vLLM
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "language_model.": "language_model.",
            "vit_model.": "vit_model.",
            "connector.": "connector.",
            "vit_pos_embed.": "vit_pos_embed.",
        }
    )

350
351
352
353
354
355
356
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<|image_pad|>"

        raise ValueError("Only image modality is supported")

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        # Ensure we have a BagelConfig (check by name to handle trust_remote_code)
        # When trust_remote_code=True, the config comes from transformers_modules
        if type(config).__name__ != "BagelConfig":
            raise ValueError(
                f"Expected BagelConfig, got {type(config).__name__}. "
                "Make sure the model config is properly loaded."
            )

        self.config = config
        self.multimodal_config = multimodal_config

        # Initialize language model (Qwen2)
        # Pass the llm_config from BagelConfig to initialize Qwen2 properly
377
378
379
380
381
382
383
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.llm_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402

        # Initialize vision model (SigLIP) if visual understanding is enabled
        if config.visual_und:
            # Fix vit_config: checkpoint has 26 layers (0-25) but config says 27
            # Also disable head as it's not in checkpoint
            vit_config = config.vit_config
            if vit_config.num_hidden_layers == 27:
                logger.warning(
                    "Overriding vit_config.num_hidden_layers from 27 to 26 "
                    "to match the Bagel model checkpoint."
                )
                vit_config.num_hidden_layers = 26
            if not hasattr(vit_config, "vision_use_head"):
                logger.warning(
                    "Setting vit_config.vision_use_head to False as it is not "
                    "present in the Bagel model checkpoint."
                )
                vit_config.vision_use_head = False

403
404
405
406
407
408
            with self._mark_tower_model(vllm_config, "image"):
                self.vit_model = SiglipVisionModel(
                    config=vit_config,
                    quant_config=quant_config,
                    prefix=maybe_prefix(prefix, "vit_model"),
                )
409

410
411
412
413
414
415
416
417
418
419
420
421
                # Initialize connector (MLP)
                vit_hidden_size = config.vit_config.hidden_size
                llm_hidden_size = config.llm_config.hidden_size

                self.connector = BagelVisionMLP(
                    in_features=vit_hidden_size,
                    hidden_features=llm_hidden_size,
                    out_features=llm_hidden_size,
                    act_layer=config.connector_act,
                    quant_config=quant_config,
                    prefix=maybe_prefix(prefix, "connector"),
                )
422

423
424
425
426
427
                # Position embedding for vision tokens
                self.vit_pos_embed = PositionEmbedding(
                    max_num_patch_per_side=config.vit_max_num_patch_per_side,
                    hidden_size=llm_hidden_size,
                )
428
        else:
429
430
431
            self.vit_model = StageMissingLayer("image_tower")
            self.connector = StageMissingLayer("image_tower")
            self.vit_pos_embed = StageMissingLayer("image_tower")
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> BagelImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)

        if pixel_values is None:
            return None

        return BagelImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
        )

    def _process_image_input(
        self, image_input: BagelImageInputs
    ) -> tuple[torch.Tensor, ...]:
        """Process image inputs through vision encoder and connector."""
        pixel_values = image_input["pixel_values"]

        # Handle potential extra batch dimension
        # Expected shape: (batch_size * num_images, 3, H, W)
        # But might receive: (batch_size, num_images, 3, H, W)
        if pixel_values.ndim == 5:
            # Flatten batch and num_images dimensions
            batch_size, num_images, channels, height, width = pixel_values.shape
            pixel_values = pixel_values.reshape(
                batch_size * num_images, channels, height, width
            )

        # Get vision features from SigLIP
        # pixel_values shape: (batch_size * num_images, 3, H, W)
        vision_features = self.vit_model(pixel_values)

        # Pass through connector
        vision_embeds = self.connector(vision_features)

        # Add position embeddings
        batch_size, num_patches, hidden_size = vision_embeds.shape
        patch_size = self.config.vit_config.patch_size
        image_size = self.config.vit_config.image_size

        # Calculate grid dimensions
        num_patches_per_side = image_size // patch_size

        # Create flattened position IDs (0 to num_patches-1)
        # For BAGEL, we use extrapolate mode by default
        h_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
        w_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
        position_ids = (
            h_coords[:, None] * self.config.vit_max_num_patch_per_side + w_coords
        ).flatten()
        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).flatten()

        # Add position embeddings
        pos_embeds = self.vit_pos_embed(position_ids)
        pos_embeds = pos_embeds.reshape(batch_size, num_patches, hidden_size)
        # Ensure pos_embeds are on the same device as vision_embeds
        pos_embeds = pos_embeds.to(vision_embeds.device)
        vision_embeds = vision_embeds + pos_embeds

        # Split by image
        return tuple(vision_embeds)

500
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
501
502
503
504
505
506
507
508
509
        """Get multimodal embeddings from input."""
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        return self._process_image_input(image_input)

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
510
        input_ids: torch.Tensor,
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        """Run forward pass for BAGEL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a batch.
            positions: Flattened (concatenated) position ids corresponding to a batch.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """Load weights from checkpoint."""
        # Skip generation-related weights since we only support text2text and image2text
        # Filter out all image generation components:
        # - 'moe_gen': MoE generation weights
        # - 'latent_pos_embed': Latent position embeddings for VAE
        # - 'llm2vae', 'vae2llm': LLM-VAE projections
        # - 'time_embedder': Timestep embeddings for diffusion
        # - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder
        generation_keywords = [
            "moe_gen",
            "latent_pos_embed",
            "llm2vae",
            "vae2llm",
            "time_embedder",
        ]
        vae_prefixes = [
            "decoder.",
            "encoder.",
        ]  # VAE encoder/decoder, not vision encoder
        filtered_weights = []
        for name, tensor in weights:
            # Skip generation-related keywords
            if any(skip in name for skip in generation_keywords):
                continue
            if any(name.startswith(prefix) for prefix in vae_prefixes):
                continue

            if "patch_embedding.weight" in name and tensor.ndim == 2:
                out_channels = tensor.shape[0]
                in_features = tensor.shape[1]
                patch_size = self.config.vit_config.patch_size
                in_channels = self.config.vit_config.num_channels
                if in_features == in_channels * patch_size * patch_size:
                    tensor = tensor.reshape(
                        out_channels, patch_size, patch_size, in_channels
                    )
                    tensor = tensor.permute(0, 3, 1, 2).contiguous()

            filtered_weights.append((name, tensor))

582
583
        # Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
        loader = AutoWeightsLoader(self, skip_prefixes=["vit_pos_embed.pos_embed"])
zhuwenwen's avatar
zhuwenwen committed
584
        return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper)