"vscode:/vscode.git/clone" did not exist on "73b642e6f341287163c784e1e99a18426ee2ccea"
swin.py 16.7 KB
Newer Older
汪志鹏's avatar
汪志鹏 committed
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import Optional

import torch
import torch.nn as nn
from transformers import SwinConfig
10
from transformers.models.swin.modeling_swin import SwinEmbeddings, SwinPatchMerging
汪志鹏's avatar
汪志鹏 committed
11
12
13
14
from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer
from transformers.pytorch_utils import meshgrid

from vllm.model_executor.layers.activation import get_act_fn
15
16
17
18
19
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
汪志鹏's avatar
汪志鹏 committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader


class SwinSelfAttention(nn.Module):
    def __init__(
        self,
        config: SwinConfig,
        dim: int,
        num_heads: int,
        window_size: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        if dim % num_heads != 0:
            raise ValueError(
                f"The hidden size ({dim}) is not a multiple of the number of "
38
39
                f"attention heads ({num_heads})"
            )
汪志鹏's avatar
汪志鹏 committed
40
41
42
43

        self.num_attention_heads = num_heads
        self.attention_head_size = int(dim / num_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
44
45
46
47
48
        self.window_size = (
            window_size
            if isinstance(window_size, Iterable)
            else (window_size, window_size)
        )
汪志鹏's avatar
汪志鹏 committed
49
50
51
52
        self.scale = self.attention_head_size**-0.5

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(
53
54
55
                (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
            )
        )
汪志鹏's avatar
汪志鹏 committed
56
57
58
59
60
61

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
62
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
汪志鹏's avatar
汪志鹏 committed
63
64
65
66
67
68
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)

69
70
71
        self.relative_position_index = nn.Parameter(
            relative_position_index, requires_grad=False
        )
汪志鹏's avatar
汪志鹏 committed
72
73
74
75
76
77
78
79
80
81
82

        self.qkv = QKVParallelLinear(
            hidden_size=dim,
            head_size=self.attention_head_size,
            total_num_heads=self.num_attention_heads,
            bias=config.qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
        )

    def transpose_for_scores(self, x):
83
84
85
86
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
汪志鹏's avatar
汪志鹏 committed
87
88
89
90
91
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def _get_rel_pos_bias(self) -> torch.Tensor:
        relative_position_bias = self.relative_position_bias_table[
92
93
            self.relative_position_index.view(-1)
        ]
汪志鹏's avatar
汪志鹏 committed
94
95
        relative_position_bias = relative_position_bias.view(
            self.window_size[0] * self.window_size[1],
96
97
98
99
            self.window_size[0] * self.window_size[1],
            -1,
        )
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
汪志鹏's avatar
汪志鹏 committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        return relative_position_bias.unsqueeze(0)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> tuple[torch.Tensor, ...]:
        batch_size, dim, num_channels = hidden_states.shape

        qkv_output, _ = self.qkv(hidden_states)
        query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1)

        key_layer = self.transpose_for_scores(key_layer)
        value_layer = self.transpose_for_scores(value_layer)
        query_layer = self.transpose_for_scores(query_layer)

        attention_scores = self._get_rel_pos_bias()
        if attention_mask is not None:
            mask_shape = attention_mask.shape[0]
            attention_mask_expanded = attention_mask.view(
122
123
124
125
126
127
128
129
130
131
                1, mask_shape, 1, dim, dim
            ).expand(
                batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
            )
            attention_scores = attention_scores + attention_mask_expanded.unsqueeze(
                1
            ).unsqueeze(0)
            attention_scores = attention_scores.view(
                -1, self.num_attention_heads, dim, dim
            )
汪志鹏's avatar
汪志鹏 committed
132
133
134
135
136
137

        context_layer = torch.nn.functional.scaled_dot_product_attention(
            query_layer,
            key_layer,
            value_layer,
            attn_mask=attention_scores,
138
            dropout_p=0.0,
汪志鹏's avatar
汪志鹏 committed
139
140
141
142
        )
        attention_probs = None

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
143
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
汪志鹏's avatar
汪志鹏 committed
144
145
        context_layer = context_layer.view(new_context_layer_shape)

146
147
148
        outputs = (
            (context_layer, attention_probs) if output_attentions else (context_layer,)
        )
汪志鹏's avatar
汪志鹏 committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

        return outputs


class SwinSelfOutput(nn.Module):
    def __init__(
        self,
        config: SwinConfig,
        dim: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.dense = RowParallelLinear(
            input_size=dim,
            output_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )

169
170
171
    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
汪志鹏's avatar
汪志鹏 committed
172
173
174
175
176
177
        hidden_states, _ = self.dense(hidden_states)

        return hidden_states


class SwinAttention(nn.Module):
178
179
180
181
182
183
184
185
186
    def __init__(
        self,
        config: SwinConfig,
        dim: int,
        num_heads: int,
        window_size: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
汪志鹏's avatar
汪志鹏 committed
187
        super().__init__()
188
189
190
191
192
193
194
195
196
197
198
        self.self = SwinSelfAttention(
            config,
            dim,
            num_heads,
            window_size,
            quant_config=quant_config,
            prefix=f"{prefix}.self",
        )
        self.output = SwinSelfOutput(
            config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
        )
汪志鹏's avatar
汪志鹏 committed
199
200
201
202
203
204
205
206
207
        self.pruned_heads = set()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> tuple[torch.Tensor]:
208
209
210
        self_outputs = self.self(
            hidden_states, attention_mask, head_mask, output_attentions
        )
汪志鹏's avatar
汪志鹏 committed
211
        attention_output = self.output(self_outputs[0], hidden_states)
212
        outputs = (attention_output,) + self_outputs[1:]
汪志鹏's avatar
汪志鹏 committed
213
214
215
216
        return outputs


class SwinIntermediate(nn.Module):
217
218
219
220
221
222
223
    def __init__(
        self,
        config: SwinConfig,
        dim: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
汪志鹏's avatar
汪志鹏 committed
224
        super().__init__()
225
226
227
228
229
230
        self.dense = ColumnParallelLinear(
            dim,
            int(config.mlp_ratio * dim),
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
汪志鹏's avatar
汪志鹏 committed
231
232
233
234
235
236
237
238
239
        self.intermediate_act_fn = get_act_fn(config.hidden_act)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class SwinOutput(nn.Module):
240
241
242
243
244
245
246
    def __init__(
        self,
        config: SwinConfig,
        dim: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
汪志鹏's avatar
汪志鹏 committed
247
        super().__init__()
248
249
250
251
252
253
        self.dense = RowParallelLinear(
            int(config.mlp_ratio * dim),
            dim,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
汪志鹏's avatar
汪志鹏 committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        return hidden_states


class SwinLayer(HFSwinLayer):
    def __init__(
        self,
        config: SwinConfig,
        dim: int,
        input_resolution: int,
        num_heads: int,
        drop_path_rate: float = 0.0,
        shift_size: int = 0,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__(
            config=config,
            dim=dim,
            input_resolution=input_resolution,
            num_heads=num_heads,
            drop_path_rate=drop_path_rate,
            shift_size=shift_size,
        )

281
282
283
284
285
286
287
288
289
290
291
292
293
294
        self.attention = SwinAttention(
            config,
            dim,
            num_heads,
            window_size=self.window_size,
            quant_config=quant_config,
            prefix=f"{prefix}.attention",
        )
        self.intermediate = SwinIntermediate(
            config, dim, quant_config=quant_config, prefix=f"{prefix}.intermediate"
        )
        self.output = SwinOutput(
            config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
        )
汪志鹏's avatar
汪志鹏 committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312


class SwinStage(nn.Module):
    def __init__(
        self,
        config: SwinConfig,
        dim: int,
        input_resolution: int,
        depth: int,
        num_heads: int,
        drop_path: list[float],
        downsample: Optional[SwinPatchMerging] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.dim = dim
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        self.blocks = nn.ModuleList(
            [
                SwinLayer(
                    config=config,
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    drop_path_rate=drop_path[layer_idx],
                    shift_size=0 if (layer_idx % 2 == 0) else config.window_size // 2,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(depth)
            ]
        )
汪志鹏's avatar
汪志鹏 committed
328
329
330

        # patch merging layer
        if downsample is not None:
331
332
333
            self.downsample = downsample(
                input_resolution, dim=dim, norm_layer=nn.LayerNorm
            )
汪志鹏's avatar
汪志鹏 committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        else:
            self.downsample = None

        self.pointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        input_dimensions: tuple[int, int],
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        always_partition: Optional[bool] = False,
    ) -> tuple[torch.Tensor]:
        height, width = input_dimensions
        for i, layer_module in enumerate(self.blocks):
            layer_head_mask = head_mask[i] if head_mask is not None else None

351
352
353
354
355
356
357
            layer_outputs = layer_module(
                hidden_states,
                input_dimensions,
                layer_head_mask,
                output_attentions,
                always_partition,
            )
汪志鹏's avatar
汪志鹏 committed
358
359
360
361
362

            hidden_states = layer_outputs[0]

        hidden_states_before_downsampling = hidden_states
        if self.downsample is not None:
363
364
365
366
367
            height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
            output_dimensions = (height, width, height_downsampled, width_downsampled)
            hidden_states = self.downsample(
                hidden_states_before_downsampling, input_dimensions
            )
汪志鹏's avatar
汪志鹏 committed
368
369
370
        else:
            output_dimensions = (height, width, height, width)

371
372
373
374
375
        stage_outputs = (
            hidden_states,
            hidden_states_before_downsampling,
            output_dimensions,
        )
汪志鹏's avatar
汪志鹏 committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393

        if output_attentions:
            stage_outputs += layer_outputs[1:]
        return stage_outputs


class SwinEncoder(nn.Module):
    def __init__(
        self,
        config: SwinConfig,
        grid_size: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.num_layers = len(config.depths)
        self.config = config
        dpr = [
394
395
396
397
            x.item()
            for x in torch.linspace(
                0, config.drop_path_rate, sum(config.depths), device="cpu"
            )
汪志鹏's avatar
汪志鹏 committed
398
        ]
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        self.layers = nn.ModuleList(
            [
                SwinStage(
                    config=config,
                    dim=int(config.embed_dim * 2**layer_idx),
                    input_resolution=(
                        grid_size[0] // (2**layer_idx),
                        grid_size[1] // (2**layer_idx),
                    ),
                    depth=config.depths[layer_idx],
                    num_heads=config.num_heads[layer_idx],
                    drop_path=dpr[
                        sum(config.depths[:layer_idx]) : sum(
                            config.depths[: layer_idx + 1]
                        )
                    ],
                    downsample=SwinPatchMerging
                    if (layer_idx < self.num_layers - 1)
                    else None,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                )
                for layer_idx in range(self.num_layers)
            ]
        )
汪志鹏's avatar
汪志鹏 committed
424
425
426
427
428
429
430
431
432
433
434
435

    def forward(
        self,
        hidden_states: torch.Tensor,
        input_dimensions: tuple[int, int],
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        always_partition: Optional[bool] = False,
    ) -> tuple[torch.Tensor]:
        for i, layer_module in enumerate(self.layers):
            layer_head_mask = head_mask[i] if head_mask is not None else None

436
437
438
439
440
441
442
            layer_outputs = layer_module(
                hidden_states,
                input_dimensions,
                layer_head_mask,
                output_attentions,
                always_partition,
            )
汪志鹏's avatar
汪志鹏 committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463

            hidden_states = layer_outputs[0]
            output_dimensions = layer_outputs[2]

            input_dimensions = (output_dimensions[-2], output_dimensions[-1])

        return hidden_states


class SwinModel(nn.Module):
    config_class: SwinConfig

    def __init__(
        self,
        config: SwinConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.num_layers = len(config.depths)
464
        self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
汪志鹏's avatar
汪志鹏 committed
465
466

        self.embeddings = SwinEmbeddings(config)
467
468
469
470
471
472
        self.encoder = SwinEncoder(
            config,
            self.embeddings.patch_grid,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
        )
汪志鹏's avatar
汪志鹏 committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
    ) -> tuple[torch.Tensor]:
        embedding_output, input_dimensions = self.embeddings(pixel_values)

        encoder_outputs = self.encoder(
            embedding_output,
            input_dimensions,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )

        return encoder_outputs

491
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
汪志鹏's avatar
汪志鹏 committed
492
493
494
495
496
497
498
499
500
        stacked_params_mapping = [
            ("qkv", "query", "q"),
            ("qkv", "key", "k"),
            ("qkv", "value", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
501
            for param_name, weight_name, shard_id in stacked_params_mapping:
汪志鹏's avatar
汪志鹏 committed
502
503
504
505
506
507
508
509
510
511
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
512
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
汪志鹏's avatar
汪志鹏 committed
513
514
515
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params