deepseek_vl2.py 12.8 KB
Newer Older
1
from typing import Iterable, List, Optional, Tuple
2
3
4
5
6
7
8
9
10
11

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn

from sglang.srt.configs.deepseekvl2 import (
    DeepseekVL2Config,
    DeepseekVL2MlpProjectorConfig,
)
12
from sglang.srt.layers.linear import ReplicatedLinear
13
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Mick's avatar
Mick committed
14
from sglang.srt.managers.mm_utils import (
15
    MultiModalityDataPaddingPatternMultimodalTokens,
Mick's avatar
Mick committed
16
17
18
    general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
19
20
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
21
from sglang.srt.models.deepseek import DeepseekForCausalLM
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM


class DeepseekVL2MlpProjector(nn.Module):
    def __init__(
        self,
        config: DeepseekVL2MlpProjectorConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):

        super().__init__()

        self.config = config

        if config.projector_type == "identity":
            modules = nn.Identity()

        elif config.projector_type == "linear":
            self.layers = nn.ModuleList(
                [
                    ReplicatedLinear(
                        config.input_dim,
                        config.n_embed,
                        quant_config=quant_config,
                    )
                ]
            )

        elif config.projector_type == "mlp_gelu":
            mlp_depth = config.depth
            self.layers = nn.ModuleList(
                [
                    ReplicatedLinear(
                        config.input_dim,
                        config.n_embed,
                        quant_config=quant_config,
                    )
                ]
            )
            for _ in range(1, mlp_depth):
                self.layers.append(nn.GELU())
                self.layers.append(
                    ReplicatedLinear(
                        config.n_embed,
                        config.n_embed,
                        quant_config=quant_config,
                    )
                )

        elif config.projector_type == "downsample_mlp_gelu":
            mlp_depth = config.depth
            mlp_ratio = config.mlp_ratio
            self.layers = nn.ModuleList(
                [
                    ReplicatedLinear(
                        config.input_dim
                        * config.downsample_ratio
                        * config.downsample_ratio,
                        config.n_embed * mlp_ratio,
                        quant_config=quant_config,
                    )
                ]
            )
            for _ in range(1, mlp_depth - 1):
                self.layers.append(nn.GELU())
                self.layers.append(
                    ReplicatedLinear(
                        config.n_embed * mlp_ratio,
                        config.n_embed * mlp_ratio,
                        quant_config=quant_config,
                    )
                )
            self.layers.append(nn.GELU())
            self.layers.append(
                ReplicatedLinear(
                    config.n_embed * mlp_ratio,
                    config.n_embed,
                    quant_config=quant_config,
                )
            )

        else:
            raise ValueError(f"Unknown projector type: {config.projector_type}")

        if config.token_pooling:
            self.token_pooling_layer = ReplicatedLinear(
                config.input_dim * 4, config.input_dim, quant_config=quant_config
            )

    def forward(self, x):
        if self.config.token_pooling:
            batch_size, wxh, channels = x.shape
            w = h = int(wxh**0.5)
            x = x.view(batch_size, w, h, channels)
            x = x.permute(0, 3, 1, 2)

            patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
            batch_size, channels, h_patches, w_patches, _, _ = patches.size()
            patches = patches.contiguous().view(
                batch_size, channels, h_patches * w_patches, -1
            )
            patches = patches.permute(0, 2, 1, 3).contiguous()
            patches = patches.view(batch_size, h_patches * w_patches, channels * 4)

            x = self.token_pooling_layer(patches)[0]

        elif self.config.projector_type == "downsample_mlp_gelu":
            bs, hw, input_dim = x.shape
            h = w = int((hw) ** 0.5)

            """compute padding"""
            if h % self.config.downsample_ratio:
                pad = self.config.downsample_ratio - h % self.config.downsample_ratio
            else:
                pad = 0
            x = x.reshape(bs, h, w, input_dim)
            if pad > 0:
                x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)

            """4 to 1 concat"""
            x = x.permute(0, 3, 1, 2)  # B, C, H, W
            x = F.unfold(
                x,
                kernel_size=self.config.downsample_ratio,
                stride=self.config.downsample_ratio,
                padding=0,
            )  # B, C*4, HW // 4
            x = x.permute(0, 2, 1)

        for layer in self.layers:
            x = layer(x)
            if isinstance(x, tuple):
                x = x[0]
        return x


class DeepseekVL2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: DeepseekVL2Config,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()

        # ----------- vision encoder ------------
        vision_config = config.vision_config
        self.vision = self._init_vision_module(vision_config, quant_config)

        # ----------- vl projector ------------
        projector_config = config.projector_config
        self.projector = DeepseekVL2MlpProjector(projector_config, quant_config)

        self.tile_tag = config.tile_tag
        self.global_view_pos = config.global_view_pos

        embed_std = 1 / torch.sqrt(
            torch.tensor(projector_config.n_embed, dtype=torch.float32)
        )
        if self.tile_tag == "2D":
            self.image_newline = nn.Parameter(
                torch.randn(projector_config.n_embed) * embed_std
            )
            self.view_seperator = nn.Parameter(
                torch.randn(projector_config.n_embed) * embed_std
            )
        else:
            raise ValueError(f"tile tag should be 2D, but got {self.tile_tag}")

        # ----------- language model ------------
        language_config = config.language_config
193
194
195
196
197
        if language_config.use_mla:
            self.language_model = DeepseekV2ForCausalLM(language_config)
        else:
            # deepseek-vl2-tiny forbids mla
            self.language_model = DeepseekForCausalLM(language_config)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

    def _init_vision_module(
        self, vision_config, quant_config: Optional[QuantizationConfig]
    ) -> nn.Module:
        # TODO: refactor vision model through timm wrapper from transformers
        try:
            import timm
        except ImportError:
            raise ImportError("Please install timm") from ImportError

        model = timm.create_model(
            "vit_so400m_patch14_siglip_384.webli",
            pretrained=False,
            num_classes=0,
            dynamic_img_size=True,
            dynamic_img_pad=True,
        )

        model = model.to(dtype=torch.get_default_dtype())
        return model

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
        **kwargs: object,
    ):
Mick's avatar
Mick committed
226
        hs = general_mm_embed_routine(
227
228
229
            input_ids=input_ids,
            positions=positions,
            forward_batch=forward_batch,
Mick's avatar
Mick committed
230
231
            image_data_embedding_func=self.get_image_feature,
            language_model=self.language_model,
232
233
        )

Mick's avatar
Mick committed
234
        return hs
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "up_proj", 1),
            ("gate_up_proj", "gate_proj", 0),
        ]
        params_dict = dict(self.named_parameters())
        weights = list(weights)
        for name, loaded_weight in weights:
            if "language" in name:
                name = name.replace("language.", "")
                self.language_model.load_weights([(name, loaded_weight)])
            else:
                param = params_dict[name]
                weights_loader = getattr(param, "weight_loader", default_weight_loader)
                weights_loader(param, loaded_weight)

256
257
258
    def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
        pattern = MultiModalityDataPaddingPatternMultimodalTokens()
        return pattern.pad_input_tokens(input_ids, mm_inputs)
Mick's avatar
Mick committed
259
260
261
262
263
264
265
266
267
268

    def get_image_feature(self, items: List[MultimodalDataItem]):

        images_spatial_crop = torch.cat(
            [item.image_spatial_crop for item in items], dim=0
        )

        assert images_spatial_crop.dim() == 3

        # TODO: can it be batched ?
269
        images_in_this_batch = []
Mick's avatar
Mick committed
270
271
272
273
274
275
        for item in items:
            assert item.pixel_values.dim() == 4
            image_feature = self.vision.forward_features(
                item.pixel_values.type(next(self.vision.parameters()).dtype).to(
                    device=next(self.vision.parameters()).device
                )
276
            )
Mick's avatar
Mick committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
            images_embeds = self.projector(image_feature)
            _, hw, n_dim = images_embeds.shape
            h = w = int(hw**0.5)
            tile_index = 0
            for jdx in range(item.image_spatial_crop.shape[1]):
                num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
                if num_width_tiles == 0 or num_height_tiles == 0:
                    break
                num_tiles_in_image = num_width_tiles * num_height_tiles

                # [hw, D]
                global_features = images_embeds[tile_index]

                # [num_height_tiles * num_width_tiles, hw, D]
                local_features = images_embeds[
                    tile_index + 1 : tile_index + 1 + num_tiles_in_image
                ]
                tile_index += num_tiles_in_image + 1
295

Mick's avatar
Mick committed
296
297
298
299
300
301
302
                # format global and local features
                # ----------------- global view add newline -----------------
                # [hw, D] -> [h, w, D]
                global_features = global_features.view(h, w, n_dim)

                # [D]     -> [h, 1, D]
                new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
303

Mick's avatar
Mick committed
304
305
306
                # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
                global_features = torch.cat(
                    [global_features, new_lines_in_global], dim=1
307
                )
Mick's avatar
Mick committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321

                # [h, w + 1, D] -> [h * (w + 1), D]
                global_features = global_features.view(-1, n_dim)

                # ----------------- local view add newline -----------------
                # [num_height_tiles * num_width_tiles, h * w, D] ->
                # [num_height_tiles * h, num_width_tiles * w, D]
                local_features = rearrange(
                    local_features,
                    "(th tw) (h w) d -> (th h) (tw w) d",
                    th=num_height_tiles,
                    tw=num_width_tiles,
                    h=h,
                    w=w,
322
323
                )

Mick's avatar
Mick committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                # [D] -> [num_height_tiles * h, 1, D]
                new_lines_in_local = repeat(
                    self.image_newline,
                    "d -> (th h) 1 d",
                    th=num_height_tiles,
                    h=h,
                )

                # [num_height_tiles * h, num_width_tiles * w + 1, D]
                local_features = torch.cat([local_features, new_lines_in_local], dim=1)

                # [num_height_tiles * h, num_width_tiles * w + 1, D]
                #   --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
                local_features = local_features.view(-1, n_dim)

                # merge global and local tiles
                if self.global_view_pos == "head":
                    global_local_features = torch.cat(
                        [
                            global_features,
                            self.view_seperator[None, :],
                            local_features,
                        ]
                    )
                else:
                    global_local_features = torch.cat(
                        [
                            local_features,
                            self.view_seperator[None, :],
                            global_features,
                        ]
                    )

                images_in_this_batch.append(global_local_features)
358

359
        return torch.cat(images_in_this_batch, dim=0)
360
361
362


EntryClass = DeepseekVL2ForCausalLM