vision.py 19.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import itertools
import math
6
from abc import ABC, abstractmethod
7
8
from collections.abc import Callable
from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
9

10
import torch
11
12
from transformers import PretrainedConfig

13
from vllm.attention.backends.registry import _Backend
14
15
16
17
18
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
19
from vllm.logger import init_logger
20
from vllm.platforms import current_platform
21
22

logger = init_logger(__name__)
23

24
25
26
_C = TypeVar("_C", bound=PretrainedConfig)


27
28
29
30
class _RootConfig(Protocol[_C]):
    vision_config: _C


31
class VisionEncoderInfo(ABC, Generic[_C]):
32
    def __init__(self, hf_config: _RootConfig[_C]) -> None:
33
34
        super().__init__()

35
36
        self.hf_config = hf_config
        self.vision_config = hf_config.vision_config
37
38
39
40
41
42
43
44
45
46
47

    @abstractmethod
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        raise NotImplementedError

    @abstractmethod
48
    def get_image_size(self) -> int:
49
50
51
        raise NotImplementedError

    @abstractmethod
52
53
54
55
56
    def get_patch_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_grid_length(self) -> int:
57
58
59
        raise NotImplementedError


60
61
62
63
class VisionLanguageConfig(Protocol):
    vision_config: Final[PretrainedConfig]


64
def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
65
66
67
68
69
    # Avoid circular imports
    from .clip import CLIPEncoderInfo, CLIPVisionConfig
    from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
    from .siglip import SiglipEncoderInfo, SiglipVisionConfig

70
71
72
73
74
75
76
77
    if isinstance(hf_config.vision_config, CLIPVisionConfig):
        return CLIPEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, SiglipVisionConfig):
        return SiglipEncoderInfo(hf_config)

    msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
78
    raise NotImplementedError(msg)
79
80


81
82
83
84
85
86
def get_vit_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    *,
    attn_backend_override: _Backend | None = None,
) -> _Backend:
87
88
89
    """
    Get the available attention backend for Vision Transformer.
    """
90
91
92
    if attn_backend_override is not None:
        return attn_backend_override

93
94
    # Lazy import to avoid circular dependency
    from vllm.attention.selector import get_env_variable_attn_backend
95

96
    selected_backend: _Backend | None = get_env_variable_attn_backend()
97
98
99
    if selected_backend is not None:
        return selected_backend

100
    return current_platform.get_vit_attn_backend(head_size, dtype)
101
102


103
104
VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]

105
106
107
VisionFeatureSelectStrategy: TypeAlias = (
    VisionFeatureSelectStrategyStr | Callable[[torch.Tensor], torch.Tensor]
)
108
109
110


def _get_vision_feature_selector(
111
    strategy: VisionFeatureSelectStrategy | str,
112
113
114
115
116
117
) -> Callable[[torch.Tensor], torch.Tensor]:
    if callable(strategy):
        return strategy

    # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
    if strategy == "class":
118
        return lambda feats: feats[:, :1, :]
119
120
121
122
123
124
125
126

    # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
    if strategy == "default":
        return lambda feats: feats[:, 1:, :]

    if strategy == "full":
        return lambda feats: feats

127
128
129
130
131
    raise ValueError(f"Unexpected feature select strategy: {strategy!r}")


def get_num_selected_vision_tokens(
    num_vision_tokens: int,
132
    strategy: VisionFeatureSelectStrategy | str,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
) -> int:
    if callable(strategy):
        dummy_features = torch.empty(1, num_vision_tokens, 64)  # [B, L, D]
        dummy_selected_features = strategy(dummy_features)
        return dummy_selected_features.shape[1]

    if strategy == "class":
        return 1

    if strategy == "default":
        return num_vision_tokens - 1

    if strategy == "full":
        return num_vision_tokens

    raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
149
150


151
def resolve_visual_encoder_outputs(
152
153
    encoder_outputs: torch.Tensor | list[torch.Tensor],
    post_layer_norm: torch.nn.LayerNorm | None,
154
    *,
155
156
157
    select_layers: list[int] | None = None,
    max_possible_layers: int | None = None,
    feature_select_strategy: VisionFeatureSelectStrategy | None = None,
158
159
160
161
162
163
164
165
) -> torch.Tensor:
    """Given the outputs a visual encoder module that may correspond to the
    output of the last layer, or a list of hidden states to be stacked,
    handle post normalization and resolve it into a single output tensor.

    Args:
        encoder_outputs: Output of encoder's last layer or all hidden states.
        post_layer_norm: Post norm to apply to the output of the encoder.
166
167
        select_layers: Optional layer indices to grab from the encoder
            outputs; if provided, encoder outputs must be a list.
168
        max_possible_layers: Total layers in the fully loaded visual encoder.
169
170
        feature_select_strategy: Defines how to select the hidden states
            from each layer.
171
    """
172
173
    if select_layers is None:
        if not isinstance(encoder_outputs, torch.Tensor):
174
175
176
177
            raise ValueError(
                "Expected only a single encoder output when "
                "`select_layers` is not provided"
            )
178
179

        if feature_select_strategy is not None:
180
            select_features = _get_vision_feature_selector(feature_select_strategy)
181
182
            encoder_outputs = select_features(encoder_outputs)

183
184
        if post_layer_norm is not None:
            return post_layer_norm(encoder_outputs)
185

186
187
        return encoder_outputs

188
    if max_possible_layers is None:
189
190
191
        raise ValueError(
            "`max_possible_layers` must be provided alongside `select_layers`"
        )
192

193
194
195
    # Get the hidden states corresponding to the layer indices.
    # Negative values are relative to the full visual encoder,
    # so offset them depending on how many layers were loaded.
196
197
198
199
200
    # NOTE: this assumes that encoder_outputs is a list containing
    # the inputs to the visual encoder, followed by the hidden states
    # of each layer.
    num_loaded_layers = len(encoder_outputs) - 1
    offset = max_possible_layers - num_loaded_layers
201
202
    hs_pool = [
        encoder_outputs[layer_idx]
203
204
        if layer_idx >= 0
        else encoder_outputs[layer_idx + offset]
205
        for layer_idx in select_layers
206
207
    ]

208
209
210
211
    if feature_select_strategy is not None:
        select_features = _get_vision_feature_selector(feature_select_strategy)
        hs_pool = [select_features(hs) for hs in hs_pool]

212
    # Apply post-norm on the final hidden state if we are using it
213
    uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1)
214
    if post_layer_norm is not None and uses_last_layer:
215
216
        hs_pool[-1] = post_layer_norm(hs_pool[-1])

217
    return torch.cat(hs_pool, dim=-1)
218
219


220
221
222
223
def run_dp_sharded_vision_model(
    image_input: torch.Tensor, vision_model: torch.nn.Module
) -> torch.Tensor:
    """Run a vision model with data parallelism (DP) sharding. The function
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    will shard the input image tensor on the first dimension and run the vision
    model

    Args:
        image_input (torch.Tensor): Image input tensor.
        vision_model (torch.nn.Module): Vision model.
    Returns:
        torch.Tensor: Output image embeddings
    """

    num_chunks = image_input.shape[0]
    mp_world_size = get_tensor_model_parallel_world_size()
    num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
    num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
238
    pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
239
240
    image_input_padded = torch.nn.functional.pad(image_input, pad)
    rank = get_tensor_model_parallel_rank()
241
242
243
    image_input_per_rank = image_input_padded[
        rank * num_chunks_per_rank : (rank + 1) * num_chunks_per_rank, ...
    ]
244
245
246
247

    vision_embeddings = vision_model(image_input_per_rank)
    # Ensure tensor is contiguous before all_gather
    vision_embeddings = vision_embeddings.contiguous()
248
    vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0)
249
250
251
252
253
254
255
256
257
    vision_embeddings = vision_embeddings[:num_chunks, ...]
    return vision_embeddings


def get_load_balance_assignment(
    sizes: list[int],
    num_gpus: int = 2,
) -> tuple[list[int], list[int], list[int]]:
    """
258
    Generate load balancing assignment and metadata
259
260
261
    for distributing data across GPUs.
    The load is determined by the total image sizes,
    not the number of images.
262

263
264
265
    Args:
        sizes: The size of each image
        num_gpus: Number of GPUs to balance across
266

267
    Returns:
268
        shuffle_indices:
269
            Indices to reorder data for balanced loading
270
        gpu_sample_counts:
271
            Number of samples assigned to each GPU
272
        grouped_sizes_per_gpu:
273
            Total size assigned to each GPU
274

275
276
277
    Example:
        ```
        sizes = [1000, 100, 200, 50]
278
        num_gpus = 2
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        ```

    """

    n_samples = len(sizes)

    # Handle edge cases
    if n_samples == 0:
        return [], [0] * num_gpus, [0] * num_gpus

    # Use greedy algorithm - balance by total size, not sample count
    gpu_assignments = [list[int]() for _ in range(num_gpus)]
    gpu_loads = [0] * num_gpus  # This tracks total SIZE, not sample count

    # Sort indices by size (largest first for better load balancing)
    # sizes = [1000, 100, 200, 50]
    # large_to_small_indices = [0, 2, 1, 3]
296
297
298
    large_to_small_indices = sorted(
        range(n_samples), key=lambda i: sizes[i], reverse=True
    )
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

    for idx in large_to_small_indices:
        # Find GPU with minimum current load (by total size)
        min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
        gpu_assignments[min_gpu].append(idx)
        gpu_loads[min_gpu] += sizes[idx]

    # Create shuffle indices and counts
    shuffle_indices = list[int]()
    gpu_sample_counts = list[int]()
    for gpu_id in range(num_gpus):
        # GPU_0 = [1000] = [0]
        # GPU_1 = [200, 100, 50] = [2, 1, 3]
        # shuffle_indices = [0, 2, 1, 3]
        shuffle_indices.extend(gpu_assignments[gpu_id])
        # GPU_0 = [1]
        # GPU_1 = [3]
        # gpu_sample_counts = [1, 3]
        gpu_sample_counts.append(len(gpu_assignments[gpu_id]))

    return (shuffle_indices, gpu_sample_counts, gpu_loads)


def run_dp_sharded_mrope_vision_model(
    vision_model: torch.nn.Module,
    pixel_values: torch.Tensor,
    grid_thw_list: list[list[int]],
    *,
    rope_type: Literal["rope_3d", "rope_2d"],
) -> tuple[torch.Tensor, ...]:
329
330
    """Run a vision model with data parallelism (DP) sharding.
    The function will shard the input image tensor on the
331
332
    first dimension and run the vision model.
    This function is used to run the vision model with mrope.
333

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    Args:
        vision_model (torch.nn.Module): Vision model.
        pixel_values (torch.Tensor): Image/Video input tensor.
        grid_thw_list: List of grid dimensions for each image
        rope_type: Type of rope used in the vision model.
                   Different rope types have different dimension to do ViT.
                   "rope_3d" for 3D rope (e.g., Qwen2.5-VL)
                   "rope_2d" for 2D rope (e.g., Kimi-VL)
    Returns:
        torch.Tensor: Output image embeddings

    Example:
        ```
        vision_model.out_hidden_size = 64
        vision_model.spatial_merge_size = 2
        pixel_values.shape = (1350, channel)
        grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
351
        tp_size = 2
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        ```

    """
    tp_size = get_tensor_model_parallel_world_size()

    # GPU_0 tp_rank_local = 0
    # GPU_1 tp_rank_local = 1
    tp_rank_local = get_tensor_model_parallel_rank()

    # patches_per_image = [1000, 100, 200, 50]
    patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
    # patches_per_image = [0, 1000, 1100, 1300, 1350]
    cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]

    # Get load balancing assignment with all metadata
    # image_to_tp_rank = [0, 2, 1, 3]
    # gpu_sample_counts = [1, 3]
    # grouped_pixel_values_len = [1000, 350]
370
371
372
    (image_to_tp_rank, gpu_sample_counts, grouped_pixel_values_len) = (
        get_load_balance_assignment(patches_per_image, tp_size)
    )
373
374
375
376
377
378

    # cu_gpu_sample_counts = [0, 1, 4]
    cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]

    # GPU_0 image_idxs_local = [0]
    # GPU_1 image_idxs_local = [2, 1, 3]
379
380
381
    image_idxs_local = image_to_tp_rank[
        cum_gpu_sample_counts[tp_rank_local] : cum_gpu_sample_counts[tp_rank_local + 1]
    ]
382
383
384

    # Get the pixel values for the local images based on the image_idxs_local
    if len(image_idxs_local) > 0:
385
386
387
388
389
390
        pixel_values_local = torch.cat(
            [
                pixel_values[cum_patches_per_image[i] : cum_patches_per_image[i + 1]]
                for i in image_idxs_local
            ]
        )
391
392
    else:
        # Handle case where this rank has no images
393
394
395
396
397
        pixel_values_local = torch.empty(
            (0, pixel_values.shape[1]),
            device=pixel_values.device,
            dtype=pixel_values.dtype,
        )
398
399
    # embed_dim_reduction_factor = 2 * 2
    if rope_type == "rope_2d":
400
401
402
        embed_dim_reduction_factor = (
            vision_model.merge_kernel_size[0] * vision_model.merge_kernel_size[1]
        )
403
    else:
404
405
406
        embed_dim_reduction_factor = (
            vision_model.spatial_merge_size * vision_model.spatial_merge_size
        )
407
408
409
410
411

    # Find the max length across all ranks
    # The output embedding of every DP rank has to be
    # padded to this length for tensor_model_parallel_all_gather
    # to work
412
    max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor
413
414
415
416
417
418
    local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]

    # Run the vision model on the local pixel_values_local
    if rope_type == "rope_2d":
        if pixel_values_local.shape[0] > 0:
            image_embeds_local = vision_model(
419
420
                pixel_values_local, torch.tensor(local_grid_thw_list)
            )
421
422
423
424
425
426
427
            if isinstance(image_embeds_local, list):
                image_embeds_local = torch.cat(image_embeds_local, dim=0)
        else:
            out_dim = getattr(vision_model.config, "hidden_size", None)
            image_embeds_local = torch.empty(
                (0, embed_dim_reduction_factor, out_dim),
                device=pixel_values.device,
428
429
                dtype=pixel_values.dtype,
            )
430
431
    else:
        if pixel_values_local.shape[0] > 0:
432
            image_embeds_local = vision_model(pixel_values_local, local_grid_thw_list)
433
434
        else:
            # Handle empty case
435
436
437
438
439
            image_embeds_local = torch.empty(
                (0, vision_model.out_hidden_size),
                device=pixel_values.device,
                dtype=pixel_values.dtype,
            )
440
441
442
443
444
445
446

    # Pad the output based on max_len_per_rank
    # for tensor_model_parallel_all_gather to work
    current_len = image_embeds_local.shape[0]
    if current_len < max_len_per_rank:
        padding_size = max_len_per_rank - current_len
        if rope_type == "rope_2d":
447
448
449
450
451
452
453
454
455
            padding = torch.empty(
                (
                    padding_size,
                    image_embeds_local.shape[1],
                    image_embeds_local.shape[2],
                ),
                dtype=image_embeds_local.dtype,
                device=image_embeds_local.device,
            )
456
        else:
457
458
459
460
461
462
            padding = torch.empty(
                (padding_size, image_embeds_local.shape[1]),
                dtype=image_embeds_local.dtype,
                device=image_embeds_local.device,
            )
        image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0)
463
464
465
466
    else:
        image_embeds_local_padded = image_embeds_local

    # Do all_gather to collect embeddings from all ranks
467
    gathered_embeds = tensor_model_parallel_all_gather(image_embeds_local_padded, dim=0)
468
469
470
471
472

    # Remove padding and reconstruct per-rank embeddings
    rank_embeddings = list[torch.Tensor]()
    for rank in range(tp_size):
        start_idx = rank * max_len_per_rank
473
474
475
        end_idx = start_idx + (
            grouped_pixel_values_len[rank] // embed_dim_reduction_factor
        )
476
477
        rank_embeddings.append(gathered_embeds[start_idx:end_idx])

478
479
480
    patches_per_output_image = [
        (patch_size // embed_dim_reduction_factor) for patch_size in patches_per_image
    ]
481
482
483
484
485
486
487
488
489
490

    # Reconstruct embeddings in the original order
    original_order_embeddings = [None] * len(grid_thw_list)
    current_idx = 0
    for rank in range(tp_size):
        count = gpu_sample_counts[rank]
        if count > 0:
            # Get images assigned to this rank in shuffled order
            # GPU_0 = image_idxs_local  [0]
            # GPU_1 = image_idxs_local  [2, 1, 3]
491
            rank_images = image_to_tp_rank[current_idx : current_idx + count]
492
493
494
495
496
497
498

            rank_embed = rank_embeddings[rank]
            # Split rank embeddings back to individual images
            embed_start = 0
            for img_idx in rank_images:
                img_patches = patches_per_output_image[img_idx]
                original_order_embeddings[img_idx] = rank_embed[
499
500
                    embed_start : embed_start + img_patches
                ]
501
502
                embed_start += img_patches
            current_idx += count
503
504
505
506
507
508
    out_embeddings = tuple(
        embed for embed in original_order_embeddings if embed is not None
    )
    assert len(out_embeddings) == len(original_order_embeddings), (
        "Found unassigned embeddings"
    )
509
    return out_embeddings
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546


def get_llm_pos_ids_for_vision(
    start_idx: int,
    vision_idx: int,
    spatial_merge_size: int,
    t_index: list[int],
    grid_hs: torch.Tensor,
    grid_ws: torch.Tensor,
) -> torch.Tensor:
    llm_pos_ids_list = []
    llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
    llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
    h_index = (
        torch.arange(llm_grid_h)
        .view(1, -1, 1)
        .expand(len(t_index), -1, llm_grid_w)
        .flatten()
    )
    w_index = (
        torch.arange(llm_grid_w)
        .view(1, 1, -1)
        .expand(len(t_index), llm_grid_h, -1)
        .flatten()
    )
    t_index_tensor = (
        torch.Tensor(t_index)
        .to(llm_grid_h.device)
        .view(-1, 1)
        .expand(-1, llm_grid_h * llm_grid_w)
        .long()
        .flatten()
    )
    _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
    llm_pos_ids_list.append(_llm_pos_ids + start_idx)
    llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
    return llm_pos_ids
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562


# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
    """
    Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
    """
    out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
    linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
    return linear_weight