llama.py 25.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
35
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
36
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.activation import SiluAndMul
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
51
52
53
    ParallelLMHead,
    VocabParallelEmbedding,
)
54
from vllm.model_executor.model_loader.weight_utils import (
55
56
57
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
58
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
61
62
63
64
65
66
67
68
69
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
70

Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74

class LlamaMLP(nn.Module):
    def __init__(
        self,
75
76
77
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
78
        quant_config: QuantizationConfig | None = None,
79
        bias: bool = False,
80
        prefix: str = "",
81
        reduce_results: bool = True,
82
        disable_tp: bool = False,
83
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        super().__init__()
85
        self.gate_up_proj = MergedColumnParallelLinear(
86
87
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
88
            bias=bias,
89
            quant_config=quant_config,
90
            disable_tp=disable_tp,
91
92
93
94
95
96
97
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
98
            reduce_results=reduce_results,
99
            disable_tp=disable_tp,
100
101
            prefix=f"{prefix}.down_proj",
        )
102
        if hidden_act != "silu":
103
104
105
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
106
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108

    def forward(self, x):
109
110
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
114
115
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
116
117
118
119
120
121
122
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
123
        quant_config: QuantizationConfig | None = None,
124
125
        bias: bool = False,
        bias_o_proj: bool = False,
126
        cache_config: CacheConfig | None = None,
127
128
129
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        super().__init__()
131
        layer_idx = extract_layer_index(prefix)
132
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
133
        tp_size = get_tensor_model_parallel_world_size()
134
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
135
136
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
137
        self.total_num_kv_heads = num_kv_heads
138
139
140
141
142
143
144
145
146
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
147
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
148
149
150
151
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
152
        # Phi models introduced a partial_rotary_factor parameter in the config
153
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
Zhuohan Li's avatar
Zhuohan Li committed
154
155
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
156
        self.scaling = self.head_dim**-0.5
157
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
158

159
160
161
162
163
164
165
166
        llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

167
        self.qkv_proj = QKVParallelLinear(
168
169
170
171
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
172
            bias=bias,
173
            quant_config=quant_config,
174
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
175
        )
176

177
        self.o_proj = RowParallelLinear(
178
179
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
180
            bias=bias_o_proj,
181
            quant_config=quant_config,
182
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
183
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
184

185
        self._init_rotary_emb(config, quant_config=quant_config)
186

187
188
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
189
190
191
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
192
            if hasattr(config, "target_layer_count"):
193
194
195
196
197
198
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
199
            assert effective_layer_idx < len(layer_types), (
200
201
                f"effective_layer_idx: {effective_layer_idx} \
                is out of bounds for layer_types: {layer_types}"
202
            )
203

204
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
205
206
            if is_sliding:
                sliding_window = config.sliding_window
207

208
209
210
211
212
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
213
214

        self.attn = attn_cls(
215
216
217
218
219
220
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
221
            per_layer_sliding_window=sliding_window,
222
            attn_type=attn_type,
223
            prefix=f"{prefix}.attn",
224
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
225

226
227
228
229
230
231
232
233
234
235
236
    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
    def forward(
        self,
239
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
242
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
243
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
244
        q, k = self.rotary_emb(positions, q, k)
245
246
247
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
248
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
        output, _ = self.o_proj(attn_output)
        return output

252
253
254
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
255
        quant_config: QuantizationConfig | None,
256
    ) -> None:
257
258
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
259
        if is_gguf and config.model_type == "llama":
260
261
262
263
264
265
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
266
            rope_parameters=getattr(config, "rope_parameters", None),
267
268
269
270
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
271
272

class LlamaDecoderLayer(nn.Module):
273
274
275
276
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
277
        config: LlamaConfig | None = None,
278
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
279
        super().__init__()
280
281
282

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
283
        quant_config = self.get_quant_config(vllm_config)
284

Woosuk Kwon's avatar
Woosuk Kwon committed
285
        self.hidden_size = config.hidden_size
286
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
287
288
289
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
290
291
            config, "bias", False
        )
292
293
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
294
        if hasattr(config, "qkv_bias"):
295
296
            attention_bias = config.qkv_bias

297
298
299
300
301
302
303
304
305
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
306
        self.self_attn = LlamaAttention(
307
            config=config,
308
309
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
310
311
312
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
313
            max_position_embeddings=max_position_embeddings,
314
            quant_config=quant_config,
315
            bias=attention_bias,
316
            bias_o_proj=bias_o_proj,
317
            cache_config=cache_config,
318
            prefix=f"{prefix}.self_attn",
319
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
320
321
        )
        self.mlp = LlamaMLP(
322
323
324
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
325
            quant_config=quant_config,
326
            bias=getattr(config, "mlp_bias", False),
327
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
328
        )
329
330
331
332
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
333
334
335

    def forward(
        self,
336
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
337
        hidden_states: torch.Tensor,
338
        residual: torch.Tensor | None,
339
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
340
        # Self Attention
341
342
343
344
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
345
346
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
347
348

        # Fully Connected
349
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
350
        hidden_states = self.mlp(hidden_states)
351
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
352

353
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
354
355
356
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
357

358
359
360
361
362
363
364
365
366
367
368
def llama_model_invariants(
    input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
    """Shape invariants for Llama model compilation, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])


@support_torch_compile(shape_invariants=llama_model_invariants)
Woosuk Kwon's avatar
Woosuk Kwon committed
369
class LlamaModel(nn.Module):
370
371
372
373
374
375
376
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
377
        super().__init__()
378
379
380
381

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
382
        self.config = config
383
        self.quant_config = quant_config
384
385
386

        self.vocab_size = config.vocab_size

387
388
389
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
390
391
392
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
393
                quant_config=quant_config,
394
395
396
            )
        else:
            self.embed_tokens = PPMissingLayer()
397
        self.start_layer, self.end_layer, self.layers = make_layers(
398
            config.num_hidden_layers,
399
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
400
401
            prefix=f"{prefix}.layers",
        )
402
403
404
405
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
406

407
        self.aux_hidden_state_layers = tuple[int, ...]()
408

409
410
411
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
412

413
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
414
415
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
416
417
    def forward(
        self,
418
        input_ids: torch.Tensor | None,
419
        positions: torch.Tensor,
420
421
422
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
423
424
425
426
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
427
                hidden_states = self.embed_input_ids(input_ids)
428
            residual = None
429
        else:
430
431
432
433
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

434
435
        aux_hidden_states = []
        for idx, layer in enumerate(
436
437
            islice(self.layers, self.start_layer, self.end_layer)
        ):
438
439
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
440
            hidden_states, residual = layer(positions, hidden_states, residual)
441
442

        if not get_pp_group().is_last_rank:
443
444
445
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
446

447
        hidden_states, _ = self.norm(hidden_states, residual)
448
449
450

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
451
452
        return hidden_states

453
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
454
455
456
457
458
459
460
461
462
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
463
        loaded_params: set[str] = set()
464
465
466
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
467
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
468
469
470
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
471
472
473
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
474
                # Loading kv cache quantization scales
475
                param = params_dict[scale_name]
476
477
478
479
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
480
                weight_loader(param, loaded_weight)
481
                loaded_params.add(scale_name)
482
                continue
483
484
485
486
487
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
512
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
513
                weight_loader(param, loaded_weight)
514
515
            loaded_params.add(name)
        return loaded_params
516

Woosuk Kwon's avatar
Woosuk Kwon committed
517

518
519
520
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
521
    packed_modules_mapping = {
522
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
523
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
524
525
526
527
528
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
529
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
530
    }
531

532
533
534
535
536
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
537
538
539
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
540
541
542
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
543
544
545
546
547
548
549
550
551
552
553
554
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
555
        "norm": "model.norm",
556
    }
557

558
559
560
561
562
563
564
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
565
        super().__init__()
566
567
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
568
        self.config = config
569

570
571
572
573
574
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
575

576
577
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
578
                config.vocab_size,
579
580
                config.hidden_size,
                quant_config=quant_config,
581
                prefix=maybe_prefix(prefix, "lm_head"),
582
583
            )
            if config.tie_word_embeddings:
584
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
585
586

            logit_scale = getattr(config, "logit_scale", 1.0)
587
            self.logits_processor = LogitsProcessor(
588
                config.vocab_size, scale=logit_scale
589
            )
590
591
        else:
            self.lm_head = PPMissingLayer()
592

593
        self.make_empty_intermediate_tensors = (
594
595
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
596

597
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
598
599
        self.model.aux_hidden_state_layers = layers

600
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
601
602
603
604
605
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
606
607
608
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

609
610
611
612
613
614
615
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
616

617
618
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
619

Woosuk Kwon's avatar
Woosuk Kwon committed
620
621
    def forward(
        self,
622
623
        input_ids: torch.Tensor,
        positions: torch.Tensor,
624
625
626
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
627
628
629
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
630
        return model_output
631

632
633
634
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
635
    ) -> torch.Tensor | None:
636
        logits = self.logits_processor(self.lm_head, hidden_states)
637
638
        return logits

639
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
640
641
        loader = AutoWeightsLoader(
            self,
642
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
643
        )
644
        return loader.load_weights(
645
            self.maybe_remap_mistral(name, loaded_weight)
646
647
            for name, loaded_weight in weights
        )
648

649
650
651
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
652
653
654
        self,
        name: str,
        loaded_weight: torch.Tensor,
655
    ) -> tuple[str, torch.Tensor]:
656
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
657
658
            attn_in = self.config.head_dim * n_heads

659
660
661
662
663
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
664
665
666
667
668

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
669
670
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
671
        if "wk" in modules and modules[-1] == "weight":
672
673
674
675
676
677
678
679
680
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
681
        elif "wq" in modules and modules[-1] == "weight":
682
683
684
685
686
687
688
689
690
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
691

692
693
694
695
696
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

697
            combined_item = f"{item}.{next_item}" if next_item is not None else None
698
699
700
701

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
702
703
704
                name = name.replace(item, mapping[item])

        return name, loaded_weight