vision.py 18.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import itertools
import math
6
from abc import ABC, abstractmethod
7
8
from collections.abc import Callable
from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
9

10
import torch
11
12
from transformers import PretrainedConfig

13
from vllm.attention.backends.registry import AttentionBackendEnum
14
from vllm.config import VllmConfig, get_current_vllm_config
15
16
17
18
19
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
20
from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22
23

logger = init_logger(__name__)
24

25
26
27
_C = TypeVar("_C", bound=PretrainedConfig)


28
29
30
31
class _RootConfig(Protocol[_C]):
    vision_config: _C


32
class VisionEncoderInfo(ABC, Generic[_C]):
33
    def __init__(self, hf_config: _RootConfig[_C]) -> None:
34
35
        super().__init__()

36
37
        self.hf_config = hf_config
        self.vision_config = hf_config.vision_config
38
39
40
41
42
43
44
45
46
47
48

    @abstractmethod
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        raise NotImplementedError

    @abstractmethod
49
    def get_image_size(self) -> int:
50
51
52
        raise NotImplementedError

    @abstractmethod
53
54
55
56
57
    def get_patch_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_grid_length(self) -> int:
58
59
60
        raise NotImplementedError


61
62
63
64
class VisionLanguageConfig(Protocol):
    vision_config: Final[PretrainedConfig]


65
def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
66
67
68
69
70
    # Avoid circular imports
    from .clip import CLIPEncoderInfo, CLIPVisionConfig
    from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
    from .siglip import SiglipEncoderInfo, SiglipVisionConfig

71
72
73
74
75
76
77
78
    if isinstance(hf_config.vision_config, CLIPVisionConfig):
        return CLIPEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFEncoderInfo(hf_config)
    if isinstance(hf_config.vision_config, SiglipVisionConfig):
        return SiglipEncoderInfo(hf_config)

    msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
79
    raise NotImplementedError(msg)
80
81


82
83
84
85
def get_vit_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    *,
86
87
    attn_backend_override: AttentionBackendEnum | None = None,
) -> AttentionBackendEnum:
88
89
90
    """
    Get the available attention backend for Vision Transformer.
    """
91
92
93
    if attn_backend_override is not None:
        return attn_backend_override

94
    selected_backend = get_current_vllm_config().attention_config.backend
95
96
97
    if selected_backend is not None:
        return selected_backend

98
    return current_platform.get_vit_attn_backend(head_size, dtype)
99
100


Harry Mellor's avatar
Harry Mellor committed
101
102
103
104
105
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
    """Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
    return vllm_config.compilation_config.compile_mm_encoder


106
107
VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]

108
109
110
VisionFeatureSelectStrategy: TypeAlias = (
    VisionFeatureSelectStrategyStr | Callable[[torch.Tensor], torch.Tensor]
)
111
112
113


def _get_vision_feature_selector(
114
    strategy: VisionFeatureSelectStrategy | str,
115
116
117
118
119
120
) -> Callable[[torch.Tensor], torch.Tensor]:
    if callable(strategy):
        return strategy

    # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
    if strategy == "class":
121
        return lambda feats: feats[:, :1, :]
122
123
124
125
126
127
128
129

    # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
    if strategy == "default":
        return lambda feats: feats[:, 1:, :]

    if strategy == "full":
        return lambda feats: feats

130
131
132
133
134
    raise ValueError(f"Unexpected feature select strategy: {strategy!r}")


def get_num_selected_vision_tokens(
    num_vision_tokens: int,
135
    strategy: VisionFeatureSelectStrategy | str,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
) -> int:
    if callable(strategy):
        dummy_features = torch.empty(1, num_vision_tokens, 64)  # [B, L, D]
        dummy_selected_features = strategy(dummy_features)
        return dummy_selected_features.shape[1]

    if strategy == "class":
        return 1

    if strategy == "default":
        return num_vision_tokens - 1

    if strategy == "full":
        return num_vision_tokens

    raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
152
153


154
def resolve_visual_encoder_outputs(
155
156
    encoder_outputs: torch.Tensor | list[torch.Tensor],
    post_layer_norm: torch.nn.LayerNorm | None,
157
    *,
158
159
160
    select_layers: list[int] | None = None,
    max_possible_layers: int | None = None,
    feature_select_strategy: VisionFeatureSelectStrategy | None = None,
161
162
163
164
165
166
167
168
) -> torch.Tensor:
    """Given the outputs a visual encoder module that may correspond to the
    output of the last layer, or a list of hidden states to be stacked,
    handle post normalization and resolve it into a single output tensor.

    Args:
        encoder_outputs: Output of encoder's last layer or all hidden states.
        post_layer_norm: Post norm to apply to the output of the encoder.
169
170
        select_layers: Optional layer indices to grab from the encoder
            outputs; if provided, encoder outputs must be a list.
171
        max_possible_layers: Total layers in the fully loaded visual encoder.
172
173
        feature_select_strategy: Defines how to select the hidden states
            from each layer.
174
    """
175
176
    if select_layers is None:
        if not isinstance(encoder_outputs, torch.Tensor):
177
178
179
180
            raise ValueError(
                "Expected only a single encoder output when "
                "`select_layers` is not provided"
            )
181
182

        if feature_select_strategy is not None:
183
            select_features = _get_vision_feature_selector(feature_select_strategy)
184
185
            encoder_outputs = select_features(encoder_outputs)

186
187
        if post_layer_norm is not None:
            return post_layer_norm(encoder_outputs)
188

189
190
        return encoder_outputs

191
    if max_possible_layers is None:
192
193
194
        raise ValueError(
            "`max_possible_layers` must be provided alongside `select_layers`"
        )
195

196
197
198
    # Get the hidden states corresponding to the layer indices.
    # Negative values are relative to the full visual encoder,
    # so offset them depending on how many layers were loaded.
199
200
201
202
203
    # NOTE: this assumes that encoder_outputs is a list containing
    # the inputs to the visual encoder, followed by the hidden states
    # of each layer.
    num_loaded_layers = len(encoder_outputs) - 1
    offset = max_possible_layers - num_loaded_layers
204
205
    hs_pool = [
        encoder_outputs[layer_idx]
206
207
        if layer_idx >= 0
        else encoder_outputs[layer_idx + offset]
208
        for layer_idx in select_layers
209
210
    ]

211
212
213
214
    if feature_select_strategy is not None:
        select_features = _get_vision_feature_selector(feature_select_strategy)
        hs_pool = [select_features(hs) for hs in hs_pool]

215
    # Apply post-norm on the final hidden state if we are using it
216
    uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1)
217
    if post_layer_norm is not None and uses_last_layer:
218
219
        hs_pool[-1] = post_layer_norm(hs_pool[-1])

220
    return torch.cat(hs_pool, dim=-1)
221
222


223
224
225
226
def run_dp_sharded_vision_model(
    image_input: torch.Tensor, vision_model: torch.nn.Module
) -> torch.Tensor:
    """Run a vision model with data parallelism (DP) sharding. The function
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    will shard the input image tensor on the first dimension and run the vision
    model

    Args:
        image_input (torch.Tensor): Image input tensor.
        vision_model (torch.nn.Module): Vision model.
    Returns:
        torch.Tensor: Output image embeddings
    """

    num_chunks = image_input.shape[0]
    mp_world_size = get_tensor_model_parallel_world_size()
    num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
    num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
241
    pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
242
243
    image_input_padded = torch.nn.functional.pad(image_input, pad)
    rank = get_tensor_model_parallel_rank()
244
245
246
    image_input_per_rank = image_input_padded[
        rank * num_chunks_per_rank : (rank + 1) * num_chunks_per_rank, ...
    ]
247
248
249
250

    vision_embeddings = vision_model(image_input_per_rank)
    # Ensure tensor is contiguous before all_gather
    vision_embeddings = vision_embeddings.contiguous()
251
    vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0)
252
253
254
255
256
257
258
259
260
    vision_embeddings = vision_embeddings[:num_chunks, ...]
    return vision_embeddings


def get_load_balance_assignment(
    sizes: list[int],
    num_gpus: int = 2,
) -> tuple[list[int], list[int], list[int]]:
    """
261
    Generate load balancing assignment and metadata
262
263
264
    for distributing data across GPUs.
    The load is determined by the total image sizes,
    not the number of images.
265

266
267
268
    Args:
        sizes: The size of each image
        num_gpus: Number of GPUs to balance across
269

270
    Returns:
271
        shuffle_indices:
272
            Indices to reorder data for balanced loading
273
        gpu_sample_counts:
274
            Number of samples assigned to each GPU
275
        grouped_sizes_per_gpu:
276
            Total size assigned to each GPU
277

278
279
280
    Example:
        ```
        sizes = [1000, 100, 200, 50]
281
        num_gpus = 2
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        ```

    """

    n_samples = len(sizes)

    # Handle edge cases
    if n_samples == 0:
        return [], [0] * num_gpus, [0] * num_gpus

    # Use greedy algorithm - balance by total size, not sample count
    gpu_assignments = [list[int]() for _ in range(num_gpus)]
    gpu_loads = [0] * num_gpus  # This tracks total SIZE, not sample count

    # Sort indices by size (largest first for better load balancing)
    # sizes = [1000, 100, 200, 50]
    # large_to_small_indices = [0, 2, 1, 3]
299
300
301
    large_to_small_indices = sorted(
        range(n_samples), key=lambda i: sizes[i], reverse=True
    )
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

    for idx in large_to_small_indices:
        # Find GPU with minimum current load (by total size)
        min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
        gpu_assignments[min_gpu].append(idx)
        gpu_loads[min_gpu] += sizes[idx]

    # Create shuffle indices and counts
    shuffle_indices = list[int]()
    gpu_sample_counts = list[int]()
    for gpu_id in range(num_gpus):
        # GPU_0 = [1000] = [0]
        # GPU_1 = [200, 100, 50] = [2, 1, 3]
        # shuffle_indices = [0, 2, 1, 3]
        shuffle_indices.extend(gpu_assignments[gpu_id])
        # GPU_0 = [1]
        # GPU_1 = [3]
        # gpu_sample_counts = [1, 3]
        gpu_sample_counts.append(len(gpu_assignments[gpu_id]))

    return (shuffle_indices, gpu_sample_counts, gpu_loads)


def run_dp_sharded_mrope_vision_model(
    vision_model: torch.nn.Module,
    pixel_values: torch.Tensor,
    grid_thw_list: list[list[int]],
    *,
    rope_type: Literal["rope_3d", "rope_2d"],
) -> tuple[torch.Tensor, ...]:
332
333
    """Run a vision model with data parallelism (DP) sharding.
    The function will shard the input image tensor on the
334
335
    first dimension and run the vision model.
    This function is used to run the vision model with mrope.
336

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    Args:
        vision_model (torch.nn.Module): Vision model.
        pixel_values (torch.Tensor): Image/Video input tensor.
        grid_thw_list: List of grid dimensions for each image
        rope_type: Type of rope used in the vision model.
                   Different rope types have different dimension to do ViT.
                   "rope_3d" for 3D rope (e.g., Qwen2.5-VL)
                   "rope_2d" for 2D rope (e.g., Kimi-VL)
    Returns:
        torch.Tensor: Output image embeddings

    Example:
        ```
        vision_model.out_hidden_size = 64
        vision_model.spatial_merge_size = 2
        pixel_values.shape = (1350, channel)
        grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
354
        tp_size = 2
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        ```

    """
    tp_size = get_tensor_model_parallel_world_size()

    # GPU_0 tp_rank_local = 0
    # GPU_1 tp_rank_local = 1
    tp_rank_local = get_tensor_model_parallel_rank()

    # patches_per_image = [1000, 100, 200, 50]
    patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
    # patches_per_image = [0, 1000, 1100, 1300, 1350]
    cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]

    # Get load balancing assignment with all metadata
    # image_to_tp_rank = [0, 2, 1, 3]
    # gpu_sample_counts = [1, 3]
    # grouped_pixel_values_len = [1000, 350]
373
374
375
    (image_to_tp_rank, gpu_sample_counts, grouped_pixel_values_len) = (
        get_load_balance_assignment(patches_per_image, tp_size)
    )
376
377
378
379
380
381

    # cu_gpu_sample_counts = [0, 1, 4]
    cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]

    # GPU_0 image_idxs_local = [0]
    # GPU_1 image_idxs_local = [2, 1, 3]
382
383
384
    image_idxs_local = image_to_tp_rank[
        cum_gpu_sample_counts[tp_rank_local] : cum_gpu_sample_counts[tp_rank_local + 1]
    ]
385
386
387

    # Get the pixel values for the local images based on the image_idxs_local
    if len(image_idxs_local) > 0:
388
389
390
391
392
393
        pixel_values_local = torch.cat(
            [
                pixel_values[cum_patches_per_image[i] : cum_patches_per_image[i + 1]]
                for i in image_idxs_local
            ]
        )
394
395
    else:
        # Handle case where this rank has no images
396
397
398
399
400
        pixel_values_local = torch.empty(
            (0, pixel_values.shape[1]),
            device=pixel_values.device,
            dtype=pixel_values.dtype,
        )
401
402
    # embed_dim_reduction_factor = 2 * 2
    if rope_type == "rope_2d":
403
404
405
        embed_dim_reduction_factor = (
            vision_model.merge_kernel_size[0] * vision_model.merge_kernel_size[1]
        )
406
    else:
407
408
409
        embed_dim_reduction_factor = (
            vision_model.spatial_merge_size * vision_model.spatial_merge_size
        )
410
411
412
413
414

    # Find the max length across all ranks
    # The output embedding of every DP rank has to be
    # padded to this length for tensor_model_parallel_all_gather
    # to work
415
    max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor
416
417
418
419
420
421
    local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]

    # Run the vision model on the local pixel_values_local
    if rope_type == "rope_2d":
        if pixel_values_local.shape[0] > 0:
            image_embeds_local = vision_model(
422
423
                pixel_values_local, torch.tensor(local_grid_thw_list)
            )
424
425
426
427
428
429
430
            if isinstance(image_embeds_local, list):
                image_embeds_local = torch.cat(image_embeds_local, dim=0)
        else:
            out_dim = getattr(vision_model.config, "hidden_size", None)
            image_embeds_local = torch.empty(
                (0, embed_dim_reduction_factor, out_dim),
                device=pixel_values.device,
431
432
                dtype=pixel_values.dtype,
            )
433
434
    else:
        if pixel_values_local.shape[0] > 0:
435
            image_embeds_local = vision_model(pixel_values_local, local_grid_thw_list)
436
437
        else:
            # Handle empty case
438
439
440
441
442
            image_embeds_local = torch.empty(
                (0, vision_model.out_hidden_size),
                device=pixel_values.device,
                dtype=pixel_values.dtype,
            )
443
444
445
446
447
448
449

    # Pad the output based on max_len_per_rank
    # for tensor_model_parallel_all_gather to work
    current_len = image_embeds_local.shape[0]
    if current_len < max_len_per_rank:
        padding_size = max_len_per_rank - current_len
        if rope_type == "rope_2d":
450
451
452
453
454
455
456
457
458
            padding = torch.empty(
                (
                    padding_size,
                    image_embeds_local.shape[1],
                    image_embeds_local.shape[2],
                ),
                dtype=image_embeds_local.dtype,
                device=image_embeds_local.device,
            )
459
        else:
460
461
462
463
464
465
            padding = torch.empty(
                (padding_size, image_embeds_local.shape[1]),
                dtype=image_embeds_local.dtype,
                device=image_embeds_local.device,
            )
        image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0)
466
467
468
469
    else:
        image_embeds_local_padded = image_embeds_local

    # Do all_gather to collect embeddings from all ranks
470
    gathered_embeds = tensor_model_parallel_all_gather(image_embeds_local_padded, dim=0)
471
472
473
474
475

    # Remove padding and reconstruct per-rank embeddings
    rank_embeddings = list[torch.Tensor]()
    for rank in range(tp_size):
        start_idx = rank * max_len_per_rank
476
477
478
        end_idx = start_idx + (
            grouped_pixel_values_len[rank] // embed_dim_reduction_factor
        )
479
480
        rank_embeddings.append(gathered_embeds[start_idx:end_idx])

481
482
483
    patches_per_output_image = [
        (patch_size // embed_dim_reduction_factor) for patch_size in patches_per_image
    ]
484
485
486
487
488
489
490
491
492
493

    # Reconstruct embeddings in the original order
    original_order_embeddings = [None] * len(grid_thw_list)
    current_idx = 0
    for rank in range(tp_size):
        count = gpu_sample_counts[rank]
        if count > 0:
            # Get images assigned to this rank in shuffled order
            # GPU_0 = image_idxs_local  [0]
            # GPU_1 = image_idxs_local  [2, 1, 3]
494
            rank_images = image_to_tp_rank[current_idx : current_idx + count]
495
496
497
498
499
500
501

            rank_embed = rank_embeddings[rank]
            # Split rank embeddings back to individual images
            embed_start = 0
            for img_idx in rank_images:
                img_patches = patches_per_output_image[img_idx]
                original_order_embeddings[img_idx] = rank_embed[
502
503
                    embed_start : embed_start + img_patches
                ]
504
505
                embed_start += img_patches
            current_idx += count
506
507
508
509
510
511
    out_embeddings = tuple(
        embed for embed in original_order_embeddings if embed is not None
    )
    assert len(out_embeddings) == len(original_order_embeddings), (
        "Found unassigned embeddings"
    )
512
    return out_embeddings
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549


def get_llm_pos_ids_for_vision(
    start_idx: int,
    vision_idx: int,
    spatial_merge_size: int,
    t_index: list[int],
    grid_hs: torch.Tensor,
    grid_ws: torch.Tensor,
) -> torch.Tensor:
    llm_pos_ids_list = []
    llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
    llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
    h_index = (
        torch.arange(llm_grid_h)
        .view(1, -1, 1)
        .expand(len(t_index), -1, llm_grid_w)
        .flatten()
    )
    w_index = (
        torch.arange(llm_grid_w)
        .view(1, 1, -1)
        .expand(len(t_index), llm_grid_h, -1)
        .flatten()
    )
    t_index_tensor = (
        torch.Tensor(t_index)
        .to(llm_grid_h.device)
        .view(-1, 1)
        .expand(-1, llm_grid_h * llm_grid_w)
        .long()
        .flatten()
    )
    _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
    llm_pos_ids_list.append(_llm_pos_ids + start_idx)
    llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
    return llm_pos_ids