isaac.py 2.74 KB
Newer Older
oscardev256's avatar
oscardev256 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations

from transformers import Qwen3Config
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig


class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
    """Vision configuration for Isaac with Pixel Shuffle support.

    Extends Siglip2VisionConfig with additional fields for pixel shuffle.
    """

    model_type = "pixel_shuffle_siglip2"
    base_config_key = "vision_config"

    def __init__(
        self,
        pixel_shuffle_scale_factor: int = 1,
        num_patches: int = 256,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Add our custom fields
        self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor
        self.num_patches = num_patches


class IsaacConfig(Qwen3Config):
    """Configuration class for Isaac multimodal model."""

    model_type = "isaac"
    sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig}

    def __init__(
        self,
        vision_config=None,
        vision_patch_size: int = 16,
        vision_max_num_patches: int = 256,
        vision_min_num_patches: int | None = None,
        pixel_shuffle_scale: int = 1,
        max_sequence_length: int = 16384,
        vision_token: str = "<image>",
        vision_attn_implementation: str | None = None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # EventStreamProcessor parameters (for backward compatibility)
        self.video_patch_size = vision_patch_size
        self.vision_max_num_patches = vision_max_num_patches
        self.vision_min_num_patches = vision_min_num_patches
        self.pixel_shuffle_scale = pixel_shuffle_scale

        # Processing parameters
        self.max_sequence_length = max_sequence_length
        self.vision_token = vision_token

        # Handle vision config - PixelShuffleSiglip2VisionConfig instance
        if isinstance(vision_config, dict):
            self.vision_config = PixelShuffleSiglip2VisionConfig(**vision_config)
        elif vision_config is None:
            self.vision_config = PixelShuffleSiglip2VisionConfig()
        else:
            self.vision_config = vision_config

        # Ensure compatibility with pretrained checkpoints
        self.vision_config.pixel_shuffle_scale_factor = getattr(
            self.vision_config,
            "pixel_shuffle_scale_factor",
            pixel_shuffle_scale,
        )
        self.vision_config.num_patches = getattr(
            self.vision_config,
            "num_patches",
            vision_max_num_patches,
        )
        self.vision_attn_implementation = vision_attn_implementation


__all__ = [
    "IsaacConfig",
    "PixelShuffleSiglip2VisionConfig",
]