llama.py 25.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26
27
from collections.abc import Iterable
from typing import Any, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32

import torch
from torch import nn
from transformers import LlamaConfig

33
from vllm.attention import Attention, AttentionType
34
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.activation import SiluAndMul
39
from vllm.model_executor.layers.layernorm import RMSNorm
40
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
41
42
                                               QKVParallelLinear,
                                               RowParallelLinear)
43
from vllm.model_executor.layers.logits_processor import LogitsProcessor
44
from vllm.model_executor.layers.quantization import QuantizationConfig
45
from vllm.model_executor.layers.rotary_embedding import get_rope
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
from vllm.model_executor.model_loader.weight_utils import (
49
    default_weight_loader, maybe_remap_kv_scale_name)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
52

53
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
54
55
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
56
57
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
58

Woosuk Kwon's avatar
Woosuk Kwon committed
59
60

class LlamaMLP(nn.Module):
61

Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
    def __init__(
        self,
64
65
66
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
67
        quant_config: Optional[QuantizationConfig] = None,
68
        bias: bool = False,
69
        prefix: str = "",
70
        reduce_results: bool = True,
71
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        super().__init__()
73
        self.gate_up_proj = MergedColumnParallelLinear(
74
75
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
76
            bias=bias,
77
            quant_config=quant_config,
78
79
80
81
82
83
84
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
85
            reduce_results=reduce_results,
86
87
            prefix=f"{prefix}.down_proj",
        )
88
89
90
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93

    def forward(self, x):
94
95
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

102
103
104
105
106
107
108
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
109
        rope_scaling: Optional[dict[str, Any]] = None,
110
111
112
113
114
115
116
117
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        bias_o_proj: bool = False,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
118
        super().__init__()
119
        layer_idx = extract_layer_index(prefix)
120
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
121
        tp_size = get_tensor_model_parallel_world_size()
122
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
123
124
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
125
        self.total_num_kv_heads = num_kv_heads
126
127
128
129
130
131
132
133
134
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
135
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
136
137
138
139
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
140
        # Phi models introduced a partial_rotary_factor parameter in the config
141
142
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
                                             1)
Zhuohan Li's avatar
Zhuohan Li committed
143
144
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
145
        self.scaling = self.head_dim**-0.5
146
147
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
148

149
        self.qkv_proj = QKVParallelLinear(
150
151
152
153
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
154
            bias=bias,
155
            quant_config=quant_config,
156
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        )
158

159
        self.o_proj = RowParallelLinear(
160
161
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
162
            bias=bias_o_proj,
163
            quant_config=quant_config,
164
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
165
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
166

167
168
169
        self._init_rotary_emb(config,
                              rope_scaling=rope_scaling,
                              quant_config=quant_config)
170

171
172
173
174
175
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
            is_sliding = layer_types[layer_idx] == "sliding_attention"
            if is_sliding:
                sliding_window = config.sliding_window
176

177
178
179
180
        attn_cls = (EncoderOnlyAttention
                    if attn_type == AttentionType.ENCODER_ONLY else Attention)

        self.attn = attn_cls(
181
182
183
184
185
186
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
187
            per_layer_sliding_window=sliding_window,
188
            attn_type=attn_type,
189
            prefix=f"{prefix}.attn",
190
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
191
192
193

    def forward(
        self,
194
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
197
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
198
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
199
        q, k = self.rotary_emb(positions, q, k)
200
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203
        output, _ = self.o_proj(attn_output)
        return output

204
205
206
207
208
    def _init_rotary_emb(self, config: LlamaConfig,
                         rope_scaling: Optional[dict[str, Any]],
                         quant_config: Optional[QuantizationConfig]) -> None:
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
209
        if is_gguf and config.model_type == "llama":
210
211
212
213
214
215
216
217
218
219
220
221
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224

class LlamaDecoderLayer(nn.Module):

225
226
227
    def __init__(
        self,
        config: LlamaConfig,
228
        cache_config: Optional[CacheConfig] = None,
229
        quant_config: Optional[QuantizationConfig] = None,
230
        prefix: str = "",
231
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
        super().__init__()
        self.hidden_size = config.hidden_size
234
235
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
236
237
238
239
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
240
241
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
242
243
244
245
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
246
247
248
249
250
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

251
252
253
254
255
256
257
258
259
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
260
        self.self_attn = LlamaAttention(
261
            config=config,
262
263
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
264
265
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
266
267
268
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
269
            quant_config=quant_config,
270
            bias=attention_bias,
271
            bias_o_proj=bias_o_proj,
272
            cache_config=cache_config,
273
            prefix=f"{prefix}.self_attn",
274
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
275
276
        )
        self.mlp = LlamaMLP(
277
278
279
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
280
            quant_config=quant_config,
281
            bias=getattr(config, "mlp_bias", False),
282
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
283
        )
284
285
286
287
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
290

    def forward(
        self,
291
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
292
        hidden_states: torch.Tensor,
293
        residual: Optional[torch.Tensor],
294
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
295
        # Self Attention
296
297
298
299
300
301
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
302
        hidden_states = self.self_attn(positions=positions,
303
                                       hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305

        # Fully Connected
306
307
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        hidden_states = self.mlp(hidden_states)
309
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
310
311


312
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
313
314
class LlamaModel(nn.Module):

315
316
317
318
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
319
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
320
        super().__init__()
321
322
323
324
325
326

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
327
        self.config = config
328
        self.quant_config = quant_config
329
330
331
332
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
333
334
335
336
337
338
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
339
                quant_config=quant_config,
340
341
342
            )
        else:
            self.embed_tokens = PPMissingLayer()
343
        self.start_layer, self.end_layer, self.layers = make_layers(
344
            config.num_hidden_layers,
345
346
347
348
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
349
350
            prefix=f"{prefix}.layers",
        )
351
352
353
354
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
355

356
357
        self.aux_hidden_state_layers: tuple[int] = tuple()

358
359
360
361
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

362
363
364
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
365
366
    def forward(
        self,
367
        input_ids: Optional[torch.Tensor],
368
        positions: torch.Tensor,
369
        intermediate_tensors: Optional[IntermediateTensors],
370
        inputs_embeds: Optional[torch.Tensor] = None,
371
372
    ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
                                                        list[torch.Tensor]]]:
373
374
375
376
377
378
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
379
        else:
380
381
382
383
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

384
385
386
387
388
        aux_hidden_states = []
        for idx, layer in enumerate(
                self.layers[self.start_layer:self.end_layer]):
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
389
            hidden_states, residual = layer(positions, hidden_states, residual)
390
391
392
393
394
395
396

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

397
        hidden_states, _ = self.norm(hidden_states, residual)
398
399
400

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
        return hidden_states

403
404
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
405
406
407
408
409
410
411
412
413
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
414
        loaded_params: set[str] = set()
415
416
417
418
419
420
421
422
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
423
424
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
425
                # Loading kv cache quantization scales
426
427
428
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
429
430
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
431
                weight_loader(param, loaded_weight)
432
                loaded_params.add(scale_name)
433
                continue
434
435
436
437
438
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
466
467
            loaded_params.add(name)
        return loaded_params
468

Woosuk Kwon's avatar
Woosuk Kwon committed
469

470
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
Terry's avatar
Terry committed
471
    packed_modules_mapping = {
472
473
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
474
475
476
477
478
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
479
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
480
481
    }
    embedding_padding_modules = ["lm_head"]
482

483
484
485
486
487
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
488
489
490
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
491
492
493
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
494
495
496
497
498
499
500
501
502
503
504
505
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
506
        "norm": "model.norm",
507
    }
508

509
510
511
512
513
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
514
        super().__init__()
515
516
517
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
518
        self.config = config
519
520
        self.lora_config = lora_config

521
        self.model = self._init_model(vllm_config=vllm_config,
522
523
                                      prefix=maybe_prefix(prefix, "model"),
                                      layer_type=layer_type)
524

525
526
527
528
529
530
531
532
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
533
534
535
536
537
538
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
539
                quant_config=quant_config,
540
                prefix=maybe_prefix(prefix, "lm_head"),
541
542
            )
            if config.tie_word_embeddings:
543
544
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
545
546
547
548
549
550
551

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
552

553
554
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
555

556
557
558
559
560
561
562
    def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

563
564
565
566
567
568
569
    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
                    layer_type: type[nn.Module] = LlamaDecoderLayer):
        return LlamaModel(vllm_config=vllm_config,
                          prefix=prefix,
                          layer_type=layer_type)
570

571
572
573
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
574
575
    def forward(
        self,
576
577
        input_ids: torch.Tensor,
        positions: torch.Tensor,
578
        intermediate_tensors: Optional[IntermediateTensors] = None,
579
        inputs_embeds: Optional[torch.Tensor] = None,
580
    ) -> Union[torch.Tensor, IntermediateTensors]:
581
        model_output = self.model(input_ids, positions, intermediate_tensors,
582
                                  inputs_embeds)
583
        return model_output
584

585
586
587
588
589
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
590
        logits = self.logits_processor(self.lm_head, hidden_states,
591
592
593
                                       sampling_metadata)
        return logits

594
595
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
596
597
598
599
600
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
601
        return loader.load_weights(
602
            self.maybe_remap_mistral(name, loaded_weight)
603
            for name, loaded_weight in weights)
604

605
606
607
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
608
609
610
        self,
        name: str,
        loaded_weight: torch.Tensor,
611
    ) -> tuple[str, torch.Tensor]:
612

613
        def permute(w: torch.Tensor, n_heads: int):
614
615
616
617
618
619
620
621
622
623
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
624
        if "wk" in modules and modules[-1] == "weight":
625
626
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
627
        elif "wq" in modules and modules[-1] == "weight":
628
629
630
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

631
632
633
634
635
636
637
638
639
640
641
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = (f"{item}.{next_item}"
                             if next_item is not None else None)

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
642
643
644
                name = name.replace(item, mapping[item])

        return name, loaded_weight