"vscode:/vscode.git/clone" did not exist on "9869453c42b8295a84a5a4513b6b3683dde110b7"
gemma.py 12.9 KB
Newer Older
Xiang Xu's avatar
Xiang Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple

import torch
from torch import nn
from transformers import GemmaConfig

23
from vllm.attention import Attention, AttentionMetadata
24
from vllm.config import LoRAConfig
25
from vllm.model_executor.layers.activation import GeluAndMul
26
from vllm.model_executor.layers.layernorm import RMSNorm
27
28
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
Xiang Xu's avatar
Xiang Xu committed
29
30
31
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Xiang Xu's avatar
Xiang Xu committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
from vllm.sequence import SamplerOutput


class GemmaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        linear_method: Optional[LinearMethodBase] = None,
    ) -> None:
        super().__init__()
53
54
55
56
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
Xiang Xu's avatar
Xiang Xu committed
57
58
59
60
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
61
        self.act_fn = GeluAndMul()
Xiang Xu's avatar
Xiang Xu committed
62
63

    def forward(self, x):
64
65
66
67
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x
Xiang Xu's avatar
Xiang Xu committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123


class GemmaAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 head_dim: int,
                 max_position_embeddings: int = 8192,
                 rope_theta: float = 10000,
                 linear_method: Optional[LinearMethodBase] = None) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            linear_method=linear_method,
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            linear_method=linear_method,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=self.rope_theta,
            is_neox_style=True,
        )
124
125
126
127
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads)
Xiang Xu's avatar
Xiang Xu committed
128
129
130
131
132

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
133
134
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Xiang Xu's avatar
Xiang Xu committed
135
136
137
138
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
139
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Xiang Xu's avatar
Xiang Xu committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        output, _ = self.o_proj(attn_output)
        return output


class GemmaDecoderLayer(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GemmaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            max_position_embeddings=config.max_position_embeddings,
            rope_theta=config.rope_theta,
            linear_method=linear_method,
        )
        self.mlp = GemmaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            linear_method=linear_method,
        )
167
168
169
170
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
171
172
173
174
175

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
176
177
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
178
        residual: Optional[torch.Tensor],
Xiang Xu's avatar
Xiang Xu committed
179
180
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
181
182
183
184
185
186
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
187
188
189
190
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
191
            attn_metadata=attn_metadata,
Xiang Xu's avatar
Xiang Xu committed
192
193
194
        )

        # Fully Connected
195
196
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
197
        hidden_states = self.mlp(hidden_states)
198
        return hidden_states, residual
Xiang Xu's avatar
Xiang Xu committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218


class GemmaModel(nn.Module):

    def __init__(
        self,
        config: GemmaConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ) -> None:
        super().__init__()
        self.config = config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
            GemmaDecoderLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
219
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Xiang Xu's avatar
Xiang Xu committed
220
221
222
223
224

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
225
226
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Xiang Xu's avatar
Xiang Xu committed
227
228
229
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        # Normalize the embedding by sqrt(hidden_size)
230
        hidden_states *= self.config.hidden_size**0.5
Xiang Xu's avatar
Xiang Xu committed
231

232
        residual = None
Xiang Xu's avatar
Xiang Xu committed
233
234
        for i in range(len(self.layers)):
            layer = self.layers[i]
235
            hidden_states, residual = layer(
Xiang Xu's avatar
Xiang Xu committed
236
237
238
                positions,
                hidden_states,
                kv_caches[i],
239
                attn_metadata,
240
                residual,
Xiang Xu's avatar
Xiang Xu committed
241
            )
242
        hidden_states, _ = self.norm(hidden_states, residual)
Xiang Xu's avatar
Xiang Xu committed
243
244
245
246
        return hidden_states


class GemmaForCausalLM(nn.Module):
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    # Gemma does not apply LoRA to the embedding layer.
    embedding_modules = {}
    embedding_padding_modules = []
Xiang Xu's avatar
Xiang Xu committed
269
270
271
272
273

    def __init__(
        self,
        config: GemmaConfig,
        linear_method: Optional[LinearMethodBase] = None,
274
        lora_config: Optional[LoRAConfig] = None,
Xiang Xu's avatar
Xiang Xu committed
275
    ) -> None:
276
        del lora_config  # Unused.
Xiang Xu's avatar
Xiang Xu committed
277
278
279
280
        super().__init__()
        self.config = config
        self.linear_method = linear_method
        self.model = GemmaModel(config, linear_method)
281
282
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Xiang Xu's avatar
Xiang Xu committed
283
284
285
286
287
288

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
289
290
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Xiang Xu's avatar
Xiang Xu committed
291
292
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
293
                                   attn_metadata)
Xiang Xu's avatar
Xiang Xu committed
294
295
        return hidden_states

296
297
298
299
300
301
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.model.embed_tokens.weight,
                                       hidden_states, sampling_metadata)
        return logits

Xiang Xu's avatar
Xiang Xu committed
302
303
    def sample(
        self,
304
        logits: torch.Tensor,
Xiang Xu's avatar
Xiang Xu committed
305
306
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
307
        next_tokens = self.sampler(logits, sampling_metadata)
Xiang Xu's avatar
Xiang Xu committed
308
309
310
311
312
313
314
315
316
317
318
319
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
320
321
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Xiang Xu's avatar
Xiang Xu committed
322
323
324
325
326
327
328
329
330
        ]
        params_dict = dict(self.named_parameters())
        loaded_params = set()
        for name, loaded_weight in hf_model_weights_iterator(
                model_name_or_path, cache_dir, load_format, revision):
            for (param_name, shard_name, shard_id) in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
331
332
333
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
Xiang Xu's avatar
Xiang Xu committed
334
335
336
337
338
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
339
340
341
342
                # lm_head is not used in vllm as it is tied with embed_token.
                # To prevent errors, skip loading lm_head.weight.
                if "lm_head.weight" in name:
                    continue
343
344
345
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
346
347
348
349
                # GemmaRMSNorm is different from Llama's in that it multiplies
                # (1 + weight) to the output, instead of just weight.
                if "norm.weight" in name:
                    loaded_weight += 1.0
Xiang Xu's avatar
Xiang Xu committed
350
351
352
353
354
355
356
357
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        unloaded_params = params_dict.keys() - loaded_params
        if unloaded_params:
            raise RuntimeError(
358
359
                "Some weights are not initialized from checkpoints: "
                f"{unloaded_params}")