aimv2.py 8.05 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

# A modified implementation of the AIMv2 Transformer
# inserted here also the image tokenizer used by Ovis2
6
from collections.abc import Iterable
7
8

import torch
9
import torch.nn as nn
10

11
12
13
14
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
15
from vllm.model_executor.layers.layernorm import RMSNorm
16
17
18
19
20
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
21
from vllm.model_executor.layers.quantization import QuantizationConfig
22
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
from vllm.transformers_utils.configs.ovis import AIMv2Config
24
25
26


class AIMv2SwiGLUFFN(nn.Module):
27
28
29
    def __init__(
        self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
    ):
30
31
32
33
34
        super().__init__()
        hidden_features = config.intermediate_size
        in_features = config.hidden_size
        bias = config.use_bias

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        self.fc13 = MergedColumnParallelLinear(
            in_features,
            [hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc13",
        )
        self.fc2 = RowParallelLinear(
            input_size=hidden_features,
            output_size=in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
        self.act_fn = SiluAndMul()
50
51

    def forward(self, x: torch.Tensor) -> torch.Tensor:
52
53
54
55
        x, _ = self.fc13(x)
        x = self.act_fn(x)
        x, _ = self.fc2(x)
        return x
56
57
58


class AIMv2PatchEmbed(nn.Module):
59
    def __init__(self, config: AIMv2Config):
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        super().__init__()
        self.proj = nn.Conv2d(
            config.num_channels,
            config.hidden_size,
            kernel_size=(config.patch_size, config.patch_size),
            stride=(config.patch_size, config.patch_size),
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm.forward_native(x)
        return x


class AIMv2ViTPreprocessor(nn.Module):
76
    def __init__(self, config: AIMv2Config):
77
        super().__init__()
78
        num_patches = (config.image_size // config.patch_size) ** 2
79
80

        self.patchifier = AIMv2PatchEmbed(config)
81
        self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
82
83
84
85
86
87
88
89
90
91

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        tokens = self.patchifier(x)
        _, N, _ = tokens.shape
        pos_embed = self.pos_embed.to(tokens.device)
        tokens = tokens + pos_embed[:, :N]
        return tokens


class AIMv2Attention(nn.Module):
92
93
94
    def __init__(
        self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
    ):
95
        super().__init__()
96
97
        self.config = config
        self.embed_dim = config.hidden_size
98
        self.num_heads = config.num_attention_heads
99
100
101
102
103
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
104
105
                f" {self.num_heads})."
            )
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        self.scale = self.head_dim**-0.5

        self.qkv = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            bias=config.qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
        )

        self.proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            bias=config.use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
127

128
129
130
        self.attn = MultiHeadAttention(
            self.num_heads_per_partition, self.head_dim, self.scale
        )
131

132
133
134
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
135

136
        x = self.attn(q, k, v)
137
138
139
140
141
        x, _ = self.proj(x)
        return x


class AIMv2Block(nn.Module):
142
143
144
    def __init__(
        self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
    ):
145
        super().__init__()
146
147
148
        self.attn = AIMv2Attention(
            config, quant_config=quant_config, prefix=f"{prefix}.attn"
        )
149
        self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
150
151
152
        self.mlp = AIMv2SwiGLUFFN(
            config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
153
154
        self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

155
156
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm_1.forward_native(x))
157
158
159
160
161
        x = x + self.mlp(self.norm_2.forward_native(x))
        return x


class AIMv2Transformer(nn.Module):
162
163
    def __init__(
        self,
164
        config: AIMv2Config,
165
166
        quant_config: QuantizationConfig,
        *,
167
        require_post_norm: bool | None = None,
168
169
        prefix: str = "",
    ):
170
171
        super().__init__()

172
173
174
175
176
177
        self.blocks = nn.ModuleList(
            [
                AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
                for i in range(config.num_hidden_layers)
            ]
        )
178
        if require_post_norm:
179
            self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
180
181
        else:
            self.post_trunk_norm = None
182

183
    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
184
185
        # they take the -1 as the ref embeddings, like a clip skip
        for block in self.blocks:
186
187
188
            tokens = block(tokens)
        if self.post_trunk_norm is not None:
            tokens = self.post_trunk_norm(tokens)
189
190
191
192
        return tokens


class AIMv2Model(torch.nn.Module):
193
194
195
196
197
    def __init__(
        self,
        config: AIMv2Config,
        quant_config: QuantizationConfig,
        *,
198
        require_post_norm: bool | None = None,
199
200
        prefix: str = "",
    ):
201
202
        super().__init__()
        self.preprocessor = AIMv2ViTPreprocessor(config)
203
204
205
206
207
208
        self.trunk = AIMv2Transformer(
            config,
            quant_config=quant_config,
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.trunk",
        )
209

210
    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
211
        x = self.preprocessor(pixel_values)
212
        x = self.trunk(x)
213
214

        return x
215

216
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
217
218
219
220
221
222
223
224
225
226
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".fc13", ".fc1", 0),
            (".fc13", ".fc3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            # post_layernorm is optional in SiglipVisionModel
227
228
229
230
            if (
                name.startswith("trunk.post_trunk_norm")
                and self.trunk.post_trunk_norm is None
            ):
231
232
                continue

233
            for param_name, weight_name, shard_id in stacked_params_mapping:
234
235
236
237
238
239
240
241
242
243
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
244
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
245
246
247
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params