encode_utils.py 4.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Dict, Optional

import torch

from .model import SupportedModels

logger = logging.getLogger(__name__)


def get_qwen_image_features(
    vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
    """
    Extract image features using Qwen-style vision encoder.

    Args:
        vision_encoder: The vision encoder model
        image_embeds: Dictionary containing pixel values and grid information

    Returns:
        Processed image features tensor

    Raises:
        ValueError: If grid_thw is not provided for Qwen model
    """
    pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)

    grid_thw = image_embeds.get("image_grid_thw", None)
    if grid_thw is not None:
        grid_thw = grid_thw.to(vision_encoder.device)
        logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
    else:
        raise ValueError("grid_thw is not provided")

    return (
        vision_encoder.get_image_features(pixel_values, grid_thw)  # type: ignore
        if grid_thw is not None
        else vision_encoder.get_image_features(pixel_values)  # type: ignore
    )


def encode_image_embeddings(
    model_name: str,
    image_embeds: Dict[str, Any],
    vision_encoder: torch.nn.Module,
    projector: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
    """
    Encode image embeddings using the appropriate model-specific encoder.

    Args:
        model_name: The model identifier
        image_embeds: Dictionary containing processed image data
        vision_encoder: The vision encoder module
        projector: The multimodal projector (required for LLaVA-style models)

    Returns:
        Encoded embeddings tensor with normalized shape

    Raises:
        ValueError: If projector is missing for LLaVA models
        NotImplementedError: If model is not supported
    """
    with torch.no_grad():
        # Route through the correct encoder based on model
        if model_name == SupportedModels.LLAVA_1_5_7B:
            pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
            vision_outputs = vision_encoder(pixel_values)

            if projector is None:
                raise ValueError(f"Projector not found for LLaVA model: {model_name}")

            embeddings = projector(vision_outputs.last_hidden_state)

        elif model_name == SupportedModels.QWEN_2_5_VL_7B:
            embeddings = get_qwen_image_features(vision_encoder, image_embeds)

        else:
            raise NotImplementedError(f"Model not supported: {model_name}")

        # Normalize output shape
        if isinstance(embeddings, (tuple, list)):
            embeddings = embeddings[0]
        embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings

        return embeddings


def get_encoder_components(
    model_name: str, vision_model: torch.nn.Module
) -> tuple[Any, Optional[Any]]:
    """
    Get the appropriate vision encoder and projector components for a given model.

    Args:
        model_name: The model identifier
        vision_model: The loaded vision model

    Returns:
        Tuple of (vision_encoder, projector) where types depend on the model

    Raises:
        NotImplementedError: If model is not supported
    """
    if model_name == SupportedModels.LLAVA_1_5_7B:
        vision_encoder = vision_model.vision_tower
        projector = getattr(vision_model, "multi_modal_projector", None)
        return vision_encoder, projector

    elif model_name == SupportedModels.QWEN_2_5_VL_7B:
        vision_encoder = vision_model
        projector = None
        return vision_encoder, projector

    else:
        raise NotImplementedError(f"Model not supported: {model_name}")