gemma.py 14.7 KB
Newer Older
Xiang Xu's avatar
Xiang Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
17
from functools import lru_cache
18
from typing import Iterable, List, Optional, Tuple
Xiang Xu's avatar
Xiang Xu committed
19
20
21
22
23

import torch
from torch import nn
from transformers import GemmaConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.config import CacheConfig, LoRAConfig
26
from vllm.distributed import get_tensor_model_parallel_world_size
27
from vllm.logger import init_logger
28
from vllm.model_executor.layers.activation import GeluAndMul
29
from vllm.model_executor.layers.layernorm import RMSNorm
30
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Xiang Xu's avatar
Xiang Xu committed
31
32
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
35
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
36
from vllm.model_executor.layers.rotary_embedding import get_rope
Xiang Xu's avatar
Xiang Xu committed
37
38
39
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Xiang Xu's avatar
Xiang Xu committed
41
42
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
logger = init_logger(__name__)


@lru_cache(maxsize=None)
def _get_gemma_act_fn(
    hidden_act: Optional[str],
    hidden_activation: Optional[str],
) -> nn.Module:
    if hidden_activation is None:
        if hidden_act is not None:
            logger.warning(
                "Gemma's activation function was incorrectly set to exact GeLU "
                "in the config JSON file when it was initially released. "
                "Changing the activation function to approximate GeLU "
                "(`gelu_pytorch_tanh`). If you want to use the legacy "
59
60
                "`%s`, edit the config JSON to set "
                "`hidden_activation=%s` instead of `hidden_act`. "
61
                "See https://github.com/huggingface/transformers/pull/29402 "
62
                "for more details.", hidden_act, hidden_act)
63
64
65
66
67
68
69
70
71
        return GeluAndMul(approximate="tanh")
    elif hidden_activation == "gelu_pytorch_tanh":
        return GeluAndMul(approximate="tanh")
    elif hidden_activation == "gelu":
        return GeluAndMul(approximate="none")
    else:
        raise ValueError(f"Activation function {hidden_act} is not "
                         "supported for Gemma models.")

Xiang Xu's avatar
Xiang Xu committed
72
73
74
75
76
77
78

class GemmaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
79
80
        hidden_act: Optional[str] = None,
        hidden_activation: Optional[str] = None,
81
        quant_config: Optional[QuantizationConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
82
83
    ) -> None:
        super().__init__()
84
85
86
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
87
            quant_config=quant_config)
Xiang Xu's avatar
Xiang Xu committed
88
89
90
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
91
                                           quant_config=quant_config)
92
        self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
Xiang Xu's avatar
Xiang Xu committed
93
94

    def forward(self, x):
95
96
97
98
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x
Xiang Xu's avatar
Xiang Xu committed
99
100
101
102
103
104
105
106
107
108
109


class GemmaAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 head_dim: int,
                 max_position_embeddings: int = 8192,
                 rope_theta: float = 10000,
110
                 cache_config: Optional[CacheConfig] = None,
111
                 quant_config: Optional[QuantizationConfig] = None) -> None:
Xiang Xu's avatar
Xiang Xu committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
140
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
141
142
143
144
145
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
146
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
147
148
149
150
151
152
153
154
155
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=self.rope_theta,
            is_neox_style=True,
        )
156
157
158
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
159
160
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config)
Xiang Xu's avatar
Xiang Xu committed
161
162
163
164
165

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
166
167
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Xiang Xu's avatar
Xiang Xu committed
168
169
170
171
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
172
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Xiang Xu's avatar
Xiang Xu committed
173
174
175
176
177
178
179
180
181
        output, _ = self.o_proj(attn_output)
        return output


class GemmaDecoderLayer(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
182
        cache_config: Optional[CacheConfig] = None,
183
        quant_config: Optional[QuantizationConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
184
185
186
187
188
189
190
191
192
193
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GemmaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            rope_theta=config.rope_theta,
194
            cache_config=cache_config,
195
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
196
197
198
199
        )
        self.mlp = GemmaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
200
201
            hidden_act=config.hidden_act,
            hidden_activation=getattr(config, "hidden_activation", None),
202
            quant_config=quant_config,
Xiang Xu's avatar
Xiang Xu committed
203
        )
204
205
206
207
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
208
209
210
211
212

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
213
214
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
215
        residual: Optional[torch.Tensor],
Xiang Xu's avatar
Xiang Xu committed
216
217
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
218
219
220
221
222
223
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
224
225
226
227
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
228
            attn_metadata=attn_metadata,
Xiang Xu's avatar
Xiang Xu committed
229
230
231
        )

        # Fully Connected
232
233
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
234
        hidden_states = self.mlp(hidden_states)
235
        return hidden_states, residual
Xiang Xu's avatar
Xiang Xu committed
236
237
238
239
240
241
242


class GemmaModel(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
243
        cache_config: Optional[CacheConfig] = None,
244
        quant_config: Optional[QuantizationConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
245
246
247
248
249
250
251
252
253
    ) -> None:
        super().__init__()
        self.config = config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
254
            GemmaDecoderLayer(config, cache_config, quant_config)
Xiang Xu's avatar
Xiang Xu committed
255
256
            for _ in range(config.num_hidden_layers)
        ])
257
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
258

259
260
261
262
263
264
265
        # Normalize the embedding by sqrt(hidden_size)
        # The normalizer's data type should be downcasted to the model's
        # data type such as bfloat16, not float32.
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = self.config.hidden_size**0.5
        self.register_buffer("normalizer", torch.tensor(normalizer))

Xiang Xu's avatar
Xiang Xu committed
266
267
268
269
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
270
271
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Xiang Xu's avatar
Xiang Xu committed
272
273
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
274
        hidden_states *= self.normalizer
Xiang Xu's avatar
Xiang Xu committed
275

276
        residual = None
Xiang Xu's avatar
Xiang Xu committed
277
278
        for i in range(len(self.layers)):
            layer = self.layers[i]
279
            hidden_states, residual = layer(
Xiang Xu's avatar
Xiang Xu committed
280
281
282
                positions,
                hidden_states,
                kv_caches[i],
283
                attn_metadata,
284
                residual,
Xiang Xu's avatar
Xiang Xu committed
285
            )
286
        hidden_states, _ = self.norm(hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
287
288
289
290
        return hidden_states


class GemmaForCausalLM(nn.Module):
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    # Gemma does not apply LoRA to the embedding layer.
    embedding_modules = {}
    embedding_padding_modules = []
Xiang Xu's avatar
Xiang Xu committed
313
314
315
316

    def __init__(
        self,
        config: GemmaConfig,
317
        cache_config: Optional[CacheConfig] = None,
318
        quant_config: Optional[QuantizationConfig] = None,
319
        lora_config: Optional[LoRAConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
320
    ) -> None:
321
        del lora_config  # Unused.
Xiang Xu's avatar
Xiang Xu committed
322
323
        super().__init__()
        self.config = config
324
        self.quant_config = quant_config
325
        self.model = GemmaModel(config, cache_config, quant_config)
326
327
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Xiang Xu's avatar
Xiang Xu committed
328
329
330
331
332
333

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
334
335
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Xiang Xu's avatar
Xiang Xu committed
336
337
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
338
                                   attn_metadata)
Xiang Xu's avatar
Xiang Xu committed
339
340
        return hidden_states

341
342
343
344
345
346
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.model.embed_tokens.weight,
                                       hidden_states, sampling_metadata)
        return logits

Xiang Xu's avatar
Xiang Xu committed
347
348
    def sample(
        self,
349
        logits: torch.Tensor,
Xiang Xu's avatar
Xiang Xu committed
350
351
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
352
        next_tokens = self.sampler(logits, sampling_metadata)
Xiang Xu's avatar
Xiang Xu committed
353
354
        return next_tokens

355
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Xiang Xu's avatar
Xiang Xu committed
356
357
358
359
360
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
361
362
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Xiang Xu's avatar
Xiang Xu committed
363
364
365
        ]
        params_dict = dict(self.named_parameters())
        loaded_params = set()
366
        for name, loaded_weight in weights:
Xiang Xu's avatar
Xiang Xu committed
367
368
369
370
            for (param_name, shard_name, shard_id) in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
371
372
373
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
Xiang Xu's avatar
Xiang Xu committed
374
375
376
377
378
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
379
380
381
382
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
383
384
385
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
386
387
388
389
                # GemmaRMSNorm is different from Llama's in that it multiplies
                # (1 + weight) to the output, instead of just weight.
                if "norm.weight" in name:
                    loaded_weight += 1.0
Xiang Xu's avatar
Xiang Xu committed
390
391
392
393
394
395
396
397
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        unloaded_params = params_dict.keys() - loaded_params
        if unloaded_params:
            raise RuntimeError(
398
399
                "Some weights are not initialized from checkpoints: "
                f"{unloaded_params}")