gemma.py 11.4 KB
Newer Older
Liangsheng Yin's avatar
Liangsheng Yin committed
1
2
3
4
5
6
7
# Adapted from:
# https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py
"""Inference-only Gemma model compatible with HuggingFace weights."""
from typing import Optional, Tuple

import torch
from torch import nn
8
from transformers import PretrainedConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
9
10
11
12
13
14
15
16
from vllm.config import LoRAConfig
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
17
18
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Liangsheng Yin's avatar
Liangsheng Yin committed
19
20
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
21
from vllm.distributed import (
Liangsheng Yin's avatar
Liangsheng Yin committed
22
23
    get_tensor_model_parallel_world_size,
)
24
from sglang.srt.weight_utils import (
Liangsheng Yin's avatar
Liangsheng Yin committed
25
26
27
28
    default_weight_loader,
    hf_model_weights_iterator,
)

Liangsheng Yin's avatar
Liangsheng Yin committed
29
30
31
32
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata

Liangsheng Yin's avatar
Liangsheng Yin committed
33
34
35
36
37
38

class GemmaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
39
        quant_config: Optional[QuantizationConfig] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
40
41
42
43
44
45
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
46
            quant_config=quant_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
47
48
        )
        self.down_proj = RowParallelLinear(
49
            intermediate_size, hidden_size, bias=False, quant_config=quant_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        )
        self.act_fn = GeluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class GemmaAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        head_dim: int,
        layer_id: int = 0,
        max_position_embeddings: int = 8192,
        rope_theta: float = 10000,
70
        quant_config: Optional[QuantizationConfig] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
100
            quant_config=quant_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
101
102
103
104
105
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
106
            quant_config=quant_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=self.rope_theta,
            is_neox_style=True,
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, input_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class GemmaDecoderLayer(nn.Module):
    def __init__(
        self,
141
        config: PretrainedConfig,
Liangsheng Yin's avatar
Liangsheng Yin committed
142
        layer_id: int = 0,
143
        quant_config: Optional[QuantizationConfig] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
144
145
146
147
148
149
150
151
152
153
154
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GemmaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            layer_id=layer_id,
            max_position_embeddings=config.max_position_embeddings,
            rope_theta=config.rope_theta,
155
            quant_config=quant_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
156
157
158
159
        )
        self.mlp = GemmaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
160
            quant_config=quant_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            input_metadata=input_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class GemmaModel(nn.Module):
    def __init__(
        self,
195
        config: PretrainedConfig,
196
        quant_config: Optional[QuantizationConfig] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
197
198
199
200
201
202
203
204
205
206
    ) -> None:
        super().__init__()
        self.config = config

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList(
            [
207
                GemmaDecoderLayer(config, i, quant_config=quant_config)
Liangsheng Yin's avatar
Liangsheng Yin committed
208
209
210
211
212
213
214
215
216
217
                for i in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
218
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
219
    ) -> torch.Tensor:
220
        if input_embeds is None:
Liangsheng Yin's avatar
Liangsheng Yin committed
221
222
            hidden_states = self.embed_tokens(input_ids)
        else:
223
            hidden_states = input_embeds
Liangsheng Yin's avatar
Liangsheng Yin committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

        # Normalize the embedding by sqrt(hidden_size)
        hidden_states *= self.config.hidden_size**0.5

        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                input_metadata,
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class GemmaForCausalLM(nn.Module):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    # Gemma does not apply LoRA to the embedding layer.
    embedding_modules = {}
    embedding_padding_modules = []

    def __init__(
        self,
267
        config: PretrainedConfig,
268
        quant_config: Optional[QuantizationConfig] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
269
270
271
272
273
        lora_config: Optional[LoRAConfig] = None,
    ) -> None:
        del lora_config  # Unused.
        super().__init__()
        self.config = config
274
275
        self.quant_config = quant_config
        self.model = GemmaModel(config, quant_config=quant_config)
Liangsheng Yin's avatar
Liangsheng Yin committed
276
277
278
279
280
281
282
283
        self.logits_processor = LogitsProcessor(config)

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
284
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
285
    ) -> torch.Tensor:
286
        hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
Liangsheng Yin's avatar
Liangsheng Yin committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        return self.logits_processor(
            input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
        )

    def load_weights(
        self,
        model_name_or_path: str,
        cache_dir: Optional[str] = None,
        load_format: str = "auto",
        revision: Optional[str] = None,
    ):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params = set()
        for name, loaded_weight in hf_model_weights_iterator(
            model_name_or_path, cache_dir, load_format, revision
        ):
            for param_name, shard_name, shard_id in stacked_params_mapping:
                if shard_name not in name:
                    continue
                name = name.replace(shard_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # GemmaRMSNorm is different from Llama's in that it multiplies
                # (1 + weight) to the output, instead of just weight.
                if "norm.weight" in name:
                    loaded_weight += 1.0
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        unloaded_params = params_dict.keys() - loaded_params
        if unloaded_params:
            raise RuntimeError(
                "Some weights are not initialized from checkpoints: "
                f"{unloaded_params}"
            )


EntryClass = GemmaForCausalLM