aimv2.py 8.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

# A modified implementation of the AIMv2 Transformer
# inserted here also the image tokenizer used by Ovis2
6
from collections.abc import Iterable
7
8
9
from typing import Optional

import torch
10
import torch.nn as nn
11

12
13
14
15
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
16
from vllm.model_executor.layers.layernorm import RMSNorm
17
18
19
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
20
21
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
22
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
from vllm.transformers_utils.configs.ovis import AIMv2Config
24
25
26
27
28
29
30
31
32
33
34


class AIMv2SwiGLUFFN(nn.Module):

    def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
                 prefix: str):
        super().__init__()
        hidden_features = config.intermediate_size
        in_features = config.hidden_size
        bias = config.use_bias

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        self.fc13 = MergedColumnParallelLinear(
            in_features,
            [hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc13",
        )
        self.fc2 = RowParallelLinear(
            input_size=hidden_features,
            output_size=in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
        self.act_fn = SiluAndMul()
50
51

    def forward(self, x: torch.Tensor) -> torch.Tensor:
52
53
54
55
        x, _ = self.fc13(x)
        x = self.act_fn(x)
        x, _ = self.fc2(x)
        return x
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


class AIMv2PatchEmbed(nn.Module):

    def __init__(self, config: AIMv2Config):
        super().__init__()
        self.proj = nn.Conv2d(
            config.num_channels,
            config.hidden_size,
            kernel_size=(config.patch_size, config.patch_size),
            stride=(config.patch_size, config.patch_size),
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm.forward_native(x)
        return x


class AIMv2ViTPreprocessor(nn.Module):

    def __init__(self, config: AIMv2Config):
        super().__init__()
        num_patches = (config.image_size // config.patch_size)**2

        self.patchifier = AIMv2PatchEmbed(config)
        self.pos_embed = nn.Parameter(
            torch.zeros((1, num_patches, config.hidden_size)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        tokens = self.patchifier(x)
        _, N, _ = tokens.shape
        pos_embed = self.pos_embed.to(tokens.device)
        tokens = tokens + pos_embed[:, :N]
        return tokens


class AIMv2Attention(nn.Module):

    def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
                 prefix: str):
        super().__init__()
99
100
        self.config = config
        self.embed_dim = config.hidden_size
101
        self.num_heads = config.num_attention_heads
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5

        self.qkv = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            bias=config.qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
        )

        self.proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            bias=config.use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
129

130
131
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
132

133
134
135
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
136

137
        x = self.attn(q, k, v)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        x, _ = self.proj(x)
        return x


class AIMv2Block(nn.Module):

    def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
                 prefix: str):
        super().__init__()
        self.attn = AIMv2Attention(config,
                                   quant_config=quant_config,
                                   prefix=f"{prefix}.attn")
        self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mlp = AIMv2SwiGLUFFN(config,
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.mlp")
        self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

156
157
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm_1.forward_native(x))
158
159
160
161
162
163
        x = x + self.mlp(self.norm_2.forward_native(x))
        return x


class AIMv2Transformer(nn.Module):

164
165
166
167
168
169
170
171
    def __init__(
        self,
        config: AIMv2Config,
        quant_config: QuantizationConfig,
        *,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ):
172
173
174
175
176
177
        super().__init__()

        self.blocks = nn.ModuleList([
            AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
            for i in range(config.num_hidden_layers)
        ])
178
179
180
181
182
        if require_post_norm:
            self.post_trunk_norm = RMSNorm(config.hidden_size,
                                           eps=config.rms_norm_eps)
        else:
            self.post_trunk_norm = None
183

184
    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
185
186
        # they take the -1 as the ref embeddings, like a clip skip
        for block in self.blocks:
187
188
189
            tokens = block(tokens)
        if self.post_trunk_norm is not None:
            tokens = self.post_trunk_norm(tokens)
190
191
192
193
194
195
196
197
        return tokens


class AIMv2Model(torch.nn.Module):

    def __init__(self,
                 config: AIMv2Config,
                 quant_config: QuantizationConfig,
198
199
                 *,
                 require_post_norm: Optional[bool] = None,
200
201
202
203
204
                 prefix: str = ""):
        super().__init__()
        self.preprocessor = AIMv2ViTPreprocessor(config)
        self.trunk = AIMv2Transformer(config,
                                      quant_config=quant_config,
205
                                      require_post_norm=require_post_norm,
206
207
                                      prefix=f"{prefix}.trunk")

208
    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
209
210

        x = self.preprocessor(pixel_values)
211
        x = self.trunk(x)
212
213

        return x
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".fc13", ".fc1", 0),
            (".fc13", ".fc3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            # post_layernorm is optional in SiglipVisionModel
            if (name.startswith("trunk.post_trunk_norm")
                    and self.trunk.post_trunk_norm is None):
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params