vision.py 5.65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
6

7
import torch
8
9
from transformers import PretrainedConfig

10
11
12
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
                                     get_global_forced_attn_backend)
13
from vllm.logger import init_logger
14
from vllm.platforms import _Backend, current_platform
zhuwenwen's avatar
zhuwenwen committed
15
from vllm.utils import SUPPORT_TC
16
17

logger = init_logger(__name__)
18

19
20
21
22
23
_C = TypeVar("_C", bound=PretrainedConfig)


class VisionEncoderInfo(ABC, Generic[_C]):

24
    def __init__(self, hf_config: _C) -> None:
25
26
        super().__init__()

27
28
        self.hf_config = hf_config
        self.vision_config = hf_config.vision_config
29
30
31
32
33
34
35
36
37
38
39

    @abstractmethod
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        raise NotImplementedError

    @abstractmethod
40
    def get_image_size(self) -> int:
41
42
43
        raise NotImplementedError

    @abstractmethod
44
45
46
47
48
    def get_patch_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_grid_length(self) -> int:
49
50
51
        raise NotImplementedError


52
53
54
55
56
57
class VisionLanguageConfig(Protocol):
    vision_config: Final[PretrainedConfig]


def get_vision_encoder_info(
        hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
58
59
60
61
62
    # Avoid circular imports
    from .clip import CLIPEncoderInfo, CLIPVisionConfig
    from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
    from .siglip import SiglipEncoderInfo, SiglipVisionConfig

63
64
65
66
67
68
69
70
    if isinstance(hf_config.vision_config, CLIPVisionConfig):
        return CLIPEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, SiglipVisionConfig):
        return SiglipEncoderInfo(hf_config)

    msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
71
    raise NotImplementedError(msg)
72
73
74
75
76
77
78
79
80
81
82
83
84


def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
    """
    Get the available attention backend for Vision Transformer.
    """
    # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
    selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
    if selected_backend is None:
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
            selected_backend = backend_name_to_enum(backend_by_env_var)
    if selected_backend is None:
85
        if current_platform.is_cuda() or current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
86
87
            if not SUPPORT_TC:
                selected_backend = _Backend.TORCH_SDPA
88
89
90
            device_available = current_platform.has_device_capability(80)
            if device_available and support_fa:
                from transformers.utils import is_flash_attn_2_available
91
                if is_flash_attn_2_available() or current_platform.is_rocm():
92
93
94
95
96
97
98
99
                    selected_backend = _Backend.FLASH_ATTN
                else:
                    logger.warning_once(
                        "Current `vllm-flash-attn` has a bug inside vision "
                        "module, so we use xformers backend instead. You can "
                        "run `pip install flash-attn` to use flash-attention "
                        "backend.")
                    selected_backend = _Backend.XFORMERS
100
            else:
101
                # For Volta and Turing GPUs, use xformers instead.
102
103
                selected_backend = _Backend.XFORMERS
        else:
104
105
            # Default to torch SDPA for other non-GPU platforms.
            selected_backend = _Backend.TORCH_SDPA
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    return selected_backend


def resolve_visual_encoder_outputs(
    encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
    feature_sample_layers: Optional[list[int]],
    post_layer_norm: Optional[torch.nn.LayerNorm],
    max_possible_layers: int,
) -> torch.Tensor:
    """Given the outputs a visual encoder module that may correspond to the
    output of the last layer, or a list of hidden states to be stacked,
    handle post normalization and resolve it into a single output tensor.

    Args:
        encoder_outputs: Output of encoder's last layer or all hidden states.
        feature_sample_layers: Optional layer indices to grab from the encoder
            outputs; if provided, encoder outputs must be a list.
        post_layer_norm: Post norm to apply to the output of the encoder.
        max_possible_layers: Total layers in the fully loaded visual encoder.

    """
    if feature_sample_layers is None:
        if post_layer_norm is not None:
            return post_layer_norm(encoder_outputs)
        return encoder_outputs

    # Get the hidden states corresponding to the layer indices.
    # Negative values are relative to the full visual encoder,
    # so offset them depending on how many layers were loaded.
135
136
137
138
139
    # NOTE: this assumes that encoder_outputs is a list containing
    # the inputs to the visual encoder, followed by the hidden states
    # of each layer.
    num_loaded_layers = len(encoder_outputs) - 1
    offset = max_possible_layers - num_loaded_layers
140
141
142
143
144
145
146
147
148
149
    hs_pool = [
        encoder_outputs[layer_idx]
        if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
        for layer_idx in feature_sample_layers
    ]

    # Apply post-norm on the final hidden state if we are using it
    uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
    if post_layer_norm is not None and uses_last_layer:
        hs_pool[-1] = post_layer_norm(encoder_outputs)
150
    return torch.cat(hs_pool, dim=-1)