model.py 3.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
17
from typing import Any, Dict, List, Optional, Tuple
18
19

import torch
20
from transformers import AutoConfig, AutoModel
21

22
logger = logging.getLogger(__name__)
23
24


25
26
27
28
29
30
class SupportedModels:
    """Supported multimodal model identifiers"""

    LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
    QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
    LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
31
32
33
34
35
36


def load_vision_model(model_id: str) -> torch.nn.Module:
    """
    Load a vision model from a HuggingFace model ID.
    """
37
38
    model = AutoModel.from_pretrained(
        model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
39
    )
40
    return model
41
42
43


def get_vision_embeddings_info(
44
    model_id: str,
45
46
) -> Tuple[Tuple[int, int, int], torch.dtype]:
    """Calculate vision embeddings size and dtype using model config
47
    Returns a tuple of (batch_size, seq_len, hidden_dim), dtype.
48
49
    """
    config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
50
51
52
53
54
55
56
57

    if model_id == SupportedModels.LLAVA_1_5_7B:
        seq_len = 577
    elif model_id == SupportedModels.QWEN_2_5_VL_7B:
        seq_len = 345
    else:
        seq_len = 0

58
59
60
61
62
63
64
65
66
    if not hasattr(config, "torch_dtype"):
        raise ValueError("Model config missing required 'torch_dtype' attribute")
    if not hasattr(config, "hidden_size"):
        logger.warning(
            "Model config missing required 'hidden_size' attribute, using 4096"
        )
        hidden_size = 4096
    else:
        hidden_size = config.hidden_size
67
    return (1, seq_len, hidden_size), config.torch_dtype
68
69
70
71
72
73


def construct_mm_data(
    model: str,
    image_embeds: torch.Tensor,
    embeddings_dtype: torch.dtype,
74
    image_grid_thw: Optional[List[Any]],
75
76
77
) -> Dict[str, torch.Tensor | Dict[str, Any]]:
    """Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
    image_embeds = image_embeds.to(embeddings_dtype)
78
79
80
81
82
83
    if model == SupportedModels.QWEN_2_5_VL_7B:
        if image_grid_thw is not None and len(image_grid_thw) > 0:
            grid_thw_tensor = torch.tensor(image_grid_thw)
        else:
            raise ValueError("No image grid provided.")

84
85
86
        return {
            "image": {
                "image_embeds": image_embeds.squeeze(0),
87
                "image_grid_thw": grid_thw_tensor,
88
89
90
91
            }
        }
    else:
        return {"image": image_embeds}