gemma.py 15.1 KB
Newer Older
Xiang Xu's avatar
Xiang Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
17
from functools import lru_cache
18
from typing import Iterable, List, Optional, Set, Tuple
Xiang Xu's avatar
Xiang Xu committed
19
20
21
22
23

import torch
from torch import nn
from transformers import GemmaConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.config import CacheConfig, LoRAConfig
26
from vllm.distributed import get_tensor_model_parallel_world_size
27
from vllm.logger import init_logger
28
from vllm.model_executor.layers.activation import GeluAndMul
29
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
30
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Xiang Xu's avatar
Xiang Xu committed
31
32
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
35
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
36
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
Xiang Xu's avatar
Xiang Xu committed
37
38
39
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Xiang Xu's avatar
Xiang Xu committed
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors, SamplerOutput
Xiang Xu's avatar
Xiang Xu committed
43

44
45
from .interfaces import SupportsLoRA

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
logger = init_logger(__name__)


@lru_cache(maxsize=None)
def _get_gemma_act_fn(
    hidden_act: Optional[str],
    hidden_activation: Optional[str],
) -> nn.Module:
    if hidden_activation is None:
        if hidden_act is not None:
            logger.warning(
                "Gemma's activation function was incorrectly set to exact GeLU "
                "in the config JSON file when it was initially released. "
                "Changing the activation function to approximate GeLU "
                "(`gelu_pytorch_tanh`). If you want to use the legacy "
61
62
                "`%s`, edit the config JSON to set "
                "`hidden_activation=%s` instead of `hidden_act`. "
63
                "See https://github.com/huggingface/transformers/pull/29402 "
64
                "for more details.", hidden_act, hidden_act)
65
66
67
68
69
70
71
72
73
        return GeluAndMul(approximate="tanh")
    elif hidden_activation == "gelu_pytorch_tanh":
        return GeluAndMul(approximate="tanh")
    elif hidden_activation == "gelu":
        return GeluAndMul(approximate="none")
    else:
        raise ValueError(f"Activation function {hidden_act} is not "
                         "supported for Gemma models.")

Xiang Xu's avatar
Xiang Xu committed
74
75
76
77
78
79
80

class GemmaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
81
82
        hidden_act: Optional[str] = None,
        hidden_activation: Optional[str] = None,
83
        quant_config: Optional[QuantizationConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
84
85
    ) -> None:
        super().__init__()
86
87
88
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
89
            quant_config=quant_config)
Xiang Xu's avatar
Xiang Xu committed
90
91
92
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
93
                                           quant_config=quant_config)
94
        self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
Xiang Xu's avatar
Xiang Xu committed
95
96

    def forward(self, x):
97
98
99
100
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x
Xiang Xu's avatar
Xiang Xu committed
101
102
103
104
105
106
107
108
109
110
111


class GemmaAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 head_dim: int,
                 max_position_embeddings: int = 8192,
                 rope_theta: float = 10000,
112
                 cache_config: Optional[CacheConfig] = None,
113
                 quant_config: Optional[QuantizationConfig] = None) -> None:
Xiang Xu's avatar
Xiang Xu committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
142
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
143
144
145
146
147
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
148
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
149
150
        )

151
152
        # TODO(woosuk): Use the `get_rope` interface.
        self.rotary_emb = GemmaRotaryEmbedding(
Xiang Xu's avatar
Xiang Xu committed
153
154
            self.head_dim,
            rotary_dim=self.head_dim,
155
            max_position_embeddings=max_position_embeddings,
Xiang Xu's avatar
Xiang Xu committed
156
157
            base=self.rope_theta,
            is_neox_style=True,
158
            dtype=torch.get_default_dtype(),
Xiang Xu's avatar
Xiang Xu committed
159
        )
160
161
162
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
163
                              num_kv_heads=self.num_kv_heads,
164
165
                              cache_config=cache_config,
                              quant_config=quant_config)
Xiang Xu's avatar
Xiang Xu committed
166
167
168
169
170

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
171
172
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Xiang Xu's avatar
Xiang Xu committed
173
174
175
176
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
177
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Xiang Xu's avatar
Xiang Xu committed
178
179
180
181
182
183
184
185
186
        output, _ = self.o_proj(attn_output)
        return output


class GemmaDecoderLayer(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
187
        cache_config: Optional[CacheConfig] = None,
188
        quant_config: Optional[QuantizationConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
189
190
191
192
193
194
195
196
197
198
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GemmaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            rope_theta=config.rope_theta,
199
            cache_config=cache_config,
200
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
201
202
203
204
        )
        self.mlp = GemmaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
205
206
            hidden_act=config.hidden_act,
            hidden_activation=getattr(config, "hidden_activation", None),
207
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
208
        )
209
210
211
212
        self.input_layernorm = GemmaRMSNorm(config.hidden_size,
                                            eps=config.rms_norm_eps)
        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
                                                     eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
213
214
215
216
217

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
218
219
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
220
        residual: Optional[torch.Tensor],
Xiang Xu's avatar
Xiang Xu committed
221
222
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
223
224
225
226
227
228
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
229
230
231
232
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
233
            attn_metadata=attn_metadata,
Xiang Xu's avatar
Xiang Xu committed
234
235
236
        )

        # Fully Connected
237
238
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
239
        hidden_states = self.mlp(hidden_states)
240
        return hidden_states, residual
Xiang Xu's avatar
Xiang Xu committed
241
242
243
244
245
246
247


class GemmaModel(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
248
        cache_config: Optional[CacheConfig] = None,
249
        quant_config: Optional[QuantizationConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
250
251
252
253
254
255
256
257
258
    ) -> None:
        super().__init__()
        self.config = config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
259
            GemmaDecoderLayer(config, cache_config, quant_config)
Xiang Xu's avatar
Xiang Xu committed
260
261
            for _ in range(config.num_hidden_layers)
        ])
262
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
263

264
265
266
267
268
269
270
        # Normalize the embedding by sqrt(hidden_size)
        # The normalizer's data type should be downcasted to the model's
        # data type such as bfloat16, not float32.
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = self.config.hidden_size**0.5
        self.register_buffer("normalizer", torch.tensor(normalizer))

Roger Wang's avatar
Roger Wang committed
271
272
273
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Xiang Xu's avatar
Xiang Xu committed
274
275
276
277
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
278
279
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
280
        intermediate_tensors: Optional[IntermediateTensors] = None,
Roger Wang's avatar
Roger Wang committed
281
        inputs_embeds: Optional[torch.Tensor] = None,
Xiang Xu's avatar
Xiang Xu committed
282
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
283
284
285
286
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
287
        hidden_states *= self.normalizer
288
        residual = None
Xiang Xu's avatar
Xiang Xu committed
289
290
        for i in range(len(self.layers)):
            layer = self.layers[i]
291
            hidden_states, residual = layer(
Xiang Xu's avatar
Xiang Xu committed
292
293
294
                positions,
                hidden_states,
                kv_caches[i],
295
                attn_metadata,
296
                residual,
Xiang Xu's avatar
Xiang Xu committed
297
            )
298
        hidden_states, _ = self.norm(hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
299
300
301
        return hidden_states


302
class GemmaForCausalLM(nn.Module, SupportsLoRA):
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    # Gemma does not apply LoRA to the embedding layer.
    embedding_modules = {}
    embedding_padding_modules = []
Xiang Xu's avatar
Xiang Xu committed
325
326
327
328

    def __init__(
        self,
        config: GemmaConfig,
329
        cache_config: Optional[CacheConfig] = None,
330
        quant_config: Optional[QuantizationConfig] = None,
331
        lora_config: Optional[LoRAConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
332
333
    ) -> None:
        super().__init__()
334

Xiang Xu's avatar
Xiang Xu committed
335
        self.config = config
336
337
        self.lora_config = lora_config

338
        self.quant_config = quant_config
339
        self.model = GemmaModel(config, cache_config, quant_config)
340
341
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Xiang Xu's avatar
Xiang Xu committed
342
343
344
345
346

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
347
348
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
349
        intermediate_tensors: Optional[IntermediateTensors] = None,
Xiang Xu's avatar
Xiang Xu committed
350
351
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
352
                                   attn_metadata)
Xiang Xu's avatar
Xiang Xu committed
353
354
        return hidden_states

355
356
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
357
358
        logits = self.logits_processor(self.model.embed_tokens, hidden_states,
                                       sampling_metadata)
359
360
        return logits

Xiang Xu's avatar
Xiang Xu committed
361
362
    def sample(
        self,
363
        logits: torch.Tensor,
Xiang Xu's avatar
Xiang Xu committed
364
365
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
366
        next_tokens = self.sampler(logits, sampling_metadata)
Xiang Xu's avatar
Xiang Xu committed
367
368
        return next_tokens

369
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Xiang Xu's avatar
Xiang Xu committed
370
371
372
373
374
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
375
376
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Xiang Xu's avatar
Xiang Xu committed
377
378
        ]
        params_dict = dict(self.named_parameters())
379
        loaded_params: Set[str] = set()
380
        for name, loaded_weight in weights:
Xiang Xu's avatar
Xiang Xu committed
381
382
383
384
            for (param_name, shard_name, shard_id) in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
385
386
387
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
Xiang Xu's avatar
Xiang Xu committed
388
389
390
391
392
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
393
394
395
396
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
397
398
399
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
Xiang Xu's avatar
Xiang Xu committed
400
401
402
403
404
405
406
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        unloaded_params = params_dict.keys() - loaded_params
        if unloaded_params:
407
408
409
            logger.warning(
                "Some weights are not initialized from checkpoints: %s",
                unloaded_params)