"vllm/vscode:/vscode.git/clone" did not exist on "8c7075d1484de8aa6bb9b3210a8dfd4e381bc1e3"
vision.py 20.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import itertools
import math
6
from abc import ABC, abstractmethod
7
8
from collections.abc import Callable
from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
9

10
import torch
11
12
from transformers import PretrainedConfig

13
from vllm.config import MultiModalConfig, VllmConfig, get_current_vllm_config
14
15
16
17
18
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
19
from vllm.logger import init_logger
20
from vllm.platforms import current_platform
21
from vllm.v1.attention.backends.registry import AttentionBackendEnum
22
23

logger = init_logger(__name__)
24

25
26
27
_C = TypeVar("_C", bound=PretrainedConfig)


28
29
30
31
class _RootConfig(Protocol[_C]):
    vision_config: _C


32
class VisionEncoderInfo(ABC, Generic[_C]):
33
    def __init__(self, hf_config: _RootConfig[_C]) -> None:
34
35
        super().__init__()

36
37
        self.hf_config = hf_config
        self.vision_config = hf_config.vision_config
38
39
40
41
42
43
44
45
46
47
48

    @abstractmethod
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        raise NotImplementedError

    @abstractmethod
49
    def get_image_size(self) -> int:
50
51
52
        raise NotImplementedError

    @abstractmethod
53
54
55
56
57
    def get_patch_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_grid_length(self) -> int:
58
59
60
        raise NotImplementedError


61
62
63
64
class VisionLanguageConfig(Protocol):
    vision_config: Final[PretrainedConfig]


65
def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
66
67
68
69
70
    # Avoid circular imports
    from .clip import CLIPEncoderInfo, CLIPVisionConfig
    from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
    from .siglip import SiglipEncoderInfo, SiglipVisionConfig

71
72
73
74
75
76
77
78
    if isinstance(hf_config.vision_config, CLIPVisionConfig):
        return CLIPEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, SiglipVisionConfig):
        return SiglipEncoderInfo(hf_config)

    msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
79
    raise NotImplementedError(msg)
80
81


82
def _get_vit_attn_backend(
83
84
85
    head_size: int,
    dtype: torch.dtype,
    *,
86
87
    attn_backend_override: AttentionBackendEnum | None = None,
) -> AttentionBackendEnum:
88
89
90
    """
    Get the available attention backend for Vision Transformer.
    """
91
92
93
    return current_platform.get_vit_attn_backend(
        head_size,
        dtype,
94
        backend=attn_backend_override,
95
    )
96
97


98
99
100
101
102
103
104
105
106
def get_vit_attn_backend(
    head_size: int,
    dtype: torch.dtype,
) -> AttentionBackendEnum:
    """
    Get the attention backend for Vision Transformer.
    """
    try:
        vllm_config: VllmConfig = get_current_vllm_config()
107
        model_config = vllm_config.model_config
108
        multimodal_config: MultiModalConfig | None = (
109
            model_config.multimodal_config if model_config is not None else None
110
        )
111
    except (AssertionError, AttributeError):
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        multimodal_config = None

    attn_backend_override = (
        multimodal_config.mm_encoder_attn_backend
        if multimodal_config is not None
        else None
    )
    attn_backend = _get_vit_attn_backend(
        head_size,
        dtype,
        attn_backend_override=attn_backend_override,
    )
    return attn_backend


def is_vit_use_data_parallel():
    """
    Get the tensor parallel type for Vision Transformer.
    """
    try:
        vllm_config: VllmConfig = get_current_vllm_config()
133
        model_config = vllm_config.model_config
134
        multimodal_config: MultiModalConfig | None = (
135
            model_config.multimodal_config if model_config is not None else None
136
        )
137
    except (AssertionError, AttributeError):
138
139
140
141
142
143
144
145
        multimodal_config = None

    mm_encoder_tp_mode = (
        multimodal_config.mm_encoder_tp_mode if multimodal_config is not None else None
    )
    return mm_encoder_tp_mode == "data"


Harry Mellor's avatar
Harry Mellor committed
146
147
148
149
150
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
    """Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
    return vllm_config.compilation_config.compile_mm_encoder


151
152
VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]

153
154
155
VisionFeatureSelectStrategy: TypeAlias = (
    VisionFeatureSelectStrategyStr | Callable[[torch.Tensor], torch.Tensor]
)
156
157
158


def _get_vision_feature_selector(
159
    strategy: VisionFeatureSelectStrategy | str,
160
161
162
163
164
165
) -> Callable[[torch.Tensor], torch.Tensor]:
    if callable(strategy):
        return strategy

    # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
    if strategy == "class":
166
        return lambda feats: feats[:, :1, :]
167
168
169
170
171
172
173
174

    # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
    if strategy == "default":
        return lambda feats: feats[:, 1:, :]

    if strategy == "full":
        return lambda feats: feats

175
176
177
178
179
    raise ValueError(f"Unexpected feature select strategy: {strategy!r}")


def get_num_selected_vision_tokens(
    num_vision_tokens: int,
180
    strategy: VisionFeatureSelectStrategy | str,
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
) -> int:
    if callable(strategy):
        dummy_features = torch.empty(1, num_vision_tokens, 64)  # [B, L, D]
        dummy_selected_features = strategy(dummy_features)
        return dummy_selected_features.shape[1]

    if strategy == "class":
        return 1

    if strategy == "default":
        return num_vision_tokens - 1

    if strategy == "full":
        return num_vision_tokens

    raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
197
198


199
def resolve_visual_encoder_outputs(
200
201
    encoder_outputs: torch.Tensor | list[torch.Tensor],
    post_layer_norm: torch.nn.LayerNorm | None,
202
    *,
203
204
    select_layers: list[int] | None = None,
    max_possible_layers: int | None = None,
205
    last_hs_proc: Callable[[torch.Tensor], torch.Tensor] | None = None,
206
    feature_select_strategy: VisionFeatureSelectStrategy | None = None,
207
208
209
210
211
212
213
214
) -> torch.Tensor:
    """Given the outputs a visual encoder module that may correspond to the
    output of the last layer, or a list of hidden states to be stacked,
    handle post normalization and resolve it into a single output tensor.

    Args:
        encoder_outputs: Output of encoder's last layer or all hidden states.
        post_layer_norm: Post norm to apply to the output of the encoder.
215
216
        select_layers: Optional layer indices to grab from the encoder
            outputs; if provided, encoder outputs must be a list.
217
        max_possible_layers: Total layers in the fully loaded visual encoder.
218
219
220
221
222
        last_hs_proc: Optional callable to be applied to the last layer if it
            is used, e.g., pooling head for Siglip. This is done prior to
            feature selection and layer normalization. If select_layers are
            provided, the output of last_hs_proc must be able to be
            concatenated with the other select_layers along the last dimension.
223
224
        feature_select_strategy: Defines how to select the hidden states
            from each layer.
225
    """
226
227
    if select_layers is None:
        if not isinstance(encoder_outputs, torch.Tensor):
228
229
230
231
            raise ValueError(
                "Expected only a single encoder output when "
                "`select_layers` is not provided"
            )
232

233
234
235
236
237
        # Preprocess the encoder outputs as needed, e.g., map head
        # and layer norm for siglip, which runs before feature selection
        if last_hs_proc is not None:
            encoder_outputs = last_hs_proc(encoder_outputs)

238
        if feature_select_strategy is not None:
239
            select_features = _get_vision_feature_selector(feature_select_strategy)
240
241
            encoder_outputs = select_features(encoder_outputs)

242
243
        if post_layer_norm is not None:
            return post_layer_norm(encoder_outputs)
244

245
246
        return encoder_outputs

247
    if max_possible_layers is None:
248
249
250
        raise ValueError(
            "`max_possible_layers` must be provided alongside `select_layers`"
        )
251

252
253
254
    # Get the hidden states corresponding to the layer indices.
    # Negative values are relative to the full visual encoder,
    # so offset them depending on how many layers were loaded.
255
256
257
258
259
    # NOTE: this assumes that encoder_outputs is a list containing
    # the inputs to the visual encoder, followed by the hidden states
    # of each layer.
    num_loaded_layers = len(encoder_outputs) - 1
    offset = max_possible_layers - num_loaded_layers
260
261
    hs_pool = [
        encoder_outputs[layer_idx]
262
263
        if layer_idx >= 0
        else encoder_outputs[layer_idx + offset]
264
        for layer_idx in select_layers
265
266
    ]

267
268
269
270
    uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1)
    if last_hs_proc is not None and uses_last_layer:
        hs_pool[-1] = last_hs_proc(hs_pool[-1])

271
272
273
274
    if feature_select_strategy is not None:
        select_features = _get_vision_feature_selector(feature_select_strategy)
        hs_pool = [select_features(hs) for hs in hs_pool]

275
276
    # Apply post-norm on the final hidden state if we are using it
    if post_layer_norm is not None and uses_last_layer:
277
278
        hs_pool[-1] = post_layer_norm(hs_pool[-1])

279
    return torch.cat(hs_pool, dim=-1)
280
281


282
283
284
285
def run_dp_sharded_vision_model(
    image_input: torch.Tensor, vision_model: torch.nn.Module
) -> torch.Tensor:
    """Run a vision model with data parallelism (DP) sharding. The function
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    will shard the input image tensor on the first dimension and run the vision
    model

    Args:
        image_input (torch.Tensor): Image input tensor.
        vision_model (torch.nn.Module): Vision model.
    Returns:
        torch.Tensor: Output image embeddings
    """

    num_chunks = image_input.shape[0]
    mp_world_size = get_tensor_model_parallel_world_size()
    num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
    num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
300
    pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
301
302
    image_input_padded = torch.nn.functional.pad(image_input, pad)
    rank = get_tensor_model_parallel_rank()
303
304
305
    image_input_per_rank = image_input_padded[
        rank * num_chunks_per_rank : (rank + 1) * num_chunks_per_rank, ...
    ]
306
307
308
309

    vision_embeddings = vision_model(image_input_per_rank)
    # Ensure tensor is contiguous before all_gather
    vision_embeddings = vision_embeddings.contiguous()
310
    vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0)
311
312
313
314
315
316
317
318
319
    vision_embeddings = vision_embeddings[:num_chunks, ...]
    return vision_embeddings


def get_load_balance_assignment(
    sizes: list[int],
    num_gpus: int = 2,
) -> tuple[list[int], list[int], list[int]]:
    """
320
    Generate load balancing assignment and metadata
321
322
323
    for distributing data across GPUs.
    The load is determined by the total image sizes,
    not the number of images.
324

325
326
327
    Args:
        sizes: The size of each image
        num_gpus: Number of GPUs to balance across
328

329
    Returns:
330
        shuffle_indices:
331
            Indices to reorder data for balanced loading
332
        gpu_sample_counts:
333
            Number of samples assigned to each GPU
334
        grouped_sizes_per_gpu:
335
            Total size assigned to each GPU
336

337
338
339
    Example:
        ```
        sizes = [1000, 100, 200, 50]
340
        num_gpus = 2
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        ```

    """

    n_samples = len(sizes)

    # Handle edge cases
    if n_samples == 0:
        return [], [0] * num_gpus, [0] * num_gpus

    # Use greedy algorithm - balance by total size, not sample count
    gpu_assignments = [list[int]() for _ in range(num_gpus)]
    gpu_loads = [0] * num_gpus  # This tracks total SIZE, not sample count

    # Sort indices by size (largest first for better load balancing)
    # sizes = [1000, 100, 200, 50]
    # large_to_small_indices = [0, 2, 1, 3]
358
359
360
    large_to_small_indices = sorted(
        range(n_samples), key=lambda i: sizes[i], reverse=True
    )
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

    for idx in large_to_small_indices:
        # Find GPU with minimum current load (by total size)
        min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
        gpu_assignments[min_gpu].append(idx)
        gpu_loads[min_gpu] += sizes[idx]

    # Create shuffle indices and counts
    shuffle_indices = list[int]()
    gpu_sample_counts = list[int]()
    for gpu_id in range(num_gpus):
        # GPU_0 = [1000] = [0]
        # GPU_1 = [200, 100, 50] = [2, 1, 3]
        # shuffle_indices = [0, 2, 1, 3]
        shuffle_indices.extend(gpu_assignments[gpu_id])
        # GPU_0 = [1]
        # GPU_1 = [3]
        # gpu_sample_counts = [1, 3]
        gpu_sample_counts.append(len(gpu_assignments[gpu_id]))

    return (shuffle_indices, gpu_sample_counts, gpu_loads)


def run_dp_sharded_mrope_vision_model(
    vision_model: torch.nn.Module,
    pixel_values: torch.Tensor,
    grid_thw_list: list[list[int]],
    *,
    rope_type: Literal["rope_3d", "rope_2d"],
) -> tuple[torch.Tensor, ...]:
391
392
    """Run a vision model with data parallelism (DP) sharding.
    The function will shard the input image tensor on the
393
394
    first dimension and run the vision model.
    This function is used to run the vision model with mrope.
395

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    Args:
        vision_model (torch.nn.Module): Vision model.
        pixel_values (torch.Tensor): Image/Video input tensor.
        grid_thw_list: List of grid dimensions for each image
        rope_type: Type of rope used in the vision model.
                   Different rope types have different dimension to do ViT.
                   "rope_3d" for 3D rope (e.g., Qwen2.5-VL)
                   "rope_2d" for 2D rope (e.g., Kimi-VL)
    Returns:
        torch.Tensor: Output image embeddings

    Example:
        ```
        vision_model.out_hidden_size = 64
        vision_model.spatial_merge_size = 2
        pixel_values.shape = (1350, channel)
        grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
413
        tp_size = 2
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        ```

    """
    tp_size = get_tensor_model_parallel_world_size()

    # GPU_0 tp_rank_local = 0
    # GPU_1 tp_rank_local = 1
    tp_rank_local = get_tensor_model_parallel_rank()

    # patches_per_image = [1000, 100, 200, 50]
    patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
    # patches_per_image = [0, 1000, 1100, 1300, 1350]
    cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]

    # Get load balancing assignment with all metadata
    # image_to_tp_rank = [0, 2, 1, 3]
    # gpu_sample_counts = [1, 3]
    # grouped_pixel_values_len = [1000, 350]
432
433
434
    (image_to_tp_rank, gpu_sample_counts, grouped_pixel_values_len) = (
        get_load_balance_assignment(patches_per_image, tp_size)
    )
435
436
437
438
439
440

    # cu_gpu_sample_counts = [0, 1, 4]
    cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]

    # GPU_0 image_idxs_local = [0]
    # GPU_1 image_idxs_local = [2, 1, 3]
441
442
443
    image_idxs_local = image_to_tp_rank[
        cum_gpu_sample_counts[tp_rank_local] : cum_gpu_sample_counts[tp_rank_local + 1]
    ]
444
445
446

    # Get the pixel values for the local images based on the image_idxs_local
    if len(image_idxs_local) > 0:
447
448
449
450
451
452
        pixel_values_local = torch.cat(
            [
                pixel_values[cum_patches_per_image[i] : cum_patches_per_image[i + 1]]
                for i in image_idxs_local
            ]
        )
453
454
    else:
        # Handle case where this rank has no images
455
456
457
458
459
        pixel_values_local = torch.empty(
            (0, pixel_values.shape[1]),
            device=pixel_values.device,
            dtype=pixel_values.dtype,
        )
460
461
    # embed_dim_reduction_factor = 2 * 2
    if rope_type == "rope_2d":
462
463
464
        embed_dim_reduction_factor = (
            vision_model.merge_kernel_size[0] * vision_model.merge_kernel_size[1]
        )
465
    else:
466
467
468
        embed_dim_reduction_factor = (
            vision_model.spatial_merge_size * vision_model.spatial_merge_size
        )
469
470
471
472
473

    # Find the max length across all ranks
    # The output embedding of every DP rank has to be
    # padded to this length for tensor_model_parallel_all_gather
    # to work
474
    max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor
475
476
477
478
479
480
    local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]

    # Run the vision model on the local pixel_values_local
    if rope_type == "rope_2d":
        if pixel_values_local.shape[0] > 0:
            image_embeds_local = vision_model(
481
482
                pixel_values_local, torch.tensor(local_grid_thw_list)
            )
483
484
485
486
487
488
489
            if isinstance(image_embeds_local, list):
                image_embeds_local = torch.cat(image_embeds_local, dim=0)
        else:
            out_dim = getattr(vision_model.config, "hidden_size", None)
            image_embeds_local = torch.empty(
                (0, embed_dim_reduction_factor, out_dim),
                device=pixel_values.device,
490
491
                dtype=pixel_values.dtype,
            )
492
493
    else:
        if pixel_values_local.shape[0] > 0:
494
            image_embeds_local = vision_model(pixel_values_local, local_grid_thw_list)
495
496
        else:
            # Handle empty case
497
498
499
500
501
            image_embeds_local = torch.empty(
                (0, vision_model.out_hidden_size),
                device=pixel_values.device,
                dtype=pixel_values.dtype,
            )
502
503
504
505
506
507
508

    # Pad the output based on max_len_per_rank
    # for tensor_model_parallel_all_gather to work
    current_len = image_embeds_local.shape[0]
    if current_len < max_len_per_rank:
        padding_size = max_len_per_rank - current_len
        if rope_type == "rope_2d":
509
510
511
512
513
514
515
516
517
            padding = torch.empty(
                (
                    padding_size,
                    image_embeds_local.shape[1],
                    image_embeds_local.shape[2],
                ),
                dtype=image_embeds_local.dtype,
                device=image_embeds_local.device,
            )
518
        else:
519
520
521
522
523
524
            padding = torch.empty(
                (padding_size, image_embeds_local.shape[1]),
                dtype=image_embeds_local.dtype,
                device=image_embeds_local.device,
            )
        image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0)
525
526
527
528
    else:
        image_embeds_local_padded = image_embeds_local

    # Do all_gather to collect embeddings from all ranks
529
    gathered_embeds = tensor_model_parallel_all_gather(image_embeds_local_padded, dim=0)
530
531
532
533
534

    # Remove padding and reconstruct per-rank embeddings
    rank_embeddings = list[torch.Tensor]()
    for rank in range(tp_size):
        start_idx = rank * max_len_per_rank
535
536
537
        end_idx = start_idx + (
            grouped_pixel_values_len[rank] // embed_dim_reduction_factor
        )
538
539
        rank_embeddings.append(gathered_embeds[start_idx:end_idx])

540
541
542
    patches_per_output_image = [
        (patch_size // embed_dim_reduction_factor) for patch_size in patches_per_image
    ]
543
544
545
546
547
548
549
550
551
552

    # Reconstruct embeddings in the original order
    original_order_embeddings = [None] * len(grid_thw_list)
    current_idx = 0
    for rank in range(tp_size):
        count = gpu_sample_counts[rank]
        if count > 0:
            # Get images assigned to this rank in shuffled order
            # GPU_0 = image_idxs_local  [0]
            # GPU_1 = image_idxs_local  [2, 1, 3]
553
            rank_images = image_to_tp_rank[current_idx : current_idx + count]
554
555
556
557
558
559
560

            rank_embed = rank_embeddings[rank]
            # Split rank embeddings back to individual images
            embed_start = 0
            for img_idx in rank_images:
                img_patches = patches_per_output_image[img_idx]
                original_order_embeddings[img_idx] = rank_embed[
561
562
                    embed_start : embed_start + img_patches
                ]
563
564
                embed_start += img_patches
            current_idx += count
565
566
567
568
569
570
    out_embeddings = tuple(
        embed for embed in original_order_embeddings if embed is not None
    )
    assert len(out_embeddings) == len(original_order_embeddings), (
        "Found unassigned embeddings"
    )
571
    return out_embeddings
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608


def get_llm_pos_ids_for_vision(
    start_idx: int,
    vision_idx: int,
    spatial_merge_size: int,
    t_index: list[int],
    grid_hs: torch.Tensor,
    grid_ws: torch.Tensor,
) -> torch.Tensor:
    llm_pos_ids_list = []
    llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
    llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
    h_index = (
        torch.arange(llm_grid_h)
        .view(1, -1, 1)
        .expand(len(t_index), -1, llm_grid_w)
        .flatten()
    )
    w_index = (
        torch.arange(llm_grid_w)
        .view(1, 1, -1)
        .expand(len(t_index), llm_grid_h, -1)
        .flatten()
    )
    t_index_tensor = (
        torch.Tensor(t_index)
        .to(llm_grid_h.device)
        .view(-1, 1)
        .expand(-1, llm_grid_h * llm_grid_w)
        .long()
        .flatten()
    )
    _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
    llm_pos_ids_list.append(_llm_pos_ids + start_idx)
    llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
    return llm_pos_ids