"tests/vscode:/vscode.git/clone" did not exist on "86d9fc29cb39bbb09b5ac3202293eb6198666f18"
vision.py 5.27 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
3

4
import torch
5
6
from transformers import PretrainedConfig

7
8
9
10
11
12
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
                                     get_global_forced_attn_backend)
from vllm.platforms import _Backend, current_platform
from vllm.utils import print_warning_once

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
_C = TypeVar("_C", bound=PretrainedConfig)


class VisionEncoderInfo(ABC, Generic[_C]):

    def __init__(self, vision_config: _C) -> None:
        super().__init__()

        self.vision_config = vision_config

    @abstractmethod
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_max_image_tokens(self) -> int:
        raise NotImplementedError

    @abstractmethod
37
    def get_image_size(self) -> int:
38
39
40
        raise NotImplementedError

    @abstractmethod
41
42
43
44
45
    def get_patch_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_grid_length(self) -> int:
46
47
48
        raise NotImplementedError


49
50
51
52
53
54
class VisionLanguageConfig(Protocol):
    vision_config: Final[PretrainedConfig]


def get_vision_encoder_info(
        hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
55
56
57
58
59
    # Avoid circular imports
    from .clip import CLIPEncoderInfo, CLIPVisionConfig
    from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
    from .siglip import SiglipEncoderInfo, SiglipVisionConfig

60
    vision_config = hf_config.vision_config
61
62
63
64
65
66
67
68
69
    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPEncoderInfo(vision_config)
    if isinstance(vision_config, PixtralVisionConfig):
        return PixtralHFEncoderInfo(vision_config)
    if isinstance(vision_config, SiglipVisionConfig):
        return SiglipEncoderInfo(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143


def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
    """
    Get the available attention backend for Vision Transformer.
    """
    # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
    selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
    if selected_backend is None:
        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
        if backend_by_env_var is not None:
            selected_backend = backend_name_to_enum(backend_by_env_var)
    if selected_backend is None:
        # For Volta and Turing GPUs, use xformers instead.
        device_available = current_platform.has_device_capability(80)
        if device_available and support_fa:
            from transformers.utils import is_flash_attn_2_available
            if is_flash_attn_2_available():
                selected_backend = _Backend.FLASH_ATTN
            else:
                print_warning_once(
                    "Current `vllm-flash-attn` has a bug inside vision module, "
                    "so we use xformers backend instead. You can run "
                    "`pip install flash-attn` to use flash-attention backend.")
                selected_backend = _Backend.XFORMERS
        elif current_platform.is_cpu() or current_platform.is_rocm():
            # ROCM doesn't support xformers
            selected_backend = _Backend.TORCH_SDPA
        else:
            selected_backend = _Backend.XFORMERS
    return selected_backend


def resolve_visual_encoder_outputs(
    encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
    feature_sample_layers: Optional[list[int]],
    post_layer_norm: Optional[torch.nn.LayerNorm],
    max_possible_layers: int,
) -> torch.Tensor:
    """Given the outputs a visual encoder module that may correspond to the
    output of the last layer, or a list of hidden states to be stacked,
    handle post normalization and resolve it into a single output tensor.

    Args:
        encoder_outputs: Output of encoder's last layer or all hidden states.
        feature_sample_layers: Optional layer indices to grab from the encoder
            outputs; if provided, encoder outputs must be a list.
        post_layer_norm: Post norm to apply to the output of the encoder.
        max_possible_layers: Total layers in the fully loaded visual encoder.

    """
    if feature_sample_layers is None:
        if post_layer_norm is not None:
            return post_layer_norm(encoder_outputs)
        return encoder_outputs

    # Get the hidden states corresponding to the layer indices.
    # Negative values are relative to the full visual encoder,
    # so offset them depending on how many layers were loaded.
    # NOTE: this assumes that encoder_outputs contains a list
    # of hidden states in the same order as the encoder layers
    # that produced them.
    offset = max_possible_layers - len(encoder_outputs)
    hs_pool = [
        encoder_outputs[layer_idx]
        if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
        for layer_idx in feature_sample_layers
    ]

    # Apply post-norm on the final hidden state if we are using it
    uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
    if post_layer_norm is not None and uses_last_layer:
        hs_pool[-1] = post_layer_norm(encoder_outputs)
    return torch.cat(hs_pool, dim=-1)