gemma.py 15.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Xiang Xu's avatar
Xiang Xu committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
19

20
from collections.abc import Iterable
21
from functools import cache
22
from itertools import islice
23
from typing import Any
Xiang Xu's avatar
Xiang Xu committed
24
25
26
27
28

import torch
from torch import nn
from transformers import GemmaConfig

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.logger import init_logger
34
from vllm.model_executor.layers.activation import GeluAndMul
35
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm.model_executor.layers.rotary_embedding import get_rope
44
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
45
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
from vllm.sequence import IntermediateTensors
Xiang Xu's avatar
Xiang Xu committed
47

48
from .interfaces import SupportsLoRA, SupportsPP
49
50
51
52
53
54
55
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
56

57
58
59
logger = init_logger(__name__)


60
@cache
61
def _get_gemma_act_fn(
62
63
    hidden_act: str | None,
    hidden_activation: str | None,
64
65
66
67
68
69
70
71
) -> nn.Module:
    if hidden_activation is None:
        if hidden_act is not None:
            logger.warning(
                "Gemma's activation function was incorrectly set to exact GeLU "
                "in the config JSON file when it was initially released. "
                "Changing the activation function to approximate GeLU "
                "(`gelu_pytorch_tanh`). If you want to use the legacy "
72
73
                "`%s`, edit the config JSON to set "
                "`hidden_activation=%s` instead of `hidden_act`. "
74
                "See https://github.com/huggingface/transformers/pull/29402 "
75
76
77
78
                "for more details.",
                hidden_act,
                hidden_act,
            )
79
80
81
82
83
84
        return GeluAndMul(approximate="tanh")
    elif hidden_activation == "gelu_pytorch_tanh":
        return GeluAndMul(approximate="tanh")
    elif hidden_activation == "gelu":
        return GeluAndMul(approximate="none")
    else:
85
86
87
        raise ValueError(
            f"Activation function {hidden_act} is not supported for Gemma models."
        )
88

Xiang Xu's avatar
Xiang Xu committed
89
90
91
92
93
94

class GemmaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
95
96
97
        hidden_act: str | None = None,
        hidden_activation: str | None = None,
        quant_config: QuantizationConfig | None = None,
98
        prefix: str = "",
Xiang Xu's avatar
Xiang Xu committed
99
100
    ) -> None:
        super().__init__()
101
        self.gate_up_proj = MergedColumnParallelLinear(
102
103
            hidden_size,
            [intermediate_size] * 2,
104
            bias=False,
105
106
107
108
109
110
111
112
113
114
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
115
        self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
Xiang Xu's avatar
Xiang Xu committed
116
117

    def forward(self, x):
118
119
120
121
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x
Xiang Xu's avatar
Xiang Xu committed
122
123
124


class GemmaAttention(nn.Module):
125
126
127
128
129
130
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        head_dim: int,
131
        rope_parameters: dict[str, Any],
132
        max_position_embeddings: int = 8192,
133
134
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
135
136
        prefix: str = "",
    ) -> None:
Xiang Xu's avatar
Xiang Xu committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
164
            quant_config=quant_config,
165
            prefix=f"{prefix}.qkv_proj",
Xiang Xu's avatar
Xiang Xu committed
166
167
168
169
170
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
171
            quant_config=quant_config,
172
            prefix=f"{prefix}.o_proj",
Xiang Xu's avatar
Xiang Xu committed
173
174
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
175
        self.rotary_emb = get_rope(
Xiang Xu's avatar
Xiang Xu committed
176
177
            self.head_dim,
            rotary_dim=self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
178
            max_position=max_position_embeddings,
179
            rope_parameters=rope_parameters,
Xiang Xu's avatar
Xiang Xu committed
180
181
            is_neox_style=True,
        )
182
183
184
185
186
187
188
189
190
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
Xiang Xu's avatar
Xiang Xu committed
191
192
193
194
195
196
197
198
199

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
200
        attn_output = self.attn(q, k, v)
Xiang Xu's avatar
Xiang Xu committed
201
202
203
204
205
206
207
208
        output, _ = self.o_proj(attn_output)
        return output


class GemmaDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GemmaConfig,
209
210
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
211
        prefix: str = "",
Xiang Xu's avatar
Xiang Xu committed
212
213
214
215
216
217
218
219
220
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GemmaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
221
            rope_parameters=config.rope_parameters,
222
            cache_config=cache_config,
223
            quant_config=quant_config,
224
            prefix=f"{prefix}.self_attn",
Xiang Xu's avatar
Xiang Xu committed
225
226
227
228
        )
        self.mlp = GemmaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
229
230
            hidden_act=config.hidden_act,
            hidden_activation=getattr(config, "hidden_activation", None),
231
            quant_config=quant_config,
232
            prefix=f"{prefix}.mlp",
Xiang Xu's avatar
Xiang Xu committed
233
        )
234
235
236
237
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = GemmaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Xiang Xu's avatar
Xiang Xu committed
238
239
240
241
242

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
243
        residual: torch.Tensor | None,
244
    ) -> tuple[torch.Tensor, torch.Tensor]:
Xiang Xu's avatar
Xiang Xu committed
245
        # Self Attention
246
247
248
249
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
250
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
251
252
253
254
255
256
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
257
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
258
        hidden_states = self.mlp(hidden_states)
259
        return hidden_states, residual
Xiang Xu's avatar
Xiang Xu committed
260
261


262
@support_torch_compile
Xiang Xu's avatar
Xiang Xu committed
263
class GemmaModel(nn.Module):
264
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Xiang Xu's avatar
Xiang Xu committed
265
        super().__init__()
266
267
268
269
270

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Xiang Xu's avatar
Xiang Xu committed
271
272
273
274
275
276
        self.config = config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
277
278
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
279
            lambda prefix: GemmaDecoderLayer(
280
281
282
283
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
284
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
285

286
287
288
289
290
        # Normalize the embedding by sqrt(hidden_size)
        # The normalizer's data type should be downcasted to the model's
        # data type such as bfloat16, not float32.
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = self.config.hidden_size**0.5
291
292
293
294
        self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False)
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
295

296
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
297
298
        return self.embed_tokens(input_ids)

Xiang Xu's avatar
Xiang Xu committed
299
300
301
302
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
303
304
305
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
306
307
308
309
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
310
                hidden_states = self.embed_input_ids(input_ids)
311
312
            hidden_states *= self.normalizer
            residual = None
Roger Wang's avatar
Roger Wang committed
313
        else:
314
315
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
316
        for layer in islice(self.layers, self.start_layer, self.end_layer):
317
            hidden_states, residual = layer(
Xiang Xu's avatar
Xiang Xu committed
318
319
                positions,
                hidden_states,
320
                residual,
Xiang Xu's avatar
Xiang Xu committed
321
            )
322
        if not get_pp_group().is_last_rank:
323
324
325
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
326
        hidden_states, _ = self.norm(hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
327
328
        return hidden_states

329
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
330
331
332
333
334
335
336
337
338
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
339
        loaded_params: set[str] = set()
340
        for name, loaded_weight in weights:
341
            for param_name, shard_name, shard_id in stacked_params_mapping:
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
361
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
362
363
364
365
366
                weight_loader(param, loaded_weight)
            loaded_params.add(name)

        return loaded_params

Xiang Xu's avatar
Xiang Xu committed
367

368
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
369
370
371
372
373
374
375
376
377
378
379
380
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

381
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Xiang Xu's avatar
Xiang Xu committed
382
        super().__init__()
383
384
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
385

Xiang Xu's avatar
Xiang Xu committed
386
        self.config = config
387
388
        # currently all existing Gemma models have `tie_word_embeddings` enabled
        assert config.tie_word_embeddings
389

390
        self.quant_config = quant_config
391
392
393
        self.model = GemmaModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
394
        self.logits_processor = LogitsProcessor(config.vocab_size)
395
        self.make_empty_intermediate_tensors = (
396
397
            self.model.make_empty_intermediate_tensors
        )
Xiang Xu's avatar
Xiang Xu committed
398

399
400
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
401

Xiang Xu's avatar
Xiang Xu committed
402
403
404
405
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
406
407
408
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
409
410
411
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
Xiang Xu's avatar
Xiang Xu committed
412
413
        return hidden_states

414
415
416
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
417
    ) -> torch.Tensor | None:
418
        logits = self.logits_processor(self.model.embed_tokens, hidden_states)
419
420
        return logits

421
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
422
423
        loader = AutoWeightsLoader(
            self,
424
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
425
426
        )
        return loader.load_weights(weights)