gemma2.py 16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2024 The vLLM team.
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
from collections.abc import Iterable
20
from itertools import islice
21
from typing import Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25
26

import torch
from torch import nn
from transformers import Gemma2Config

27
from vllm.attention import Attention
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, VllmConfig
30
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
31
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
34
35
36
37
38
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
41
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
43
from vllm.model_executor.model_loader.weight_utils import (
44
45
46
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
47
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
48

49
from .interfaces import SupportsLoRA, SupportsPP
50
51
52
53
54
55
56
57
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
60
logger = init_logger(__name__)

Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
64
65
66
67
68
69
70
71
72

class Gemma2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        hidden_activation: str,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
73
74
75
76
77
            hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
        )
        self.down_proj = RowParallelLinear(
            intermediate_size, hidden_size, bias=False, quant_config=quant_config
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
        if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
            raise ValueError(
                "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
                "function. Please set `hidden_act` and `hidden_activation` to "
82
83
                "`gelu_pytorch_tanh`."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
87
88
89
90
91
92
93
        self.act_fn = GeluAndMul(approximate="tanh")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Gemma2Attention(nn.Module):
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def __init__(
        self,
        config: Gemma2Config,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        head_dim: int,
        max_position_embeddings: int,
        rope_theta: float,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        attn_logits_soft_cap: Optional[float] = None,
        prefix: str = "",
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        super().__init__()
        self.config = config
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = config.query_pre_attn_scalar**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.attention_bias,
            quant_config=quant_config,
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=config.attention_bias,
            quant_config=quant_config,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        self.rotary_emb = get_rope(
Woosuk Kwon's avatar
Woosuk Kwon committed
146
            self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
152
            base=self.rope_theta,
            is_neox_style=True,
        )

153
        layer_idx = extract_layer_index(prefix)
154
155
156
        is_sliding = config.layer_types[layer_idx] == "sliding_attention"
        sliding_window = config.sliding_window if is_sliding else None

157
158
159
160
161
162
163
164
165
166
167
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            logits_soft_cap=attn_logits_soft_cap,
            per_layer_sliding_window=sliding_window,
            prefix=f"{prefix}.attn",
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170
171
172
173
174
175
176

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
177
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
178
179
180
181
182
183
184
185
186
187
        output, _ = self.o_proj(attn_output)
        return output


class Gemma2DecoderLayer(nn.Module):
    def __init__(
        self,
        config: Gemma2Config,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
188
        prefix: str = "",
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
191
192
193
194
195
196
197
198
199
200
201
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = Gemma2Attention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            rope_theta=config.rope_theta,
            cache_config=cache_config,
            quant_config=quant_config,
202
            attn_logits_soft_cap=config.attn_logit_softcapping,
203
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
207
208
209
210
211
212
        )
        self.hidden_size = config.hidden_size
        self.mlp = Gemma2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            hidden_activation=config.hidden_activation,
            quant_config=quant_config,
        )
213
214
215
216
217
218
219
220
221
222
        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = GemmaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.pre_feedforward_layernorm = GemmaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.post_feedforward_layernorm = GemmaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
226
227
228

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
229
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
230
231
232
233
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
234
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
240
241
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states, residual = self.pre_feedforward_layernorm(
242
243
            hidden_states, residual
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
247
248
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        return hidden_states, residual


249
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
250
class Gemma2Model(nn.Module):
251
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
252
        super().__init__()
253
254
255
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
256
        self.config = config
257
        self.quant_config = quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
260
261
262

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
263
264
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
265
            lambda prefix: Gemma2DecoderLayer(
266
267
268
269
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271
272
273
274
275
276
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Normalize the embedding by sqrt(hidden_size)
        # The normalizer's data type should be downcasted to the model's
        # data type such as bfloat16, not float32.
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = self.config.hidden_size**0.5
277
278
279
280
        self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False)
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
281

282
283
284
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
285
286
    def forward(
        self,
287
        input_ids: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
288
        positions: torch.Tensor,
289
        intermediate_tensors: Optional[IntermediateTensors],
290
        inputs_embeds: Optional[torch.Tensor] = None,
291
292
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
293
294
295
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
296
                hidden_states = self.get_input_embeddings(input_ids)
297
298
299
300
301
302
            hidden_states *= self.normalizer
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
303
        for layer in islice(self.layers, self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
307
308
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
309
        if not get_pp_group().is_last_rank:
310
311
312
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
313
314
315
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

316
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
317
318
319
320
321
322
323
324
325
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
326
        loaded_params: set[str] = set()
327
        for name, loaded_weight in weights:
328
329
330
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
331
332
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
333
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
334
335
336
337
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
338
            for param_name, shard_name, shard_id in stacked_params_mapping:
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
355
356
357
358
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
359
360
361
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
362
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
363
364
365
                weight_loader(param, loaded_weight)
            loaded_params.add(name)

366
        return loaded_params
367

Woosuk Kwon's avatar
Woosuk Kwon committed
368

369
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
370
371
372
373
374
375
376
377
378
379
380
381
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

382
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
383
384
385
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
386
387
388
        del lora_config  # Unused.
        super().__init__()
        self.config = config
389
390
        # currently all existing Gemma models have `tie_word_embeddings` enabled
        assert config.tie_word_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
391
        self.quant_config = quant_config
392
393
394
        self.model = Gemma2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
395
        self.logits_processor = LogitsProcessor(
396
397
            config.vocab_size, soft_cap=config.final_logit_softcapping
        )
398
        self.make_empty_intermediate_tensors = (
399
400
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
401

402
403
404
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
407
408
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
409
        intermediate_tensors: Optional[IntermediateTensors] = None,
410
        inputs_embeds: Optional[torch.Tensor] = None,
411
    ) -> Union[torch.Tensor, IntermediateTensors]:
412
413
414
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
        return hidden_states

417
418
419
420
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
421
        logits = self.logits_processor(self.model.embed_tokens, hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
422
423
        return logits

424
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
425
426
        loader = AutoWeightsLoader(
            self,
427
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
428
        )
429
        return loader.load_weights(weights)