llama.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26
from collections.abc import Iterable
27
from itertools import islice
28
from typing import Any, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
from vllm.attention import Attention, AttentionType
35
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.layernorm import RMSNorm
41
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
42
43
                                               QKVParallelLinear,
                                               RowParallelLinear)
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.rotary_embedding import get_rope
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import (
50
    default_weight_loader, maybe_remap_kv_scale_name)
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
55
56
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60
61

class LlamaMLP(nn.Module):
62

Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
    def __init__(
        self,
65
66
67
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
68
        quant_config: Optional[QuantizationConfig] = None,
69
        bias: bool = False,
70
        prefix: str = "",
71
        reduce_results: bool = True,
72
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        super().__init__()
74
        self.gate_up_proj = MergedColumnParallelLinear(
75
76
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
77
            bias=bias,
78
            quant_config=quant_config,
79
80
81
82
83
84
85
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
86
            reduce_results=reduce_results,
87
88
            prefix=f"{prefix}.down_proj",
        )
89
90
91
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94

    def forward(self, x):
95
96
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
101
102
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

103
104
105
106
107
108
109
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
110
        rope_scaling: Optional[dict[str, Any]] = None,
111
112
113
114
115
116
117
118
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        bias_o_proj: bool = False,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        super().__init__()
120
        layer_idx = extract_layer_index(prefix)
121
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
122
        tp_size = get_tensor_model_parallel_world_size()
123
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
124
125
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
126
        self.total_num_kv_heads = num_kv_heads
127
128
129
130
131
132
133
134
135
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
136
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
137
138
139
140
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
141
        # Phi models introduced a partial_rotary_factor parameter in the config
142
143
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
                                             1)
Zhuohan Li's avatar
Zhuohan Li committed
144
145
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
146
        self.scaling = self.head_dim**-0.5
147
148
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
149

150
        self.qkv_proj = QKVParallelLinear(
151
152
153
154
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
155
            bias=bias,
156
            quant_config=quant_config,
157
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
158
        )
159

160
        self.o_proj = RowParallelLinear(
161
162
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
163
            bias=bias_o_proj,
164
            quant_config=quant_config,
165
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
166
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
167

168
169
170
        self._init_rotary_emb(config,
                              rope_scaling=rope_scaling,
                              quant_config=quant_config)
171

172
173
174
175
176
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
            is_sliding = layer_types[layer_idx] == "sliding_attention"
            if is_sliding:
                sliding_window = config.sliding_window
177

178
179
180
181
        attn_cls = (EncoderOnlyAttention
                    if attn_type == AttentionType.ENCODER_ONLY else Attention)

        self.attn = attn_cls(
182
183
184
185
186
187
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
188
            per_layer_sliding_window=sliding_window,
189
            attn_type=attn_type,
190
            prefix=f"{prefix}.attn",
191
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194

    def forward(
        self,
195
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
198
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
199
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
200
        q, k = self.rotary_emb(positions, q, k)
201
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203
204
        output, _ = self.o_proj(attn_output)
        return output

205
206
207
208
209
    def _init_rotary_emb(self, config: LlamaConfig,
                         rope_scaling: Optional[dict[str, Any]],
                         quant_config: Optional[QuantizationConfig]) -> None:
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
210
        if is_gguf and config.model_type == "llama":
211
212
213
214
215
216
217
218
219
220
221
222
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225

class LlamaDecoderLayer(nn.Module):

226
227
228
    def __init__(
        self,
        config: LlamaConfig,
229
        cache_config: Optional[CacheConfig] = None,
230
        quant_config: Optional[QuantizationConfig] = None,
231
        prefix: str = "",
232
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
        super().__init__()
        self.hidden_size = config.hidden_size
235
236
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
237
238
239
240
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
241
242
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
243
244
245
246
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
247
248
249
250
251
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

252
253
254
255
256
257
258
259
260
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
261
        self.self_attn = LlamaAttention(
262
            config=config,
263
264
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
265
266
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
267
268
269
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
270
            quant_config=quant_config,
271
            bias=attention_bias,
272
            bias_o_proj=bias_o_proj,
273
            cache_config=cache_config,
274
            prefix=f"{prefix}.self_attn",
275
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
        )
        self.mlp = LlamaMLP(
278
279
280
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
281
            quant_config=quant_config,
282
            bias=getattr(config, "mlp_bias", False),
283
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
284
        )
285
286
287
288
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290
291

    def forward(
        self,
292
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
293
        hidden_states: torch.Tensor,
294
        residual: Optional[torch.Tensor],
295
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
296
        # Self Attention
297
298
299
300
301
302
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
303
        hidden_states = self.self_attn(positions=positions,
304
                                       hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306

        # Fully Connected
307
308
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
309
        hidden_states = self.mlp(hidden_states)
310
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
311
312


313
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
class LlamaModel(nn.Module):

316
317
318
319
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
320
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        super().__init__()
322
323
324
325
326
327

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
328
        self.config = config
329
        self.quant_config = quant_config
330
331
332
333
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
334
335
336
337
338
339
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
340
                quant_config=quant_config,
341
342
343
            )
        else:
            self.embed_tokens = PPMissingLayer()
344
        self.start_layer, self.end_layer, self.layers = make_layers(
345
            config.num_hidden_layers,
346
347
348
349
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
350
351
            prefix=f"{prefix}.layers",
        )
352
353
354
355
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
356

357
        self.aux_hidden_state_layers = tuple[int, ...]()
358

359
360
361
362
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

363
364
365
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
    def forward(
        self,
368
        input_ids: Optional[torch.Tensor],
369
        positions: torch.Tensor,
370
        intermediate_tensors: Optional[IntermediateTensors],
371
        inputs_embeds: Optional[torch.Tensor] = None,
372
373
    ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
                                                        list[torch.Tensor]]]:
374
375
376
377
378
379
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
380
        else:
381
382
383
384
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

385
386
        aux_hidden_states = []
        for idx, layer in enumerate(
387
                islice(self.layers, self.start_layer, self.end_layer)):
388
389
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
390
            hidden_states, residual = layer(positions, hidden_states, residual)
391
392
393
394
395
396
397

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

398
        hidden_states, _ = self.norm(hidden_states, residual)
399
400
401

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
402
403
        return hidden_states

404
405
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
406
407
408
409
410
411
412
413
414
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
415
        loaded_params: set[str] = set()
416
417
418
419
420
421
422
423
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
424
425
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
426
                # Loading kv cache quantization scales
427
428
429
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
430
431
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
432
                weight_loader(param, loaded_weight)
433
                loaded_params.add(scale_name)
434
                continue
435
436
437
438
439
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
467
468
            loaded_params.add(name)
        return loaded_params
469

Woosuk Kwon's avatar
Woosuk Kwon committed
470

471
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
Terry's avatar
Terry committed
472
    packed_modules_mapping = {
473
474
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
475
476
477
478
479
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
480
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
481
482
    }
    embedding_padding_modules = ["lm_head"]
483

484
485
486
487
488
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
489
490
491
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
492
493
494
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
495
496
497
498
499
500
501
502
503
504
505
506
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
507
        "norm": "model.norm",
508
    }
509

510
511
512
513
514
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
515
        super().__init__()
516
517
518
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
519
        self.config = config
520
521
        self.lora_config = lora_config

522
        self.model = self._init_model(vllm_config=vllm_config,
523
524
                                      prefix=maybe_prefix(prefix, "model"),
                                      layer_type=layer_type)
525

526
527
528
529
530
531
532
533
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
534
535
536
537
538
539
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
540
                quant_config=quant_config,
541
                prefix=maybe_prefix(prefix, "lm_head"),
542
543
            )
            if config.tie_word_embeddings:
544
545
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
546
547
548
549
550
551
552

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
553

554
555
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
556

557
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
558
559
        self.model.aux_hidden_state_layers = layers

560
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
561
562
563
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

564
565
566
567
568
569
570
    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
                    layer_type: type[nn.Module] = LlamaDecoderLayer):
        return LlamaModel(vllm_config=vllm_config,
                          prefix=prefix,
                          layer_type=layer_type)
571

572
573
574
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
575
576
    def forward(
        self,
577
578
        input_ids: torch.Tensor,
        positions: torch.Tensor,
579
        intermediate_tensors: Optional[IntermediateTensors] = None,
580
        inputs_embeds: Optional[torch.Tensor] = None,
581
    ) -> Union[torch.Tensor, IntermediateTensors]:
582
        model_output = self.model(input_ids, positions, intermediate_tensors,
583
                                  inputs_embeds)
584
        return model_output
585

586
587
588
589
590
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
591
        logits = self.logits_processor(self.lm_head, hidden_states,
592
593
594
                                       sampling_metadata)
        return logits

595
596
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
597
598
599
600
601
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
602
        return loader.load_weights(
603
            self.maybe_remap_mistral(name, loaded_weight)
604
            for name, loaded_weight in weights)
605

606
607
608
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
609
610
611
        self,
        name: str,
        loaded_weight: torch.Tensor,
612
    ) -> tuple[str, torch.Tensor]:
613

614
        def permute(w: torch.Tensor, n_heads: int):
615
616
617
618
619
620
621
622
623
624
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
625
        if "wk" in modules and modules[-1] == "weight":
626
627
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
628
        elif "wq" in modules and modules[-1] == "weight":
629
630
631
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

632
633
634
635
636
637
638
639
640
641
642
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = (f"{item}.{next_item}"
                             if next_item is not None else None)

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
643
644
645
                name = name.replace(item, mapping[item])

        return name, loaded_weight