gemma3.py 21.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2025 The vLLM team.
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
from collections.abc import Iterable
19
from itertools import islice
20
from typing import Optional, Union
21
22
23
24
25
26

import torch
import torch.nn.functional as F
from torch import nn
from transformers import Gemma3TextConfig

27
from vllm.attention import Attention, AttentionType
28
29
30
31
32
33
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
34
35
36
37
38
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
39
40
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
43
from vllm.model_executor.model_loader.weight_utils import (
44
45
46
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
47
48
from vllm.sequence import IntermediateTensors

49
from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
50
from .interfaces import SupportsLoRA, SupportsPP
51
52
53
54
55
56
57
58
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
59
60
61
62
63
64
65
66
67
68
69

logger = init_logger(__name__)


class Gemma3MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_activation: str,
        quant_config: Optional[QuantizationConfig] = None,
70
        prefix: str = "",
71
72
73
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
74
75
76
77
78
79
80
81
82
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
83
            bias=False,
84
85
86
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
87
88
89
90
        if hidden_activation != "gelu_pytorch_tanh":
            raise ValueError(
                "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
                "function. Please set `hidden_act` and `hidden_activation` to "
91
92
                "`gelu_pytorch_tanh`."
            )
93
94
95
96
97
98
99
100
101
102
        self.act_fn = GeluAndMul(approximate="tanh")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Gemma3Attention(nn.Module):
103
104
105
106
107
108
109
110
111
112
113
114
115
    def __init__(
        self,
        config: Gemma3TextConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        head_dim: int,
        max_position_embeddings: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        attn_logits_soft_cap: Optional[float] = None,
        prefix: str = "",
    ) -> None:
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        super().__init__()
        self.config = config
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = config.query_pre_attn_scalar**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.attention_bias,
            quant_config=quant_config,
145
            prefix=f"{prefix}.qkv_proj",
146
147
148
149
150
151
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=config.attention_bias,
            quant_config=quant_config,
152
            prefix=f"{prefix}.o_proj",
153
154
155
156
157
158
        )

        self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)

        layer_idx = extract_layer_index(prefix)
159
160
161
        self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
        sliding_window = config.sliding_window if self.is_sliding else None

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        # Initialize the rotary embedding.
        if self.is_sliding:
            # Local attention. Override the values in config.json.
            self.rope_theta = config.rope_local_base_freq
            self.rope_scaling = {"rope_type": "default"}
        else:
            # Global attention. Use the values in config.json.
            self.rope_theta = config.rope_theta
            self.rope_scaling = config.rope_scaling
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=self.rope_theta,
            is_neox_style=True,
            rope_scaling=self.rope_scaling,
        )

180
181
182
183
184
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

185
186
187
188
189
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
190

191
192
193
194
195
196
197
198
199
200
201
202
        self.attn = attn_cls(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            attn_type=attn_type,
            logits_soft_cap=attn_logits_soft_cap,
            per_layer_sliding_window=sliding_window,
            prefix=f"{prefix}.attn",
        )
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        q = q.unflatten(-1, (self.num_heads, self.head_dim))
        q = self.q_norm(q)
        q = q.flatten(-2, -1)
        k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
        k = self.k_norm(k)
        k = k.flatten(-2, -1)

        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)

        if not kwargs.get("has_images", False):
            # Fast path for text-only inputs. The performance for the text-only
            # inputs are not affected by the naive attention below.
            output, _ = self.o_proj(attn_output)
            return output

        # NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
        # that correspond to the same image while using causal attention
        # otherwise. Current attention backends cannot handle this pattern, so
        # we temporarily use a naive attention implementation with mask tensors.

        # We intentionally keep the attention backend as-is and only override
        # `attn_output` with the naive implementation's output. This minimizes
        # changes to existing model runners and attention backends. The call to
        # `self.attn(q, k, v)` is only used to populate the KV cache - its
        # output is discarded and overwritten below. While this duplicates
        # computation, it maintains compatibility.
        # TODO(woosuk): Optimize by implementing custom attention kernels.
241
        attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        output, _ = self.o_proj(attn_output)
        return output

    def naive_attn_with_masks(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        out: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        # NOTE(woosuk): As described in the comment above, this code is not
        # meant to be performant. It is only meant to be correct.
        q = q.view(-1, self.num_heads, self.head_dim)
        # Expand the key and value to handle GQA.
        num_queries_per_kv = self.num_heads // self.num_kv_heads
        k = k.view(-1, self.num_kv_heads, self.head_dim)
        k = k.repeat_interleave(num_queries_per_kv, dim=-2)
        v = v.view(-1, self.num_kv_heads, self.head_dim)
        v = v.repeat_interleave(num_queries_per_kv, dim=-2)

        if self.is_sliding:
            attn_masks = kwargs["local_attn_masks"]
        else:
            attn_masks = kwargs["global_attn_masks"]

        seq_lens = kwargs["seq_lens"]
        start_idx = 0
        for seq_len, attn_mask in zip(seq_lens, attn_masks):
            end_idx = start_idx + seq_len
            query = q[start_idx:end_idx].unsqueeze(0)
            key = k[start_idx:end_idx].unsqueeze(0)
            value = v[start_idx:end_idx].unsqueeze(0)

            # Transpose.
            query = query.transpose(1, 2)
            key = key.transpose(1, 2)
            value = value.transpose(1, 2)

            output = F.scaled_dot_product_attention(
                query,
                key,
                value,
                attn_mask,
                self.scaling,
            )
            output = output.transpose(1, 2).flatten(-2, -1)
            out[start_idx:end_idx] = output
            start_idx = end_idx
        return out


class Gemma3DecoderLayer(nn.Module):
    def __init__(
        self,
        config: Gemma3TextConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = Gemma3Attention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            attn_logits_soft_cap=None,
            prefix=f"{prefix}.self_attn",
        )
        self.hidden_size = config.hidden_size
        self.mlp = Gemma3MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_activation=config.hidden_activation,
            quant_config=quant_config,
322
            prefix=f"{prefix}.mlp",
323
        )
324
325
326
327
328
329
330
331
332
333
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = GemmaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.pre_feedforward_layernorm = GemmaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.post_feedforward_layernorm = GemmaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
334
335
336
337
338
339
340

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        **kwargs,
341
    ) -> tuple[torch.Tensor, torch.Tensor]:
342
343
344
345
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
346
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
347
348
349
350
351
352
353
354
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            **kwargs,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states, residual = self.pre_feedforward_layernorm(
355
356
            hidden_states, residual
        )
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Gemma3Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
375
            quant_config=quant_config,
376
            prefix=f"{prefix}.embed_tokens",
377
378
379
380
        )
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Gemma3DecoderLayer(
381
382
383
384
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
385
386
387
388
389
390
391
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Normalize the embedding by sqrt(hidden_size)
        # The normalizer's data type should be downcasted to the model's
        # data type such as bfloat16, not float32.
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = self.config.hidden_size**0.5
392
393
394
395
        self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False)
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        # NOTE(woosuk): Only apply the normalizer to the output of
        # vocab embedding. Don't apply it to the vision embedding.
        return self.embed_tokens(input_ids) * self.normalizer

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
420
        for layer in islice(self.layers, self.start_layer, self.end_layer):
421
422
423
424
425
426
427
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
                **kwargs,
            )
        if not get_pp_group().is_last_rank:
428
429
430
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
431
432
433
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

434
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
435
436
437
438
439
440
441
442
443
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
444
        loaded_params: set[str] = set()
445
        for name, loaded_weight in weights:
446
447
448
449
450
451
452
453
454
            # Revert +1 during llama.cpp conversion
            # see: https://github.com/ggml-org/llama.cpp/blob/be7c3034108473beda214fd1d7c98fd6a7a3bdf5/convert_hf_to_gguf.py#L3397-L3400
            if (
                self.quant_config
                and self.quant_config.get_name() == "gguf"
                and name.endswith("norm.weight")
            ):
                loaded_weight -= 1

455
456
457
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
458
459
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
460
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
461
462
463
464
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
465
466

            # Check if this is a scale parameter that needs remapping first
467
            if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
468
469
470
471
472
                # Try to remap the scale name first
                remapped_name = maybe_remap_kv_scale_name(name, params_dict)
                if remapped_name is not None and remapped_name in params_dict:
                    # Successfully remapped, use the remapped name
                    param = params_dict[remapped_name]
473
474
475
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
476
477
478
479
480
                    weight_loader(param, loaded_weight)
                    loaded_params.add(remapped_name)
                    continue
                # If remapping failed, continue with normal processing

481
            for param_name, shard_name, shard_id in stacked_params_mapping:
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
505
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
                weight_loader(param, loaded_weight)
            loaded_params.add(name)

        return loaded_params


class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        del lora_config  # Unused.
        super().__init__()
        self.config = config
        # currently all existing Gemma models have `tie_word_embeddings` enabled
        assert config.tie_word_embeddings
        self.quant_config = quant_config
535
536
537
        self.model = Gemma3Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
538
        self.logits_processor = LogitsProcessor(
539
540
            config.vocab_size, soft_cap=config.final_logit_softcapping
        )
541
        self.make_empty_intermediate_tensors = (
542
543
            self.model.make_empty_intermediate_tensors
        )
544
545
546
547
548
549
550
551
552
553
554
555

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:
556
557
558
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
        )
559
560
561
562
563
564
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
565
        logits = self.logits_processor(self.model.embed_tokens, hidden_states)
566
567
        return logits

568
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
569
570
        loader = AutoWeightsLoader(
            self,
571
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
572
573
        )
        return loader.load_weights(weights)