aria.py 22.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal
5
6
7

import torch
import torch.nn as nn
8
9
10
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
11

12
from vllm.config import VllmConfig
13
from vllm.config.multimodal import BaseDummyOptions
14
15
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
16
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
17
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.model_executor.model_loader.weight_utils import (
20
21
22
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
25
26
27
28
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
29
from vllm.multimodal.parse import MultiModalDataItems
30
from vllm.multimodal.processing import (
31
    BaseDummyInputsBuilder,
32
33
34
35
36
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils.tensor_schema import TensorSchema, TensorShape
39

40
from .idefics2_vision_model import Idefics2VisionConfig
41
from .idefics2_vision_model import (
42
43
    Idefics2VisionTransformer as Idefics3VisionTransformer,
)
44
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
45
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
46
47
48
49
50
51
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    is_pp_missing_parameter,
    maybe_prefix,
)
52
53


54
class AriaImagePixelInputs(TensorSchema):
55
    """
56
57
58
59
60
61
    Dimensions:
        - b: Batch size
        - n: Number of images
        - c: Number of channels
        - h: Height of each image
        - w: Width of each image
62
63
    """

64
65
    type: Literal["pixel_values"]

66
67
68
69
70
71
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("bn", 3, "h", "w"),
    ]

    pixel_mask: Annotated[
72
        torch.Tensor | None,
73
74
75
        TensorShape("bn", "h", "w"),
    ]

76

77
78
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
79
80
81
82

    def __init__(
        self,
        config: Idefics2VisionConfig,
83
        quant_config: QuantizationConfig | None = None,
84
85
        prefix: str = "",
    ) -> None:
86
        super().__init__(config, quant_config=quant_config, prefix=prefix)
87
88
89
90
91
        # Unlike Idefics3VisionTransformer which uses LayerNorm after the
        # final layer, Aria omits this normalization, so we replace it with an
        # Identity layer
        self.post_layernorm = nn.Identity()

92
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
93
94
95
96
97
98
99
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
100
        loaded_params: set[str] = set()
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        for name, loaded_weight in weights:
            # NOTE: post_layernorm is not used in Aria
            if "post_layernorm" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
116
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
117
118
119
120
121
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


122
class AriaProjectorMLP(nn.Module):
123
124
    def __init__(
        self,
125
126
127
        in_features: int,
        hidden_features: int,
        output_dim: int,
128
        prefix: str = "",
129
130
131
    ) -> None:
        super().__init__()

132
133
134
135
136
137
        self.linear_in = ColumnParallelLinear(
            in_features, hidden_features, bias=False, prefix=f"{prefix}.linear_in"
        )
        self.linear_out = RowParallelLinear(
            hidden_features, output_dim, bias=False, prefix=f"{prefix}.linear_out"
        )
138
139
        self.act = get_act_fn("gelu_new")

140
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
141
142
143
144
145
146
147
148
149
150
151
152
        hidden_states, _ = self.linear_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_out(hidden_states)
        return hidden_states


class AriaProjector(nn.Module):
    """
    A projection module with one cross attention layer and one FFN layer, which
    projects ViT's outputs into MoE's inputs.

    Args:
153
154
        config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
            containing projector configuration parameters.
155
156
157
158
159

    Outputs:
        A tensor with the shape of (batch_size, query_number, output_dim)
    """

160
    def __init__(self, config: AriaConfig, prefix: str = "") -> None:
161
        super().__init__()
162
163
164
165
166
167
168

        self.patch_to_query_dict = config.projector_patch_to_query_dict
        self.in_features = config.vision_config.hidden_size
        self.num_heads = config.vision_config.num_attention_heads
        self.kv_dim = config.vision_config.hidden_size
        self.hidden_features = config.text_config.hidden_size
        self.output_dim = config.text_config.hidden_size
169
170

        self.query = nn.Parameter(
171
172
173
174
            torch.empty(
                config.max_value_projector_patch_to_query_dict, self.in_features
            )
        )
175

176
        self.cross_attn = AriaCrossAttention(config)
177

178
        self.layer_norm = nn.LayerNorm(self.in_features)
179
        self.feed_forward = AriaProjectorMLP(
180
181
182
183
            self.in_features,
            self.hidden_features,
            self.output_dim,
            prefix=f"{prefix}.feed_forward",
184
        )
185

186
187
188
    def forward(
        self,
        x: torch.Tensor,
189
        attn_mask: torch.Tensor | None = None,
190
    ) -> torch.Tensor:
191
        batch_size, num_patches = x.shape[0], x.shape[1]
192

193
        if num_patches not in self.patch_to_query_dict:
194
195
196
197
198
            raise KeyError(
                f"Number of patches {num_patches} not found in "
                "patch_to_query_dict amongst possible values "
                f"{self.patch_to_query_dict.keys()}."
            )
199

200
201
202
        query_num = self.patch_to_query_dict[num_patches]

        queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
203
204
205
206
207
208
209

        if attn_mask is not None:
            attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)

        attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

210
        out = self.feed_forward(self.layer_norm(attention_out))
211
212
213
214

        return out


215
class AriaFusedMoE(SharedFusedMoE):
216
217
218
    def weight_loader(
        self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
    ) -> None:
219
220
221
222
223
        # Override the weight_loader to handle the expert weights in the Aria
        # model, which are already packed with experts, and merge the gate and
        # up weights for each expert.
        # Note: Loading expert weights with quantization is not supported
        tp_rank = get_tensor_model_parallel_rank()
224
        if shard_id == "w13":
225
226
227
228
229
230
            # the shape of loaded_weight is
            # (num_experts, hidden_size, 2 * moe_intermediate_size)
            if self.tp_size > 1:
                up, gate = loaded_weight.chunk(2, dim=-1)
                up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank]
                gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank]
231
232
233
                up_and_gate = torch.cat(
                    [up_current_rank, gate_current_rank], dim=-1
                ).transpose(1, 2)
234
235
236
                param.data.copy_(up_and_gate)
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))
237
        elif shard_id == "w2":
238
239
240
            # the shape of loaded_weight is
            # (num_experts, moe_intermediate_size, hidden_size)
            if self.tp_size > 1:
241
                down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[tp_rank]
242
243
244
245
246
                param.data.copy_(down_current_rank.transpose(1, 2))
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))


247
class AriaTextMoELayer(nn.Module):
248
249
250
251
252
253
254
255
256
257
    """
    Mixture of Experts (MoE) Layer for the AriaMoE model.

    This layer implements the MoE mechanism, which routes input tokens to
    different experts based on a routing algorithm, processes them through the
    experts, and then combines the outputs.
    """

    def __init__(
        self,
258
        config: AriaTextConfig,
259
        quant_config: QuantizationConfig | None,
260
        prefix: str = "",
261
262
263
264
265
    ) -> None:
        super().__init__()
        self.config = config

        self.router_weight = nn.Parameter(
266
267
            torch.empty((self.config.moe_num_experts, self.config.hidden_size))
        )
268

269
270
271
272
273
274
275
276
        self.shared_experts = LlamaMLP(
            config.hidden_size,
            config.intermediate_size * config.moe_num_shared_experts,
            "silu",
            quant_config=quant_config,
            bias=config.mlp_bias,
        )

277
        self.experts = AriaFusedMoE(
278
            shared_experts=self.shared_experts,
279
280
281
            num_experts=config.moe_num_experts,
            top_k=config.moe_topk,
            hidden_size=config.hidden_size,
282
            intermediate_size=config.intermediate_size,
283
284
            quant_config=quant_config,
            reduce_results=True,
285
            prefix=f"{prefix}.experts",
286
287
288
289
290
291
292
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the MoE Layer.

        Args:
293
294
            hidden_states: Input tensor of shape
                (batch_size, sequence_length, hidden_size).
295
296
297
298
299

        Returns:
            torch.Tensor: Output tensor after passing through the MoE layer.
        """

300
        router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
301
302
303

        sparse_expert_output = self.experts(hidden_states, router_output)

304
305
306
307
        if self.shared_experts is not None:
            return sparse_expert_output[0] + sparse_expert_output[1]
        else:
            return sparse_expert_output
308
309


310
class AriaTextDecoderLayer(LlamaDecoderLayer):
311
312
313
314
315
316
    """
    Custom Decoder Layer for the AriaMoE model which modifies the standard
    `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
    Experts (MoE) Layer.
    """

317
318
319
320
321
322
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config, prefix)

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

323
324
325
        self.mlp = AriaTextMoELayer(
            config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
326
327


328
class AriaTextModel(LlamaModel, SupportsQuant):
329
330
331
332
    """
    Custom LlamaModel for the AriaMoE model which modifies the standard
    LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
    """
333

334
335
336
337
338
339
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "experts.w13_weight": ["experts.fc1.weight"],
        "experts.w2_weight": ["experts.fc2.weight"],
    }
340
341

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
342
343
344
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, layer_type=AriaTextDecoderLayer
        )
345
346
347

    # Adapted from LlamaModel.load_weights with the modification of adding
    # the expert weights mapping to `stacked_params_mapping`
348
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
349
350
351
352
353
354
355
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
356
357
            ("experts.w13_weight", "experts.fc1.weight", "w13"),
            ("experts.w2_weight", "experts.fc2.weight", "w2"),
358
359
        ]
        params_dict = dict(self.named_parameters())
360
        loaded_params: set[str] = set()
361
362
363
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
364
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
365
366
367
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
368
369
370
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
371
                # Loading kv cache quantization scales
372
                param = params_dict[scale_name]
373
374
375
376
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
408
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
409
410
411
412
413
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


414
415
class AriaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
416
        return self.ctx.get_hf_config(AriaConfig)
417

418
    def get_vision_config(self):
419
        return self.get_hf_config().vision_config
420

421
422
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(AriaProcessor, **kwargs)
423

424
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
425
426
        return {"image": None}

427
428
429
430
431
432
    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        return max(hf_config.projector_patch_to_query_dict.values())


class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
433
434
435
436
437
438
439
440
441
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token: str = processor.tokenizer.image_token  # type: ignore

        return image_token * num_images

    def get_dummy_mm_data(
442
443
444
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
445
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
446
    ) -> MultiModalDataDict:
447
        vision_config = self.info.get_vision_config()
448
449
450
451

        max_image_size = vision_config.image_size
        num_images = mm_counts.get("image", 0)

452
453
        image_overrides = mm_options.get("image") if mm_options else None

454
        return {
455
456
457
458
459
460
            "image": self._get_dummy_images(
                width=max_image_size,
                height=max_image_size,
                num_images=num_images,
                overrides=image_overrides,
            )
461
462
463
        }


464
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
465
466
467
468
469
470
471
472
473
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            pixel_mask=MultiModalFieldConfig.batched("image"),
        )
474

475
    def _get_prompt_updates(
476
477
478
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
479
        out_mm_kwargs: MultiModalKwargsItems,
480
    ) -> Sequence[PromptUpdate]:
481
        hf_config = self.info.get_hf_config()
482
483
        image_token_id = hf_config.image_token_index

484
        num_image_tokens = self.info.get_num_image_tokens()
485
486
487
488
489

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
490
                replacement=[image_token_id] * num_image_tokens,
491
492
            )
        ]
493
494


495
496
497
498
499
@MULTIMODAL_REGISTRY.register_processor(
    AriaMultiModalProcessor,
    info=AriaProcessingInfo,
    dummy_inputs=AriaDummyInputsBuilder,
)
500
501
502
503
504
505
506
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
    """
    Aria model for conditional generation tasks.

    This model combines a vision tower, a multi-modal projector, and a language
    model to perform tasks that involve both image and text inputs.
    """
507

508
509
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
510
511
512
513
514
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            # mapping for original checkpoint
515
516
517
518
519
520
521
            "language_model.model": "language_model",
            "language_model.lm_head": "lm_head",
        },
        orig_to_new_suffix={
            "router.weight": "router_weight",
        },
    )
522

523
    @classmethod
524
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
525
526
527
528
529
        if modality.startswith("image"):
            return "<|fim_prefix|><|img|><|fim_suffix|>"

        raise ValueError("Only image modality is supported")

530
531
532
533
534
535
536
537
538
539
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = AriaVisionTransformer(
                config.vision_config,
                quant_config=quant_config,
                prefix=f"{prefix}.vision_tower",
            )
            self.multi_modal_projector = AriaProjector(
                config, prefix=maybe_prefix(prefix, "multi_modal_projector")
            )

        with self._mark_language_model(vllm_config):
            self.language_model = AriaTextModel(
                vllm_config=vllm_config.with_hf_config(config.text_config),
                prefix=maybe_prefix(prefix, "language_model.model"),
            )
556
557

    def _parse_and_validate_image_input(
558
        self, **kwargs: object
559
    ) -> AriaImagePixelInputs | None:
560
561
562
563
564
565
566
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_mask = kwargs.pop("pixel_mask", None)

        if pixel_values is None:
            return None

        return AriaImagePixelInputs(
567
568
569
            type="pixel_values",
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
570
571
        )

572
    def _create_patch_attention_mask(
573
        self,
574
575
        pixel_mask: torch.Tensor | None,
    ) -> torch.Tensor | None:
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        if pixel_mask is None:
            return None

        patches_subgrid = pixel_mask.unfold(
            dimension=1,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        ).unfold(
            dimension=2,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        )
        return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

590
591
    def _process_image_input(
        self, image_input: AriaImagePixelInputs
592
    ) -> tuple[torch.Tensor, torch.Tensor]:
593
594
        pixel_values = image_input["pixel_values"]
        pixel_mask = image_input["pixel_mask"]
595

596
597
598
599
600
601
602
603
604
605
606
607
        patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

        image_outputs = self.vision_tower(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )
        image_attn_mask = None
        if patch_attention_mask is not None:
            flattened_mask = patch_attention_mask.flatten(1)
            image_attn_mask = torch.logical_not(flattened_mask)

        return self.multi_modal_projector(image_outputs, image_attn_mask)
608

609
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
610
611
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
612
            return []
613
614
615
616
617
618
619
        multimodal_embeddings = self._process_image_input(image_input)
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
620
621
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
622
        **kwargs: object,
623
    ) -> torch.Tensor | IntermediateTensors:
624
625
        if intermediate_tensors is not None:
            inputs_embeds = None
626
627
628
629
630
631
632
633
634
635

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

636
637
638
639
640
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)
641

642
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
643
        loader = AutoWeightsLoader(self)
644
        loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)