vision.py 8.02 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import ABC, abstractmethod
4
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
5

6
import torch
7
8
from transformers import PretrainedConfig

9
10
11
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
                                     get_global_forced_attn_backend)
12
from vllm.jsontree import JSONTree, json_map_leaves
13
from vllm.logger import init_logger
14
from vllm.platforms import _Backend, current_platform
15

16
17
from .interfaces import MultiModalEmbeddings

18
logger = init_logger(__name__)
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
_C = TypeVar("_C", bound=PretrainedConfig)


class VisionEncoderInfo(ABC, Generic[_C]):

    def __init__(self, vision_config: _C) -> None:
        super().__init__()

        self.vision_config = vision_config

    @abstractmethod
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_max_image_tokens(self) -> int:
        raise NotImplementedError

    @abstractmethod
44
    def get_image_size(self) -> int:
45
46
47
        raise NotImplementedError

    @abstractmethod
48
49
50
51
52
    def get_patch_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_grid_length(self) -> int:
53
54
55
        raise NotImplementedError


56
57
58
59
60
61
class VisionLanguageConfig(Protocol):
    vision_config: Final[PretrainedConfig]


def get_vision_encoder_info(
        hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
62
63
64
65
66
    # Avoid circular imports
    from .clip import CLIPEncoderInfo, CLIPVisionConfig
    from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
    from .siglip import SiglipEncoderInfo, SiglipVisionConfig

67
    vision_config = hf_config.vision_config
68
69
70
71
72
73
74
75
76
    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPEncoderInfo(vision_config)
    if isinstance(vision_config, PixtralVisionConfig):
        return PixtralHFEncoderInfo(vision_config)
    if isinstance(vision_config, SiglipVisionConfig):
        return SiglipEncoderInfo(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
77
78
79
80
81
82
83
84
85
86
87
88
89


def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
    """
    Get the available attention backend for Vision Transformer.
    """
    # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
    selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
    if selected_backend is None:
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
            selected_backend = backend_name_to_enum(backend_by_env_var)
    if selected_backend is None:
90
91
92
93
94
95
96
97
98
99
100
101
102
        if current_platform.is_cuda():
            device_available = current_platform.has_device_capability(80)
            if device_available and support_fa:
                from transformers.utils import is_flash_attn_2_available
                if is_flash_attn_2_available():
                    selected_backend = _Backend.FLASH_ATTN
                else:
                    logger.warning_once(
                        "Current `vllm-flash-attn` has a bug inside vision "
                        "module, so we use xformers backend instead. You can "
                        "run `pip install flash-attn` to use flash-attention "
                        "backend.")
                    selected_backend = _Backend.XFORMERS
103
            else:
104
                # For Volta and Turing GPUs, use xformers instead.
105
106
                selected_backend = _Backend.XFORMERS
        else:
107
108
            # Default to torch SDPA for other non-GPU platforms.
            selected_backend = _Backend.TORCH_SDPA
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    return selected_backend


def resolve_visual_encoder_outputs(
    encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
    feature_sample_layers: Optional[list[int]],
    post_layer_norm: Optional[torch.nn.LayerNorm],
    max_possible_layers: int,
) -> torch.Tensor:
    """Given the outputs a visual encoder module that may correspond to the
    output of the last layer, or a list of hidden states to be stacked,
    handle post normalization and resolve it into a single output tensor.

    Args:
        encoder_outputs: Output of encoder's last layer or all hidden states.
        feature_sample_layers: Optional layer indices to grab from the encoder
            outputs; if provided, encoder outputs must be a list.
        post_layer_norm: Post norm to apply to the output of the encoder.
        max_possible_layers: Total layers in the fully loaded visual encoder.

    """
    if feature_sample_layers is None:
        if post_layer_norm is not None:
            return post_layer_norm(encoder_outputs)
        return encoder_outputs

    # Get the hidden states corresponding to the layer indices.
    # Negative values are relative to the full visual encoder,
    # so offset them depending on how many layers were loaded.
138
139
140
141
142
    # NOTE: this assumes that encoder_outputs is a list containing
    # the inputs to the visual encoder, followed by the hidden states
    # of each layer.
    num_loaded_layers = len(encoder_outputs) - 1
    offset = max_possible_layers - num_loaded_layers
143
144
145
146
147
148
149
150
151
152
153
    hs_pool = [
        encoder_outputs[layer_idx]
        if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
        for layer_idx in feature_sample_layers
    ]

    # Apply post-norm on the final hidden state if we are using it
    uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
    if post_layer_norm is not None and uses_last_layer:
        hs_pool[-1] = post_layer_norm(encoder_outputs)
    return torch.cat(hs_pool, dim=-1)
154
155
156
157


def scatter_patch_features(
    features: torch.Tensor,
158
    embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
159
160
161
162
163
164
165
166
167
168
169
170
171
172
) -> tuple[torch.Tensor, ...]:
    """
    Scatter the patch features into a contiguous tensor that corresponds
    to the embedding tokens defined by the multimodal processor.
    
    The rest of the values in the tensor are set to NaN so that they
    can be filtered out by :func`select_patch_features`.

    Args:
        features: The patch features, concatenated across each image.
          Shape: `(num_patch, feature_depth)`
        embed_is_patch: A boolean mask indicating which image embeddings
          correspond to patch tokens for each image.
          Shape: `(num_images, num_embeds)`
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

    Note:
        The original code only considers patch tokens as feature
        tokens, but our processor considers all image-related tokens
        as feature tokens because the feature tokens need to be
        consecutive in `input_ids`.

    Example:
        A simplified example for one image:

        .. code-block::

            Embedding tokens (from HF processor):
            [<start> <patch> <patch>  <col>  <patch> <patch>  <col>  <end> ]

            embed_is_patch (from HF processor):
            [ False   True    True    False    True    True   False  False ]

            Encoder outputs (from model):
            [  p1      p2      p3      p4   ]

            The resulting embedding tensor is:
            [  nan     p1      p2      nan      p3      p4     nan    nan  ]
196
    """
197
198
199
200
201
202
203
    num_embeds_per_image = [
        e_is_patch.numel() for e_is_patch in embed_is_patch
    ]
    if isinstance(embed_is_patch, torch.Tensor):
        embed_is_patch_flat = embed_is_patch.view(-1)
    else:
        embed_is_patch_flat = torch.cat(embed_is_patch)
204
205
206
207
208

    embeds_flat = features.new_full(
        (sum(num_embeds_per_image), features.shape[-1]),
        fill_value=torch.nan,
    )
209
    embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

    return embeds_flat.split(num_embeds_per_image)


def select_patch_features(
        multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
    """
    Given the outputs of :func:`scatter_patch_features`, return only
    the values that correspond to patch features.
    """
    selected_features = json_map_leaves(
        lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
        cast(JSONTree[torch.Tensor], multimodal_embeddings),
    )
    return cast(MultiModalEmbeddings, selected_features)