llama.py 26.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
from vllm.attention import Attention, AttentionType
35
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.layernorm import RMSNorm
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
from vllm.model_executor.layers.rotary_embedding import get_rope
49
from vllm.model_executor.layers.vocab_parallel_embedding import (
50
51
52
53
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
54
from vllm.model_executor.model_loader.weight_utils import (
55
56
57
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
58
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
61
62
63
64
65
66
67
68
69
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
70

Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74

class LlamaMLP(nn.Module):
    def __init__(
        self,
75
76
77
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
78
        quant_config: QuantizationConfig | None = None,
79
        bias: bool = False,
80
        prefix: str = "",
81
        reduce_results: bool = True,
82
        disable_tp: bool = False,
83
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        super().__init__()
85
        self.gate_up_proj = MergedColumnParallelLinear(
86
87
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
88
            bias=bias,
89
            quant_config=quant_config,
90
            disable_tp=disable_tp,
91
92
93
94
95
96
97
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
98
            reduce_results=reduce_results,
99
            disable_tp=disable_tp,
100
101
            prefix=f"{prefix}.down_proj",
        )
102
        if hidden_act != "silu":
103
104
105
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
106
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108

    def forward(self, x):
109
110
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
114
115
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
116
117
118
119
120
121
122
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
123
        quant_config: QuantizationConfig | None = None,
124
125
        bias: bool = False,
        bias_o_proj: bool = False,
126
        cache_config: CacheConfig | None = None,
127
128
129
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        super().__init__()
131
        layer_idx = extract_layer_index(prefix)
132
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
133
        tp_size = get_tensor_model_parallel_world_size()
134
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
135
136
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
137
        self.total_num_kv_heads = num_kv_heads
138
139
140
141
142
143
144
145
146
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
147
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
148
149
150
151
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
152
        # Phi models introduced a partial_rotary_factor parameter in the config
153
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
Zhuohan Li's avatar
Zhuohan Li committed
154
155
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
156
        self.scaling = self.head_dim**-0.5
157
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
158

159
160
161
162
163
164
165
166
        llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

167
        self.qkv_proj = QKVParallelLinear(
168
169
170
171
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
172
            bias=bias,
173
            quant_config=quant_config,
174
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
175
        )
176

177
        self.o_proj = RowParallelLinear(
178
179
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
180
            bias=bias_o_proj,
181
            quant_config=quant_config,
182
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
183
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
184

185
        self._init_rotary_emb(config, quant_config=quant_config)
186

187
188
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
189
190
191
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
192
            if hasattr(config, "target_layer_count"):
193
194
195
196
197
198
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
199
            assert effective_layer_idx < len(layer_types), (
200
201
                f"effective_layer_idx: {effective_layer_idx} \
                is out of bounds for layer_types: {layer_types}"
202
            )
203

204
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
205
206
            if is_sliding:
                sliding_window = config.sliding_window
207

208
209
210
211
212
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
213
214

        self.attn = attn_cls(
215
216
217
218
219
220
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
221
            per_layer_sliding_window=sliding_window,
222
            attn_type=attn_type,
223
            prefix=f"{prefix}.attn",
224
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
225

226
227
228
229
230
231
232
233
234
235
236
    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
    def forward(
        self,
239
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
242
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
243
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
244
        q, k = self.rotary_emb(positions, q, k)
245
246
247
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
248
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
        output, _ = self.o_proj(attn_output)
        return output

252
253
254
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
255
        quant_config: QuantizationConfig | None,
256
    ) -> None:
257
258
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
259
        if is_gguf and config.model_type == "llama":
260
261
262
263
264
265
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
266
            rope_parameters=config.rope_parameters,
267
268
269
270
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
271
272

class LlamaDecoderLayer(nn.Module):
273
274
275
276
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
277
        config: LlamaConfig | None = None,
278
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
279
        super().__init__()
280
281
282

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
283
        quant_config = self.get_quant_config(vllm_config)
284

Woosuk Kwon's avatar
Woosuk Kwon committed
285
        self.hidden_size = config.hidden_size
286
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
287
288
289
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
290
291
            config, "bias", False
        )
292
293
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
294
        if hasattr(config, "qkv_bias"):
295
296
            attention_bias = config.qkv_bias

297
298
299
300
301
302
303
304
305
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
306
        self.self_attn = LlamaAttention(
307
            config=config,
308
309
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
310
311
312
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
313
            max_position_embeddings=max_position_embeddings,
314
            quant_config=quant_config,
315
            bias=attention_bias,
316
            bias_o_proj=bias_o_proj,
317
            cache_config=cache_config,
318
            prefix=f"{prefix}.self_attn",
319
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
320
321
        )
        self.mlp = LlamaMLP(
322
323
324
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
325
            quant_config=quant_config,
326
            bias=getattr(config, "mlp_bias", False),
327
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
328
        )
329
330
331
332
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
333
334
335

    def forward(
        self,
336
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
337
        hidden_states: torch.Tensor,
338
        residual: torch.Tensor | None,
339
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
340
        # Self Attention
341
342
343
344
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
345
346
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
347
348

        # Fully Connected
349
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
350
        hidden_states = self.mlp(hidden_states)
351
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
352

353
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
354
355
356
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
357

358
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
359
class LlamaModel(nn.Module):
360
361
362
363
364
365
366
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
367
        super().__init__()
368
369
370
371
372

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
373
        self.config = config
374
        self.quant_config = quant_config
375
376
377
378
379
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
380
381
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
382
383
384
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
385
386
387
388
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
389
                quant_config=quant_config,
390
391
392
            )
        else:
            self.embed_tokens = PPMissingLayer()
393
        self.start_layer, self.end_layer, self.layers = make_layers(
394
            config.num_hidden_layers,
395
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
396
397
            prefix=f"{prefix}.layers",
        )
398
399
400
401
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
402

403
        self.aux_hidden_state_layers = tuple[int, ...]()
404

405
406
407
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
408

409
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
410
411
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
    def forward(
        self,
414
        input_ids: torch.Tensor | None,
415
        positions: torch.Tensor,
416
417
418
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
419
420
421
422
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
423
                hidden_states = self.embed_input_ids(input_ids)
424
            residual = None
425
        else:
426
427
428
429
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

430
431
        aux_hidden_states = []
        for idx, layer in enumerate(
432
433
            islice(self.layers, self.start_layer, self.end_layer)
        ):
434
435
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
436
            hidden_states, residual = layer(positions, hidden_states, residual)
437
438

        if not get_pp_group().is_last_rank:
439
440
441
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
442

443
        hidden_states, _ = self.norm(hidden_states, residual)
444
445
446

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
447
448
        return hidden_states

449
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
450
451
452
453
454
455
456
457
458
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
459
        loaded_params: set[str] = set()
460
461
462
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
463
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
464
465
466
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
467
468
469
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
470
                # Loading kv cache quantization scales
471
                param = params_dict[scale_name]
472
473
474
475
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
476
                weight_loader(param, loaded_weight)
477
                loaded_params.add(scale_name)
478
                continue
479
480
481
482
483
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
508
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
509
                weight_loader(param, loaded_weight)
510
511
            loaded_params.add(name)
        return loaded_params
512

Woosuk Kwon's avatar
Woosuk Kwon committed
513

514
515
516
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
517
    packed_modules_mapping = {
518
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
519
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
520
521
522
523
524
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
525
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
526
527
    }
    embedding_padding_modules = ["lm_head"]
528

529
530
531
532
533
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
534
535
536
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
537
538
539
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
540
541
542
543
544
545
546
547
548
549
550
551
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
552
        "norm": "model.norm",
553
    }
554

555
556
557
558
559
560
561
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
562
        super().__init__()
563
564
565
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
566
        self.config = config
567
568
        self.lora_config = lora_config

569
570
571
572
573
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
574

575
576
577
578
579
580
581
582
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
583
584
585
586
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
587
588
589
                    if not lora_config
                    else lora_config.lora_vocab_padding_size
                ),
590
                quant_config=quant_config,
591
                prefix=maybe_prefix(prefix, "lm_head"),
592
593
            )
            if config.tie_word_embeddings:
594
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
595
596

            logit_scale = getattr(config, "logit_scale", 1.0)
597
598
599
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, config.vocab_size, logit_scale
            )
600
601
        else:
            self.lm_head = PPMissingLayer()
602

603
        self.make_empty_intermediate_tensors = (
604
605
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
606

607
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
608
609
        self.model.aux_hidden_state_layers = layers

610
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
611
612
613
614
615
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
616
617
618
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

619
620
621
622
623
624
625
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
626

627
628
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
629

Woosuk Kwon's avatar
Woosuk Kwon committed
630
631
    def forward(
        self,
632
633
        input_ids: torch.Tensor,
        positions: torch.Tensor,
634
635
636
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
637
638
639
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
640
        return model_output
641

642
643
644
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
645
    ) -> torch.Tensor | None:
646
        logits = self.logits_processor(self.lm_head, hidden_states)
647
648
        return logits

649
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
650
651
        loader = AutoWeightsLoader(
            self,
652
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
653
        )
654
        return loader.load_weights(
655
            self.maybe_remap_mistral(name, loaded_weight)
656
657
            for name, loaded_weight in weights
        )
658

659
660
661
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
662
663
664
        self,
        name: str,
        loaded_weight: torch.Tensor,
665
    ) -> tuple[str, torch.Tensor]:
666
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
667
668
            attn_in = self.config.head_dim * n_heads

669
670
671
672
673
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
674
675
676
677
678

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
679
680
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
681
        if "wk" in modules and modules[-1] == "weight":
682
683
684
685
686
687
688
689
690
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
691
        elif "wq" in modules and modules[-1] == "weight":
692
693
694
695
696
697
698
699
700
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
701

702
703
704
705
706
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

707
            combined_item = f"{item}.{next_item}" if next_item is not None else None
708
709
710
711

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
712
713
714
                name = name.replace(item, mapping[item])

        return name, loaded_weight