llama.py 27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
29
from typing import Any
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33
34

import torch
from torch import nn
from transformers import LlamaConfig

35
from vllm.attention import Attention, AttentionType
36
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.activation import SiluAndMul
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
51
52
53
54
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
55
from vllm.model_executor.model_loader.weight_utils import (
56
57
58
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
59
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
62
63
64
65
66
67
68
69
70
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
71

Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
74
75

class LlamaMLP(nn.Module):
    def __init__(
        self,
76
77
78
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
79
        quant_config: QuantizationConfig | None = None,
80
        bias: bool = False,
81
        prefix: str = "",
82
        reduce_results: bool = True,
83
        disable_tp: bool = False,
84
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        super().__init__()
86
        self.gate_up_proj = MergedColumnParallelLinear(
87
88
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
89
            bias=bias,
90
            quant_config=quant_config,
91
            disable_tp=disable_tp,
92
93
94
95
96
97
98
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
99
            reduce_results=reduce_results,
100
            disable_tp=disable_tp,
101
102
            prefix=f"{prefix}.down_proj",
        )
103
        if hidden_act != "silu":
104
105
106
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109

    def forward(self, x):
110
111
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
117
118
119
120
121
122
123
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
124
        rope_scaling: dict[str, Any] | None = None,
125
        max_position_embeddings: int = 8192,
126
        quant_config: QuantizationConfig | None = None,
127
128
        bias: bool = False,
        bias_o_proj: bool = False,
129
        cache_config: CacheConfig | None = None,
130
131
132
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
133
        super().__init__()
134
        layer_idx = extract_layer_index(prefix)
135
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
136
        tp_size = get_tensor_model_parallel_world_size()
137
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
138
139
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
140
        self.total_num_kv_heads = num_kv_heads
141
142
143
144
145
146
147
148
149
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
150
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
151
152
153
154
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
155
        # Phi models introduced a partial_rotary_factor parameter in the config
156
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
Zhuohan Li's avatar
Zhuohan Li committed
157
158
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
159
        self.scaling = self.head_dim**-0.5
160
161
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
162

163
164
165
166
167
168
169
170
        llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

171
        self.qkv_proj = QKVParallelLinear(
172
173
174
175
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
176
            bias=bias,
177
            quant_config=quant_config,
178
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
179
        )
180

181
        self.o_proj = RowParallelLinear(
182
183
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
184
            bias=bias_o_proj,
185
            quant_config=quant_config,
186
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
187
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
188

189
190
191
        self._init_rotary_emb(
            config, rope_scaling=rope_scaling, quant_config=quant_config
        )
192

193
194
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
195
196
197
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
198
            if hasattr(config, "target_layer_count"):
199
200
201
202
203
204
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
205
            assert effective_layer_idx < len(layer_types), (
206
207
                f"effective_layer_idx: {effective_layer_idx} \
                is out of bounds for layer_types: {layer_types}"
208
            )
209

210
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
211
212
            if is_sliding:
                sliding_window = config.sliding_window
213

214
215
216
217
218
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
219
220

        self.attn = attn_cls(
221
222
223
224
225
226
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
227
            per_layer_sliding_window=sliding_window,
228
            attn_type=attn_type,
229
            prefix=f"{prefix}.attn",
230
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
231

232
233
234
235
236
237
238
239
240
241
242
    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
    def forward(
        self,
245
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
248
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
249
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
250
        q, k = self.rotary_emb(positions, q, k)
251
252
253
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
254
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
        output, _ = self.o_proj(attn_output)
        return output

258
259
260
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
261
262
        rope_scaling: dict[str, Any] | None,
        quant_config: QuantizationConfig | None,
263
    ) -> None:
264
265
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
266
        if is_gguf and config.model_type == "llama":
267
268
269
270
271
272
273
274
275
276
277
278
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
279
280

class LlamaDecoderLayer(nn.Module):
281
282
283
284
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
285
        config: LlamaConfig | None = None,
286
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
287
        super().__init__()
288
289
290

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
291
        quant_config = self.get_quant_config(vllm_config)
292

Woosuk Kwon's avatar
Woosuk Kwon committed
293
        self.hidden_size = config.hidden_size
294
295
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
296
        if rope_scaling is not None and getattr(
297
298
            config, "original_max_position_embeddings", None
        ):
299
            rope_scaling["original_max_position_embeddings"] = (
300
301
302
                config.original_max_position_embeddings
            )
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
303
304
305
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
306
307
            config, "bias", False
        )
308
309
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
310
        if hasattr(config, "qkv_bias"):
311
312
            attention_bias = config.qkv_bias

313
314
315
316
317
318
319
320
321
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
322
        self.self_attn = LlamaAttention(
323
            config=config,
324
325
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
326
327
328
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
329
330
331
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
332
            quant_config=quant_config,
333
            bias=attention_bias,
334
            bias_o_proj=bias_o_proj,
335
            cache_config=cache_config,
336
            prefix=f"{prefix}.self_attn",
337
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
338
339
        )
        self.mlp = LlamaMLP(
340
341
342
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
343
            quant_config=quant_config,
344
            bias=getattr(config, "mlp_bias", False),
345
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
346
        )
347
348
349
350
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
351
352
353

    def forward(
        self,
354
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
355
        hidden_states: torch.Tensor,
356
        residual: torch.Tensor | None,
357
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
358
        # Self Attention
359
360
361
362
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
363
364
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
365
366

        # Fully Connected
367
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
368
        hidden_states = self.mlp(hidden_states)
369
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
370

371
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
372
373
374
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
375

376
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
377
class LlamaModel(nn.Module):
378
379
380
381
382
383
384
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
385
        super().__init__()
386
387
388
389
390

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
391
        self.config = config
392
        self.quant_config = quant_config
393
394
395
396
397
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
398
399
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
400
401
402
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
403
404
405
406
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
407
                quant_config=quant_config,
408
409
410
            )
        else:
            self.embed_tokens = PPMissingLayer()
411
        self.start_layer, self.end_layer, self.layers = make_layers(
412
            config.num_hidden_layers,
413
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
414
415
            prefix=f"{prefix}.layers",
        )
416
417
418
419
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
420

421
        self.aux_hidden_state_layers = tuple[int, ...]()
422

423
424
425
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
426

427
428
429
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
430
431
    def forward(
        self,
432
        input_ids: torch.Tensor | None,
433
        positions: torch.Tensor,
434
435
436
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
437
438
439
440
441
442
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
443
        else:
444
445
446
447
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

448
449
        aux_hidden_states = []
        for idx, layer in enumerate(
450
451
            islice(self.layers, self.start_layer, self.end_layer)
        ):
452
453
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
454
            hidden_states, residual = layer(positions, hidden_states, residual)
455
456

        if not get_pp_group().is_last_rank:
457
458
459
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
460

461
        hidden_states, _ = self.norm(hidden_states, residual)
462
463
464

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
465
466
        return hidden_states

467
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
468
469
470
471
472
473
474
475
476
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
477
        loaded_params: set[str] = set()
478
479
480
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
481
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
482
483
484
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
485
486
487
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
488
                # Loading kv cache quantization scales
489
                param = params_dict[scale_name]
490
491
492
493
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
494
                weight_loader(param, loaded_weight)
495
                loaded_params.add(scale_name)
496
                continue
497
498
499
500
501
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
526
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
527
                weight_loader(param, loaded_weight)
528
529
            loaded_params.add(name)
        return loaded_params
530

Woosuk Kwon's avatar
Woosuk Kwon committed
531

532
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
Terry's avatar
Terry committed
533
    packed_modules_mapping = {
534
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
535
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
536
537
538
539
540
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
541
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
542
543
    }
    embedding_padding_modules = ["lm_head"]
544

545
546
547
548
549
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
550
551
552
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
553
554
555
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
556
557
558
559
560
561
562
563
564
565
566
567
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
568
        "norm": "model.norm",
569
    }
570

571
572
573
574
575
576
577
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
578
        super().__init__()
579
580
581
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
582
        self.config = config
583
584
        self.lora_config = lora_config

585
586
587
588
589
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
590

591
592
593
594
595
596
597
598
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
599
600
601
602
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
603
604
605
                    if not lora_config
                    else lora_config.lora_vocab_padding_size
                ),
606
                quant_config=quant_config,
607
                prefix=maybe_prefix(prefix, "lm_head"),
608
609
            )
            if config.tie_word_embeddings:
610
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
611
612

            logit_scale = getattr(config, "logit_scale", 1.0)
613
614
615
            self.logits_processor = LogitsProcessor(
                self.unpadded_vocab_size, config.vocab_size, logit_scale
            )
616
617
        else:
            self.lm_head = PPMissingLayer()
618

619
        self.make_empty_intermediate_tensors = (
620
621
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
622

623
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
624
625
        self.model.aux_hidden_state_layers = layers

626
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
627
628
629
630
631
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
632
633
634
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

635
636
637
638
639
640
641
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
642

643
644
645
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
646
647
    def forward(
        self,
648
649
        input_ids: torch.Tensor,
        positions: torch.Tensor,
650
651
652
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
653
654
655
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
656
        return model_output
657

658
659
660
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
661
    ) -> torch.Tensor | None:
662
        logits = self.logits_processor(self.lm_head, hidden_states)
663
664
        return logits

665
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
666
667
        loader = AutoWeightsLoader(
            self,
668
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
669
        )
670
        return loader.load_weights(
671
            self.maybe_remap_mistral(name, loaded_weight)
672
673
            for name, loaded_weight in weights
        )
674

675
676
677
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
678
679
680
        self,
        name: str,
        loaded_weight: torch.Tensor,
681
    ) -> tuple[str, torch.Tensor]:
682
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
683
684
            attn_in = self.config.head_dim * n_heads

685
686
687
688
689
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
690
691
692
693
694

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
695
696
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
697
        if "wk" in modules and modules[-1] == "weight":
698
699
700
701
702
703
704
705
706
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
707
        elif "wq" in modules and modules[-1] == "weight":
708
709
710
711
712
713
714
715
716
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
717

718
719
720
721
722
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

723
            combined_item = f"{item}.{next_item}" if next_item is not None else None
724
725
726
727

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
728
729
730
                name = name.replace(item, mapping[item])

        return name, loaded_weight