llama.py 25.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
from vllm.attention import Attention, AttentionType
35
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.layernorm import RMSNorm
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
from vllm.model_executor.layers.rotary_embedding import get_rope
49
from vllm.model_executor.layers.vocab_parallel_embedding import (
50
51
52
    ParallelLMHead,
    VocabParallelEmbedding,
)
53
from vllm.model_executor.model_loader.weight_utils import (
54
55
56
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
57
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
60
61
62
63
64
65
66
67
68
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
69

Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72
73

class LlamaMLP(nn.Module):
    def __init__(
        self,
74
75
76
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
77
        quant_config: QuantizationConfig | None = None,
78
        bias: bool = False,
79
        prefix: str = "",
80
        reduce_results: bool = True,
81
        disable_tp: bool = False,
82
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        super().__init__()
84
        self.gate_up_proj = MergedColumnParallelLinear(
85
86
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
87
            bias=bias,
88
            quant_config=quant_config,
89
            disable_tp=disable_tp,
90
91
92
93
94
95
96
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
97
            reduce_results=reduce_results,
98
            disable_tp=disable_tp,
99
100
            prefix=f"{prefix}.down_proj",
        )
101
        if hidden_act != "silu":
102
103
104
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
105
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107

    def forward(self, x):
108
109
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113
114
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
115
116
117
118
119
120
121
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
122
        quant_config: QuantizationConfig | None = None,
123
124
        bias: bool = False,
        bias_o_proj: bool = False,
125
        cache_config: CacheConfig | None = None,
126
127
128
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
129
        super().__init__()
130
        layer_idx = extract_layer_index(prefix)
131
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
132
        tp_size = get_tensor_model_parallel_world_size()
133
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
134
135
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
136
        self.total_num_kv_heads = num_kv_heads
137
138
139
140
141
142
143
144
145
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
146
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
147
148
149
150
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
151
        # Phi models introduced a partial_rotary_factor parameter in the config
152
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
Zhuohan Li's avatar
Zhuohan Li committed
153
154
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
155
        self.scaling = self.head_dim**-0.5
156
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
157

158
159
160
161
162
163
164
165
        llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

166
        self.qkv_proj = QKVParallelLinear(
167
168
169
170
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
171
            bias=bias,
172
            quant_config=quant_config,
173
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        )
175

176
        self.o_proj = RowParallelLinear(
177
178
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
179
            bias=bias_o_proj,
180
            quant_config=quant_config,
181
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
182
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
183

184
        self._init_rotary_emb(config, quant_config=quant_config)
185

186
187
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
188
189
190
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
191
            if hasattr(config, "target_layer_count"):
192
193
194
195
196
197
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
198
            assert effective_layer_idx < len(layer_types), (
199
200
                f"effective_layer_idx: {effective_layer_idx} \
                is out of bounds for layer_types: {layer_types}"
201
            )
202

203
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
204
205
            if is_sliding:
                sliding_window = config.sliding_window
206

207
208
209
210
211
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
212
213

        self.attn = attn_cls(
214
215
216
217
218
219
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
220
            per_layer_sliding_window=sliding_window,
221
            attn_type=attn_type,
222
            prefix=f"{prefix}.attn",
223
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
224

225
226
227
228
229
230
231
232
233
234
235
    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
    def forward(
        self,
238
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
241
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
242
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
243
        q, k = self.rotary_emb(positions, q, k)
244
245
246
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
247
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
250
        output, _ = self.o_proj(attn_output)
        return output

251
252
253
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
254
        quant_config: QuantizationConfig | None,
255
    ) -> None:
256
257
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
258
        if is_gguf and config.model_type == "llama":
259
260
261
262
263
264
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
265
            rope_parameters=getattr(config, "rope_parameters", None),
266
267
268
269
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
270
271

class LlamaDecoderLayer(nn.Module):
272
273
274
275
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
276
        config: LlamaConfig | None = None,
277
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
278
        super().__init__()
279
280
281

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
282
        quant_config = self.get_quant_config(vllm_config)
283

Woosuk Kwon's avatar
Woosuk Kwon committed
284
        self.hidden_size = config.hidden_size
285
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
286
287
288
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
289
290
            config, "bias", False
        )
291
292
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
293
        if hasattr(config, "qkv_bias"):
294
295
            attention_bias = config.qkv_bias

296
297
298
299
300
301
302
303
304
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
305
        self.self_attn = LlamaAttention(
306
            config=config,
307
308
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
309
310
311
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
312
            max_position_embeddings=max_position_embeddings,
313
            quant_config=quant_config,
314
            bias=attention_bias,
315
            bias_o_proj=bias_o_proj,
316
            cache_config=cache_config,
317
            prefix=f"{prefix}.self_attn",
318
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
319
320
        )
        self.mlp = LlamaMLP(
321
322
323
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
324
            quant_config=quant_config,
325
            bias=getattr(config, "mlp_bias", False),
326
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
327
        )
328
329
330
331
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
334

    def forward(
        self,
335
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
336
        hidden_states: torch.Tensor,
337
        residual: torch.Tensor | None,
338
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
339
        # Self Attention
340
341
342
343
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
344
345
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
346
347

        # Fully Connected
348
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
349
        hidden_states = self.mlp(hidden_states)
350
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
351

352
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
353
354
355
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
356

357
358
359
360
361
362
363
364
365
366
367
def llama_model_invariants(
    input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
    """Shape invariants for Llama model compilation, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])


@support_torch_compile(shape_invariants=llama_model_invariants)
Woosuk Kwon's avatar
Woosuk Kwon committed
368
class LlamaModel(nn.Module):
369
370
371
372
373
374
375
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
376
        super().__init__()
377
378
379
380

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
381
        self.config = config
382
        self.quant_config = quant_config
383
384
385

        self.vocab_size = config.vocab_size

386
387
388
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
389
390
391
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
392
                quant_config=quant_config,
393
394
395
            )
        else:
            self.embed_tokens = PPMissingLayer()
396
        self.start_layer, self.end_layer, self.layers = make_layers(
397
            config.num_hidden_layers,
398
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
399
400
            prefix=f"{prefix}.layers",
        )
401
402
403
404
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
405

406
        self.aux_hidden_state_layers = tuple[int, ...]()
407

408
409
410
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
411

412
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
413
414
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
    def forward(
        self,
417
        input_ids: torch.Tensor | None,
418
        positions: torch.Tensor,
419
420
421
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
422
423
424
425
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
426
                hidden_states = self.embed_input_ids(input_ids)
427
            residual = None
428
        else:
429
430
431
432
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

433
434
        aux_hidden_states = []
        for idx, layer in enumerate(
435
436
            islice(self.layers, self.start_layer, self.end_layer)
        ):
437
438
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
439
            hidden_states, residual = layer(positions, hidden_states, residual)
440
441

        if not get_pp_group().is_last_rank:
442
443
444
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
445

446
        hidden_states, _ = self.norm(hidden_states, residual)
447
448
449

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
450
451
        return hidden_states

452
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
453
454
455
456
457
458
459
460
461
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
462
        loaded_params: set[str] = set()
463
464
465
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
466
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
467
468
469
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
470
471
472
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
473
                # Loading kv cache quantization scales
474
                param = params_dict[scale_name]
475
476
477
478
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
479
                weight_loader(param, loaded_weight)
480
                loaded_params.add(scale_name)
481
                continue
482
483
484
485
486
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
511
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
512
                weight_loader(param, loaded_weight)
513
514
            loaded_params.add(name)
        return loaded_params
515

Woosuk Kwon's avatar
Woosuk Kwon committed
516

517
518
519
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
520
    packed_modules_mapping = {
521
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
522
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
523
524
525
526
527
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
528
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
529
530
    }
    embedding_padding_modules = ["lm_head"]
531

532
533
534
535
536
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
537
538
539
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
540
541
542
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
543
544
545
546
547
548
549
550
551
552
553
554
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
555
        "norm": "model.norm",
556
    }
557

558
559
560
561
562
563
564
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
565
        super().__init__()
566
567
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
568
        self.config = config
569

570
571
572
573
574
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
575

576
577
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
578
                config.vocab_size,
579
580
                config.hidden_size,
                quant_config=quant_config,
581
                prefix=maybe_prefix(prefix, "lm_head"),
582
583
            )
            if config.tie_word_embeddings:
584
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
585
586

            logit_scale = getattr(config, "logit_scale", 1.0)
587
            self.logits_processor = LogitsProcessor(
588
                config.vocab_size, scale=logit_scale
589
            )
590
591
        else:
            self.lm_head = PPMissingLayer()
592

593
        self.make_empty_intermediate_tensors = (
594
595
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
596

597
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
598
599
        self.model.aux_hidden_state_layers = layers

600
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
601
602
603
604
605
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
606
607
608
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

609
610
611
612
613
614
615
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
616

617
618
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
619

Woosuk Kwon's avatar
Woosuk Kwon committed
620
621
    def forward(
        self,
622
623
        input_ids: torch.Tensor,
        positions: torch.Tensor,
624
625
626
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
627
628
629
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
630
        return model_output
631

632
633
634
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
635
    ) -> torch.Tensor | None:
636
        logits = self.logits_processor(self.lm_head, hidden_states)
637
638
        return logits

639
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
640
641
        loader = AutoWeightsLoader(
            self,
642
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
643
        )
644
        return loader.load_weights(
645
            self.maybe_remap_mistral(name, loaded_weight)
646
647
            for name, loaded_weight in weights
        )
648

649
650
651
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
652
653
654
        self,
        name: str,
        loaded_weight: torch.Tensor,
655
    ) -> tuple[str, torch.Tensor]:
656
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
657
658
            attn_in = self.config.head_dim * n_heads

659
660
661
662
663
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
664
665
666
667
668

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
669
670
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
671
        if "wk" in modules and modules[-1] == "weight":
672
673
674
675
676
677
678
679
680
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
681
        elif "wq" in modules and modules[-1] == "weight":
682
683
684
685
686
687
688
689
690
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
691

692
693
694
695
696
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

697
            combined_item = f"{item}.{next_item}" if next_item is not None else None
698
699
700
701

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
702
703
704
                name = name.replace(item, mapping[item])

        return name, loaded_weight