yivl.py 4.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Christopher Chou's avatar
Christopher Chou committed
14
"""Inference-only Yi-VL model."""
15

16
from typing import Iterable, Optional, Tuple
Christopher Chou's avatar
Christopher Chou committed
17
18
19
20

import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig
21
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Christopher Chou's avatar
Christopher Chou committed
22

23
from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
from sglang.srt.models.llava import LlavaLlamaForCausalLM
Liangsheng Yin's avatar
Liangsheng Yin committed
25

Christopher Chou's avatar
Christopher Chou committed
26
27

class YiVLForCausalLM(LlavaLlamaForCausalLM):
28
    def __init__(
Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31
        self,
        config: LlavaConfig,
        quant_config: Optional[QuantizationConfig] = None,
32
        cache_config=None,
33
    ) -> None:
Lianmin Zheng's avatar
Lianmin Zheng committed
34
        super().__init__(config, quant_config, cache_config)
Christopher Chou's avatar
Christopher Chou committed
35
36

        self.multi_modal_projector = YiVLMultiModalProjector(self.config)
37
38
39
        self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
            "./", ""
        )  # Everything after "./"
Christopher Chou's avatar
Christopher Chou committed
40

41
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Christopher Chou's avatar
Christopher Chou committed
42
43
        # We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
        self.vision_tower = CLIPVisionModel.from_pretrained(
44
            self.config._name_or_path,
45
46
            torch_dtype=torch.float16,
            subfolder=self.vision_tower_subfolder,
47
        ).to("cuda")
Christopher Chou's avatar
Christopher Chou committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

        self.vision_tower.eval()

        self.vision_feature_layer = self.config.mm_vision_select_layer
        self.vision_feature_select_strategy = self.config.mm_vision_select_feature
        self.image_size = self.vision_tower.config.image_size
        self.patch_size = self.vision_tower.config.patch_size

        self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
        self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
        self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)

        self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
        if self.vision_feature_select_strategy == "patch":
            pass
        elif self.vision_feature_select_strategy == "cls_patch":
            self.image_feature_len += 1
        else:
            raise ValueError(f"Unexpected select feature: {self.select_feature}")

        # load mm_projector
        # TODO: support TP?
        projector_weights = {
            "model.mm_projector.0": "multi_modal_projector.linear_1",
            "model.mm_projector.1": "multi_modal_projector.ln_1",
            "model.mm_projector.3": "multi_modal_projector.linear_2",
            "model.mm_projector.4": "multi_modal_projector.ln_2",
            "model.vision_tower.vision_tower": "vision_tower",  # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
        }
        params_dict = dict(self.named_parameters())
78
79
        weights = list(weights)
        for name, loaded_weight in weights:
Christopher Chou's avatar
Christopher Chou committed
80
81
82
83
84
85
86
87
88
            if "projector" in name or "vision_tower" in name:
                for weight_name, param_name in projector_weights.items():
                    if weight_name in name:
                        name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)

        # load language model
89
        self.language_model.load_weights(weights)
Christopher Chou's avatar
Christopher Chou committed
90

91

Christopher Chou's avatar
Christopher Chou committed
92
93
94
95
class YiVLMultiModalProjector(nn.Module):
    def __init__(self, config: LlavaConfig):
        super().__init__()

96
97
98
        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size, config.text_config.hidden_size
        )
Christopher Chou's avatar
Christopher Chou committed
99
100
        self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
        self.act = nn.GELU()
101
102
103
        self.linear_2 = nn.Linear(
            config.text_config.hidden_size, config.text_config.hidden_size
        )
Christopher Chou's avatar
Christopher Chou committed
104
105
106
107
        self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)

    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
108
        hidden_states = self.ln_1(hidden_states)
Christopher Chou's avatar
Christopher Chou committed
109
110
111
112
113
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_2(hidden_states)
        return hidden_states

114
115

EntryClass = YiVLForCausalLM