pipeline_bagel.py 25.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
BagelPipeline implementation for vLLM-Omni.
"""

from __future__ import annotations

import json
import os
from collections.abc import Iterable
from dataclasses import dataclass
from math import isqrt

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from transformers import AutoTokenizer, SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
from vllm.logger import init_logger
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.transformers_utils.configs.bagel import BagelConfig

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific

from .autoencoder import AutoEncoder, AutoEncoderParams
from .bagel_transformer import Bagel, NaiveCache, Qwen2MoTConfig, Qwen2MoTForCausalLM

logger = init_logger(__name__)


@dataclass
class BagelGenParams:
    num_timesteps: int = 50
    timestep_shift: float = 1.0


def add_special_tokens(tokenizer):
    all_special_tokens = []
    for k, v in tokenizer.special_tokens_map.items():
        if isinstance(v, str):
            all_special_tokens.append(v)
        elif isinstance(v, list):
            all_special_tokens += v

    new_tokens = []

    if "<|im_start|>" not in all_special_tokens:
        new_tokens.append("<|im_start|>")

    if "<|im_end|>" not in all_special_tokens:
        new_tokens.append("<|im_end|>")

    if "<|vision_start|>" not in all_special_tokens:
        new_tokens.append("<|vision_start|>")

    if "<|vision_end|>" not in all_special_tokens:
        new_tokens.append("<|vision_end|>")

    num_new_tokens = tokenizer.add_tokens(new_tokens)
    bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
    eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    start_of_image = tokenizer.convert_tokens_to_ids("<|vision_start|>")
    end_of_image = tokenizer.convert_tokens_to_ids("<|vision_end|>")

    new_token_ids = dict(
        bos_token_id=bos_token_id,
        eos_token_id=eos_token_id,
        start_of_image=start_of_image,
        end_of_image=end_of_image,
    )

    return tokenizer, new_token_ids, num_new_tokens


def get_bagel_post_process_func(od_config: OmniDiffusionConfig):
    # BagelPipeline returns PIL.Image.Image directly.
    def post_process_func(x):
        return x

    return post_process_func


@dataclass
class _VaeCfg:
    z_channels: int = 16
    downsample: int = 8


@dataclass
class _VitCfg:
    patch_size: int = 14
    hidden_size: int = 1152


def default_ae_params() -> AutoEncoderParams:
    return AutoEncoderParams(
        resolution=256,
        in_channels=3,
        downsample=8,
        ch=128,
        out_ch=3,
        ch_mult=[1, 2, 4, 4],
        num_res_blocks=2,
        z_channels=16,
        scale_factor=0.3611,
        shift_factor=0.1159,
    )


class SiglipNaViTWrapper(nn.Module):
    def __init__(self, vision_model):
        super().__init__()
        # If input is SiglipVisionModel, unwrap it to get SiglipVisionTransformer
        if hasattr(vision_model, "vision_model"):
            self.vision_model = vision_model.vision_model
        else:
            self.vision_model = vision_model

        # Configure weights for linear equivalent of patch embedding
        self.patch_embed_weight = self.vision_model.embeddings.patch_embedding.weight
        self.patch_embed_bias = self.vision_model.embeddings.patch_embedding.bias

    def forward(self, packed_pixel_values, packed_flattened_position_ids, cu_seqlens, max_seqlen):
        w = self.patch_embed_weight.view(self.patch_embed_weight.shape[0], -1)
        x = F.linear(packed_pixel_values, w, self.patch_embed_bias)
        pos = self.vision_model.embeddings.position_embedding(packed_flattened_position_ids)
        x = x + pos
        hidden_states = x.unsqueeze(0)
        seq_len = x.shape[0]
        mask = torch.full((1, 1, seq_len, seq_len), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
        cu_seqlens_list = cu_seqlens.tolist()
        for i in range(len(cu_seqlens_list) - 1):
            start = cu_seqlens_list[i]
            end = cu_seqlens_list[i + 1]
            mask[..., start:end, start:end] = 0.0

        outputs = self.vision_model.encoder(inputs_embeds=hidden_states, attention_mask=mask)
        return outputs.last_hidden_state.squeeze(0)


class BagelPipeline(nn.Module):
    """Bagel generation pipeline (MoT) packaged for vllm-omni diffusion engine.

    This pipeline is self-contained and uses the ported Bagel core files.
    """

    def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
        super().__init__()
        self.od_config = od_config
        self.device = get_local_device()

        model = od_config.model
        local_files_only = os.path.exists(model)
        if local_files_only:
            model_path = model
        else:
            # Download everything required (ema.safetensors, ae.safetensors, tokenizer files, configs).
            model_path = download_weights_from_hf_specific(model, od_config.revision, ["*"])

        # Load Bagel top-level config for VAE settings.
        cfg_path = os.path.join(model_path, "config.json")
        with open(cfg_path, encoding="utf-8") as f:
            bagel_cfg = json.load(f)

        vae_cfg_dict = bagel_cfg.get("vae_config") or {}
        vae_cfg = _VaeCfg(
            z_channels=int(vae_cfg_dict.get("z_channels", 16)),
            downsample=int(vae_cfg_dict.get("downsample", 8)),
        )

        # LLM config: Bagel MoT requires explicitly setting layer_module
        llm_cfg_path = os.path.join(model_path, "llm_config.json")
        llm_config = Qwen2MoTConfig.from_json_file(llm_cfg_path)
        llm_config.qk_norm = True
        llm_config.tie_word_embeddings = False
        # Allow overriding from vllm-omni config if user wants MoE/vanilla.
        llm_config.layer_module = od_config.override_transformer_cls_name or "Qwen2MoTDecoderLayer"

        # Tokenizer and special tokens.
        # Bagel uses a Qwen2 tokenizer variant; prefer trust_remote_code to get the
        # correct tokenizer implementation from the checkpoint repo when available.
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            local_files_only=True,
            trust_remote_code=True,
        )

        # Try finding vision_config or interpolate from top-level config
        vit_cfg_dict = bagel_cfg.get("vit_config") or {}
        vit_cfg = _VitCfg(
            patch_size=int(vit_cfg_dict.get("patch_size", 14)),
            hidden_size=int(vit_cfg_dict.get("hidden_size", 1152)),
        )
        vit_config_path = os.path.join(model_path, "vit_config.json")
        vit_conf = SiglipVisionConfig.from_json_file(vit_config_path)
        self.vit_model = SiglipVisionModel(vit_conf)
        self.image_processor = SiglipImageProcessor.from_pretrained(model_path, local_files_only=True)

        if self.vit_model:
            self.vit_model = SiglipNaViTWrapper(self.vit_model)
            vit_cfg.hidden_size = self.vit_model.vision_model.config.hidden_size
            vit_cfg.patch_size = self.vit_model.vision_model.config.patch_size

        self.tokenizer, self.new_token_ids, _ = add_special_tokens(self.tokenizer)

        tok_len = len(self.tokenizer)
        required_max_id = max(int(v) for v in self.new_token_ids.values())
        llm_config.vocab_size = max(
            int(getattr(llm_config, "vocab_size", tok_len)),
            int(tok_len),
            int(required_max_id + 1),
        )

        self.language_model = Qwen2MoTForCausalLM(llm_config)
        ae_params: AutoEncoderParams = default_ae_params()
        self.vae = AutoEncoder(ae_params)

        self.bagel = Bagel(
            language_model=self.language_model,
            vit_model=self.vit_model,
            config=BagelConfig(
                llm_config=llm_config,
                vae_config=vae_cfg,
                vit_config=vit_cfg,
                vit_max_num_patch_per_side=int(bagel_cfg.get("vit_max_num_patch_per_side", 70)),
                connector_act=str(bagel_cfg.get("connector_act", "gelu_pytorch_tanh")),
                interpolate_pos=bool(bagel_cfg.get("interpolate_pos", False)),
                latent_patch_size=int(bagel_cfg.get("latent_patch_size", 2)),
                max_latent_size=int(bagel_cfg.get("max_latent_size", 32)),
                timestep_shift=float(bagel_cfg.get("timestep_shift", 1.0)),
            ),
        )

        # Let vLLM loader download and stream all *.safetensors under model root.
        self.weights_sources = [
            DiffusersPipelineLoader.ComponentSource(
                model_or_path=od_config.model,
                subfolder=None,
                revision=od_config.revision,
                prefix="",
                fall_back_to_pt=False,
            )
        ]

        self.to(self.device)

    @staticmethod
    def _decode_image_from_latent(
        bagel: Bagel, vae: AutoEncoder, latent: torch.Tensor, image_shape: tuple[int, int]
    ) -> Image.Image:
        H, W = image_shape
        h, w = H // bagel.latent_downsample, W // bagel.latent_downsample
        p = bagel.latent_patch_size
        c = bagel.latent_channel
        latent = latent.reshape(1, h, w, p, p, c)
        latent = torch.einsum("nhwpqc->nchpwq", latent)
        latent = latent.reshape(1, c, h * p, w * p)

        # Cast to VAE dtype (e.g. bfloat16) as latents might remain float32 from generation loop
        vae_dtype = next(vae.parameters()).dtype
        latent = latent.to(vae_dtype)

        image = vae.decode(latent)
        image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
        return Image.fromarray(image.to(torch.uint8).cpu().numpy())

    @torch.inference_mode()
    def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
        if len(req.prompts) > 1:
            logger.warning(
                """This model only supports a single prompt, not a batched request.""",
                """Taking only the first image for now.""",
            )
        # TODO: In online mode, sometimes it receives [{"prompts": None}, {...}], so cannot use .get("...", "")
        # TODO: May be some data formatting operations on the API side. Hack for now.
        first_prompt = req.prompts[0]
        prompt = first_prompt if isinstance(req.prompts[0], str) else (req.prompts[0].get("prompt") or "")

        max_hw = int(self.bagel.max_latent_size * self.bagel.latent_downsample)
        if req.sampling_params.height is None and req.sampling_params.width is None:
            height = width = max_hw
        else:
            height = int(req.sampling_params.height) if req.sampling_params.height is not None else max_hw
            width = int(req.sampling_params.width) if req.sampling_params.width is not None else max_hw
        if height > max_hw or width > max_hw:
            raise ValueError(
                f"Requested resolution {height}x{width} exceeds Bagel checkpoint limit "
                f"{max_hw}x{max_hw} (max_latent_size={self.bagel.max_latent_size}, "
                f"latent_downsample={self.bagel.latent_downsample})."
            )
        image_shape = (height, width)

        # Map request params to Bagel gen params (defaults follow Bagel inferencer)
        gen_params = BagelGenParams(
            num_timesteps=int(req.sampling_params.num_inference_steps or 50),
            timestep_shift=3.0,
        )

        gen_context = {
            "kv_lens": [0],
            "ropes": [0],
            "past_key_values": NaiveCache(self.bagel.config.llm_config.num_hidden_layers),
        }

        # Add text prompt (prefill) on gen context.
        # [Omni] Check for injected KV Cache from remote transfer
        injected_kv = req.sampling_params.past_key_values
        if injected_kv is not None:
            logger.info("Using injected KV Cache (direct)")
            gen_context["past_key_values"] = injected_kv

            # User requested: kv_lens and ropes set to [gen_context["past_key_values"].key_cache[0].shape[0]]
            # Assuming injected_kv is compatible and has key_cache[0]
            seq_len = injected_kv.key_cache[0].shape[0]
            gen_context["kv_lens"] = [seq_len]
            gen_context["ropes"] = [seq_len]

        else:
            image_input = (
                None if isinstance(first_prompt, str) else (first_prompt.get("multi_modal_data") or {}).get("image")
            )
            if image_input and not isinstance(image_input, list):
                image_input = [image_input]
            if image_input:
                image_input = [Image.open(image) if isinstance(image, str) else image for image in image_input]

            if image_input:
                # If we have an image, we prefill with it
                if self.image_processor and self.vae:

                    def vit_transforms(img):
                        # SigLIP processor returns dict with pixel_values; we want the tensor
                        return self.image_processor(images=img, return_tensors="pt").pixel_values[0]

                    def vae_transforms(img):
                        if img.mode != "RGB":
                            img = img.convert("RGB")
                        # Convert to [-1, 1] tensor (H, W, C) -> (C, H, W)
                        arr = torch.from_numpy(np.array(img)).float() / 127.5 - 1.0
                        return arr.permute(2, 0, 1)

                    # 1. Update VAE
                    gen_input_vae, newlens_vae, new_rope_vae = self.bagel.prepare_vae_images(
                        curr_kvlens=gen_context["kv_lens"],
                        curr_rope=gen_context["ropes"],
                        images=image_input,
                        transforms=vae_transforms,
                        new_token_ids=self.new_token_ids,
                    )

                    for k, v in gen_input_vae.items():
                        if torch.is_tensor(v):
                            gen_input_vae[k] = v.to(self.device)

                    # VAE needs bfloat16 to match model strings usually, specifically encode
                    with torch.autocast(
                        device_type=self.device.type,
                        enabled=self.device.type != "cpu",
                        dtype=self.od_config.dtype,
                    ):
                        gen_context["past_key_values"] = self.bagel.forward_cache_update_vae(
                            self.vae, gen_context["past_key_values"], **gen_input_vae
                        )
                    gen_context["kv_lens"] = newlens_vae
                    gen_context["ropes"] = new_rope_vae

                    # 2. Update ViT
                    gen_input_img, newlens_img, new_rope_img = self.bagel.prepare_vit_images(
                        curr_kvlens=gen_context["kv_lens"],
                        curr_rope=gen_context["ropes"],
                        images=image_input,
                        transforms=vit_transforms,
                        new_token_ids=self.new_token_ids,
                    )

                    for k, v in gen_input_img.items():
                        if torch.is_tensor(v):
                            gen_input_img[k] = v.to(self.device)

                    with torch.autocast(
                        device_type=self.device.type,
                        enabled=self.device.type != "cpu",
                        dtype=self.od_config.dtype,
                    ):
                        gen_context["past_key_values"] = self.bagel.forward_cache_update_vit(
                            gen_context["past_key_values"], **gen_input_img
                        )
                    gen_context["kv_lens"] = newlens_img
                    gen_context["ropes"] = new_rope_img
            generation_input, newlens, new_rope = self.bagel.prepare_prompts(
                curr_kvlens=gen_context["kv_lens"],
                curr_rope=gen_context["ropes"],
                prompts=[prompt],
                tokenizer=self.tokenizer,
                new_token_ids=self.new_token_ids,
            )
            # Fail fast with a clear error instead of CUDA gather OOB.
            max_tid = int(generation_input["packed_text_ids"].max().item())
            emb_n = int(self.language_model.model.embed_tokens.weight.shape[0])
            if max_tid >= emb_n:
                raise ValueError(
                    "Tokenizer/model vocab mismatch: max token id "
                    f"{max_tid} >= embed_tokens size {emb_n}. "
                    "This usually means you're not using the tokenizer shipped with the Bagel checkpoint, "
                    "or llm_config.vocab_size is smaller than the tokenizer vocab."
                )
            for k, v in generation_input.items():
                if torch.is_tensor(v):
                    generation_input[k] = v.to(self.device)
            with torch.autocast(
                device_type=self.device.type,
                enabled=self.device.type != "cpu",
                dtype=self.od_config.dtype,
            ):
                gen_context["past_key_values"] = self.bagel.forward_cache_update_text(
                    gen_context["past_key_values"], **generation_input
                )
            gen_context["kv_lens"] = newlens
            gen_context["ropes"] = new_rope

        if req.sampling_params.seed is not None:
            torch.manual_seed(req.sampling_params.seed)
            if self.device.type == "cuda":
                torch.cuda.manual_seed(req.sampling_params.seed)

        # Prepare latent query and run flow
        generation_input = self.bagel.prepare_vae_latent(
            curr_kvlens=gen_context["kv_lens"],
            curr_rope=gen_context["ropes"],
            image_sizes=[image_shape],
            new_token_ids=self.new_token_ids,
        )
        # Fail fast for special tokens used by the image path as well.
        max_tid_img = int(generation_input["packed_text_ids"].max().item())
        emb_n = int(self.language_model.model.embed_tokens.weight.shape[0])
        if max_tid_img >= emb_n:
            raise ValueError(
                "Tokenizer/model vocab mismatch (image path): max token id "
                f"{max_tid_img} >= embed_tokens size {emb_n}. "
                "This indicates the tokenizer token IDs do not match the checkpoint embeddings."
            )
        # Position ids must be non-negative; negative ids can trigger CUDA gather OOB inside RoPE.
        min_pid = int(generation_input["packed_position_ids"].min().item())
        if min_pid < 0:
            raise ValueError(f"Invalid packed_position_ids: min={min_pid} (must be >= 0)")
        # Latent position embedding bounds check: ids must be < max_latent_size^2.
        max_lat_pid = int(generation_input["packed_vae_position_ids"].max().item())
        max_lat_pid_allowed = int(self.bagel.max_latent_size * self.bagel.max_latent_size) - 1
        if max_lat_pid > max_lat_pid_allowed:
            raise ValueError(
                "Invalid packed_vae_position_ids (latent position embedding OOB): "
                f"max={max_lat_pid} > allowed_max={max_lat_pid_allowed}. "
                f"Requested image_shape={image_shape}, max_latent_size={self.bagel.max_latent_size}."
            )
        for k, v in generation_input.items():
            if torch.is_tensor(v):
                generation_input[k] = v.to(self.device)

        with torch.autocast(
            device_type=self.device.type,
            enabled=self.device.type != "cpu",
            dtype=self.od_config.dtype,
        ):
            latents = self.bagel.generate_image(
                past_key_values=gen_context["past_key_values"],
                num_timesteps=gen_params.num_timesteps,
                timestep_shift=gen_params.timestep_shift,
                **generation_input,
            )

        # Decode first sample
        img = self._decode_image_from_latent(self.bagel, self.vae, latents[0], image_shape)
        return DiffusionOutput(output=img)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        state = self.state_dict()
        allowed = set(state.keys())
        shapes = {k: tuple(v.shape) for k, v in state.items()}

        def _normalize_name(name: str) -> str:
            # Common wrappers/prefixes in checkpoints.
            for pfx in ("module.", "model."):
                if name.startswith(pfx):
                    name = name[len(pfx) :]
            # Common component renames across repos.
            if name.startswith("vae_model."):
                name = "vae." + name[len("vae_model.") :]
            # Bagel `ae.safetensors` commonly stores AE weights without a top-level prefix.
            # Map them into this pipeline's `vae.*` namespace.
            if name.startswith("encoder.") or name.startswith("decoder."):
                name = "vae." + name
            return name

        def _iter_candidate_names(name: str) -> Iterable[str]:
            """Yield candidate parameter names in this pipeline for a checkpoint key.

            The upstream Bagel repo typically stores Bagel-core layers (time_embedder,
            latent_pos_embed, vae2llm, llm2vae, etc.) at the top-level of the model,
            while this vllm-omni integration nests them under `self.bagel`.
            """
            n = _normalize_name(name)
            yield n

            # Map Bagel core layers from top-level -> `bagel.*` namespace.
            for pfx in ("time_embedder.", "latent_pos_embed.", "vae2llm.", "llm2vae."):
                if n.startswith(pfx):
                    yield "bagel." + n
                    break

            # Map connector and vit_pos_embed to `bagel.*`
            for pfx in ("connector.", "vit_pos_embed."):
                if n.startswith(pfx):
                    yield "bagel." + n
                    break

            if n.startswith("vit_model."):
                yield "bagel." + n  # matches self.bagel.vit_model
            elif n.startswith("vision_model."):
                yield "bagel.vit_model." + n
            elif n.startswith("model.vision_model."):
                yield "bagel.vit_model." + n[len("model.") :]

        def _filtered_weights():
            total = 0
            kept = 0
            shape_mismatch = 0
            for name, tensor in weights:
                total += 1
                picked = None
                for cand in _iter_candidate_names(name):
                    if cand in allowed:
                        # Only accept if tensor shape matches target param/buffer shape.
                        if tuple(tensor.shape) == shapes.get(cand):
                            picked = cand
                            break
                        else:
                            if cand.endswith("bagel.latent_pos_embed.pos_embed") and tensor.ndim == 2:
                                npos, hdim = tensor.shape
                                side = isqrt(int(npos))
                                if side * side == int(npos) and hdim == int(self.bagel.hidden_size):
                                    param = self.bagel.latent_pos_embed.pos_embed
                                    # Resize in-place to keep the same Parameter object.
                                    param.data = param.data.new_empty((npos, hdim))
                                    # Update model bookkeeping so position-id generation matches.
                                    self.bagel.max_latent_size = int(side)
                                    if hasattr(self.bagel, "config"):
                                        setattr(self.bagel.config, "max_latent_size", int(side))
                                    if hasattr(self.bagel.latent_pos_embed, "max_num_patch_per_side"):
                                        self.bagel.latent_pos_embed.max_num_patch_per_side = int(side)
                                    shapes[cand] = (npos, hdim)
                                    picked = cand
                                    break
                            # Handle flattened patch embedding for SigLIP
                            if cand.endswith("embeddings.patch_embedding.weight") and tensor.ndim == 2:
                                # Checkpoint has (Hidden, C*P*P), model expects (Hidden, C, P, P)
                                if shapes.get(cand) is not None:
                                    target_shape = shapes[cand]
                                    if tensor.numel() == torch.prod(torch.tensor(target_shape)):
                                        # Reshape tensor to match target
                                        tensor = tensor.view(target_shape)
                                        picked = cand
                                        break

                            shape_mismatch += 1
                            # Keep this quiet; shape mismatches are expected for ignored modules.
                if picked is not None:
                    kept += 1
                    yield picked, tensor
                # else: ignore extra weights (e.g. connector/vision/und)
            logger.info_once(
                "BagelPipeline weight filter kept %d/%d tensors (shape mismatches seen: %d)",
                kept,
                total,
                shape_mismatch,
            )

        loader = AutoWeightsLoader(self)
        return loader.load_weights(_filtered_weights())