gemma.py 16.3 KB
Newer Older
Xiang Xu's avatar
Xiang Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
17
from functools import lru_cache
18
from typing import Iterable, List, Optional, Set, Tuple, Union
Xiang Xu's avatar
Xiang Xu committed
19
20
21
22
23

import torch
from torch import nn
from transformers import GemmaConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.config import CacheConfig, LoRAConfig
26
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
27
from vllm.logger import init_logger
28
from vllm.model_executor.layers.activation import GeluAndMul
29
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
30
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Xiang Xu's avatar
Xiang Xu committed
31
32
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
from vllm.model_executor.layers.quantization import QuantizationConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.rotary_embedding import get_rope
36
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
Xiang Xu's avatar
Xiang Xu committed
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Xiang Xu's avatar
Xiang Xu committed
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
Xiang Xu's avatar
Xiang Xu committed
42

43
44
45
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
logger = init_logger(__name__)


@lru_cache(maxsize=None)
def _get_gemma_act_fn(
    hidden_act: Optional[str],
    hidden_activation: Optional[str],
) -> nn.Module:
    if hidden_activation is None:
        if hidden_act is not None:
            logger.warning(
                "Gemma's activation function was incorrectly set to exact GeLU "
                "in the config JSON file when it was initially released. "
                "Changing the activation function to approximate GeLU "
                "(`gelu_pytorch_tanh`). If you want to use the legacy "
62
63
                "`%s`, edit the config JSON to set "
                "`hidden_activation=%s` instead of `hidden_act`. "
64
                "See https://github.com/huggingface/transformers/pull/29402 "
65
                "for more details.", hidden_act, hidden_act)
66
67
68
69
70
71
72
73
74
        return GeluAndMul(approximate="tanh")
    elif hidden_activation == "gelu_pytorch_tanh":
        return GeluAndMul(approximate="tanh")
    elif hidden_activation == "gelu":
        return GeluAndMul(approximate="none")
    else:
        raise ValueError(f"Activation function {hidden_act} is not "
                         "supported for Gemma models.")

Xiang Xu's avatar
Xiang Xu committed
75
76
77
78
79
80
81

class GemmaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
82
83
        hidden_act: Optional[str] = None,
        hidden_activation: Optional[str] = None,
84
        quant_config: Optional[QuantizationConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
85
86
    ) -> None:
        super().__init__()
87
88
89
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
90
            quant_config=quant_config)
Xiang Xu's avatar
Xiang Xu committed
91
92
93
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
94
                                           quant_config=quant_config)
95
        self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
Xiang Xu's avatar
Xiang Xu committed
96
97

    def forward(self, x):
98
99
100
101
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x
Xiang Xu's avatar
Xiang Xu committed
102
103
104
105
106
107
108
109
110
111
112


class GemmaAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 head_dim: int,
                 max_position_embeddings: int = 8192,
                 rope_theta: float = 10000,
113
                 cache_config: Optional[CacheConfig] = None,
114
                 quant_config: Optional[QuantizationConfig] = None) -> None:
Xiang Xu's avatar
Xiang Xu committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
143
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
144
145
146
147
148
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
149
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
150
151
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
152
        self.rotary_emb = get_rope(
Xiang Xu's avatar
Xiang Xu committed
153
154
            self.head_dim,
            rotary_dim=self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
            max_position=max_position_embeddings,
Xiang Xu's avatar
Xiang Xu committed
156
157
158
            base=self.rope_theta,
            is_neox_style=True,
        )
159
160
161
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
162
                              num_kv_heads=self.num_kv_heads,
163
164
                              cache_config=cache_config,
                              quant_config=quant_config)
Xiang Xu's avatar
Xiang Xu committed
165
166
167
168
169

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
170
171
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Xiang Xu's avatar
Xiang Xu committed
172
173
174
175
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
176
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Xiang Xu's avatar
Xiang Xu committed
177
178
179
180
181
182
183
184
185
        output, _ = self.o_proj(attn_output)
        return output


class GemmaDecoderLayer(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
186
        cache_config: Optional[CacheConfig] = None,
187
        quant_config: Optional[QuantizationConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
188
189
190
191
192
193
194
195
196
197
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GemmaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            rope_theta=config.rope_theta,
198
            cache_config=cache_config,
199
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
200
201
202
203
        )
        self.mlp = GemmaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
204
205
            hidden_act=config.hidden_act,
            hidden_activation=getattr(config, "hidden_activation", None),
206
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
207
        )
208
209
210
211
        self.input_layernorm = GemmaRMSNorm(config.hidden_size,
                                            eps=config.rms_norm_eps)
        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
                                                     eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
212
213
214
215
216

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
217
218
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
219
        residual: Optional[torch.Tensor],
Xiang Xu's avatar
Xiang Xu committed
220
221
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
222
223
224
225
226
227
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
228
229
230
231
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
232
            attn_metadata=attn_metadata,
Xiang Xu's avatar
Xiang Xu committed
233
234
235
        )

        # Fully Connected
236
237
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
238
        hidden_states = self.mlp(hidden_states)
239
        return hidden_states, residual
Xiang Xu's avatar
Xiang Xu committed
240
241
242
243
244
245
246


class GemmaModel(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
247
        cache_config: Optional[CacheConfig] = None,
248
        quant_config: Optional[QuantizationConfig] = None,
249
        prefix: str = "",
Xiang Xu's avatar
Xiang Xu committed
250
251
252
253
254
255
256
257
    ) -> None:
        super().__init__()
        self.config = config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
258
259
260
261
262
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
                                             ),
            prefix=f"{prefix}.layers")
263
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
264

265
266
267
268
269
270
        # Normalize the embedding by sqrt(hidden_size)
        # The normalizer's data type should be downcasted to the model's
        # data type such as bfloat16, not float32.
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = self.config.hidden_size**0.5
        self.register_buffer("normalizer", torch.tensor(normalizer))
271
272
273
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
274

Roger Wang's avatar
Roger Wang committed
275
276
277
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Xiang Xu's avatar
Xiang Xu committed
278
279
280
281
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
282
283
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
284
        intermediate_tensors: Optional[IntermediateTensors],
Roger Wang's avatar
Roger Wang committed
285
        inputs_embeds: Optional[torch.Tensor] = None,
286
287
288
289
290
291
292
293
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            hidden_states *= self.normalizer
            residual = None
Roger Wang's avatar
Roger Wang committed
294
        else:
295
296
297
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
Xiang Xu's avatar
Xiang Xu committed
298
            layer = self.layers[i]
299
            hidden_states, residual = layer(
Xiang Xu's avatar
Xiang Xu committed
300
301
                positions,
                hidden_states,
302
                kv_caches[i - self.start_layer],
303
                attn_metadata,
304
                residual,
Xiang Xu's avatar
Xiang Xu committed
305
            )
306
307
308
309
310
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
311
        hidden_states, _ = self.norm(hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
312
313
314
        return hidden_states


315
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    # Gemma does not apply LoRA to the embedding layer.
    embedding_modules = {}
    embedding_padding_modules = []
Xiang Xu's avatar
Xiang Xu committed
338
339
340
341

    def __init__(
        self,
        config: GemmaConfig,
342
        cache_config: Optional[CacheConfig] = None,
343
        quant_config: Optional[QuantizationConfig] = None,
344
        lora_config: Optional[LoRAConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
345
346
    ) -> None:
        super().__init__()
347

Xiang Xu's avatar
Xiang Xu committed
348
        self.config = config
349
350
        # currently all existing Gemma models have `tie_word_embeddings` enabled
        assert config.tie_word_embeddings
351
352
        self.lora_config = lora_config

353
        self.quant_config = quant_config
354
        self.model = GemmaModel(config, cache_config, quant_config)
355
356
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
357
358
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Xiang Xu's avatar
Xiang Xu committed
359
360
361
362
363

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
364
365
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
366
        intermediate_tensors: Optional[IntermediateTensors] = None,
367
    ) -> Union[torch.Tensor, IntermediateTensors]:
Xiang Xu's avatar
Xiang Xu committed
368
        hidden_states = self.model(input_ids, positions, kv_caches,
369
                                   attn_metadata, intermediate_tensors)
Xiang Xu's avatar
Xiang Xu committed
370
371
        return hidden_states

372
373
374
375
376
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
377
378
        logits = self.logits_processor(self.model.embed_tokens, hidden_states,
                                       sampling_metadata)
379
380
        return logits

Xiang Xu's avatar
Xiang Xu committed
381
382
    def sample(
        self,
383
        logits: torch.Tensor,
Xiang Xu's avatar
Xiang Xu committed
384
385
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
386
        next_tokens = self.sampler(logits, sampling_metadata)
Xiang Xu's avatar
Xiang Xu committed
387
388
        return next_tokens

389
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Xiang Xu's avatar
Xiang Xu committed
390
391
392
393
394
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
395
396
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Xiang Xu's avatar
Xiang Xu committed
397
398
        ]
        params_dict = dict(self.named_parameters())
399
        loaded_params: Set[str] = set()
400
        for name, loaded_weight in weights:
Xiang Xu's avatar
Xiang Xu committed
401
402
403
404
            for (param_name, shard_name, shard_id) in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
405
406
407
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
408
409
                if is_pp_missing_parameter(name, self):
                    continue
Xiang Xu's avatar
Xiang Xu committed
410
411
412
413
414
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
415
416
417
418
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
419
420
421
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
422
423
                if is_pp_missing_parameter(name, self):
                    continue
Xiang Xu's avatar
Xiang Xu committed
424
425
426
427
428
429
430
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        unloaded_params = params_dict.keys() - loaded_params
        if unloaded_params:
431
432
433
            logger.warning(
                "Some weights are not initialized from checkpoints: %s",
                unloaded_params)