llama.py 25.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
35
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
36
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.activation import SiluAndMul
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
51
52
53
    ParallelLMHead,
    VocabParallelEmbedding,
)
54
from vllm.model_executor.model_loader.weight_utils import (
55
56
57
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
58
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
61
62
63
64
65
66
67
68
69
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
70

Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74

class LlamaMLP(nn.Module):
    def __init__(
        self,
75
76
77
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
78
        quant_config: QuantizationConfig | None = None,
79
        bias: bool = False,
80
        prefix: str = "",
81
        reduce_results: bool = True,
82
        disable_tp: bool = False,
83
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        super().__init__()
85
        self.gate_up_proj = MergedColumnParallelLinear(
86
87
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
88
            bias=bias,
89
            quant_config=quant_config,
90
            disable_tp=disable_tp,
91
92
93
94
95
96
97
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
98
            reduce_results=reduce_results,
99
            disable_tp=disable_tp,
100
101
            prefix=f"{prefix}.down_proj",
        )
102
        if hidden_act != "silu":
103
104
105
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
106
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108

    def forward(self, x):
109
110
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
114
115
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
116
117
118
119
120
121
122
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
123
        quant_config: QuantizationConfig | None = None,
124
125
        bias: bool = False,
        bias_o_proj: bool = False,
126
        cache_config: CacheConfig | None = None,
127
128
129
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        super().__init__()
131
        layer_idx = extract_layer_index(prefix)
132
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
133
        tp_size = get_tensor_model_parallel_world_size()
134
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
135
136
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
137
        self.total_num_kv_heads = num_kv_heads
138
139
140
141
142
143
144
145
146
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
147
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
148
149
150
151
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Zhuohan Li's avatar
Zhuohan Li committed
152
153
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
154
        self.scaling = self.head_dim**-0.5
155
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
156

157
158
159
160
161
162
163
164
        llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

165
        self.qkv_proj = QKVParallelLinear(
166
167
168
169
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
170
            bias=bias,
171
            quant_config=quant_config,
172
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
173
        )
174

175
        self.o_proj = RowParallelLinear(
176
177
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
178
            bias=bias_o_proj,
179
            quant_config=quant_config,
180
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
181
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
182

183
        self._init_rotary_emb(config, quant_config=quant_config)
184

185
186
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
187
188
189
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
190
            if hasattr(config, "target_layer_count"):
191
192
193
194
195
196
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
197
            assert effective_layer_idx < len(layer_types), (
198
199
                f"effective_layer_idx: {effective_layer_idx} \
                is out of bounds for layer_types: {layer_types}"
200
            )
201

202
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
203
204
            if is_sliding:
                sliding_window = config.sliding_window
205

206
207
208
209
210
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
211
212

        self.attn = attn_cls(
213
214
215
216
217
218
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
219
            per_layer_sliding_window=sliding_window,
220
            attn_type=attn_type,
221
            prefix=f"{prefix}.attn",
222
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
223

224
225
226
227
228
229
230
231
232
233
234
    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
    def forward(
        self,
237
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
240
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
241
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
242
        q, k = self.rotary_emb(positions, q, k)
243
244
245
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
246
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
        output, _ = self.o_proj(attn_output)
        return output

250
251
252
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
253
        quant_config: QuantizationConfig | None,
254
    ) -> None:
255
256
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
257
        if is_gguf and config.model_type == "llama":
258
259
260
261
262
263
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
264
            rope_parameters=getattr(config, "rope_parameters", None),
265
266
267
            is_neox_style=is_neox_style,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
268
269

class LlamaDecoderLayer(nn.Module):
270
271
272
273
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
274
        config: LlamaConfig | None = None,
275
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
276
        super().__init__()
277
278
279

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
280
        quant_config = self.get_quant_config(vllm_config)
281

Woosuk Kwon's avatar
Woosuk Kwon committed
282
        self.hidden_size = config.hidden_size
283
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
284
285
286
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
287
288
            config, "bias", False
        )
289
290
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
291
        if hasattr(config, "qkv_bias"):
292
293
            attention_bias = config.qkv_bias

294
295
296
297
298
299
300
301
302
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
303
        self.self_attn = LlamaAttention(
304
            config=config,
305
306
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
307
308
309
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
310
            max_position_embeddings=max_position_embeddings,
311
            quant_config=quant_config,
312
            bias=attention_bias,
313
            bias_o_proj=bias_o_proj,
314
            cache_config=cache_config,
315
            prefix=f"{prefix}.self_attn",
316
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
317
318
        )
        self.mlp = LlamaMLP(
319
320
321
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
322
            quant_config=quant_config,
323
            bias=getattr(config, "mlp_bias", False),
324
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
325
        )
326
327
328
329
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
330
331
332

    def forward(
        self,
333
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
334
        hidden_states: torch.Tensor,
335
        residual: torch.Tensor | None,
336
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
337
        # Self Attention
338
339
340
341
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
342
343
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
344
345

        # Fully Connected
346
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
347
        hidden_states = self.mlp(hidden_states)
348
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
349

350
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
351
352
353
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
354

355
356
357
358
359
360
361
362
363
364
365
def llama_model_invariants(
    input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
    """Shape invariants for Llama model compilation, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])


@support_torch_compile(shape_invariants=llama_model_invariants)
Woosuk Kwon's avatar
Woosuk Kwon committed
366
class LlamaModel(nn.Module):
367
368
369
370
371
372
373
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
374
        super().__init__()
375
376
377
378

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
379
        self.config = config
380
        self.quant_config = quant_config
381
382
383

        self.vocab_size = config.vocab_size

384
385
386
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
387
388
389
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
390
                quant_config=quant_config,
391
392
393
            )
        else:
            self.embed_tokens = PPMissingLayer()
394
        self.start_layer, self.end_layer, self.layers = make_layers(
395
            config.num_hidden_layers,
396
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
397
398
            prefix=f"{prefix}.layers",
        )
399
400
401
402
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
403

404
        self.aux_hidden_state_layers = tuple[int, ...]()
405

406
407
408
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
409

410
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
411
412
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
413
414
    def forward(
        self,
415
        input_ids: torch.Tensor | None,
416
        positions: torch.Tensor,
417
418
419
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
420
421
422
423
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
424
                hidden_states = self.embed_input_ids(input_ids)
425
            residual = None
426
        else:
427
428
429
430
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

431
432
        aux_hidden_states = []
        for idx, layer in enumerate(
433
434
            islice(self.layers, self.start_layer, self.end_layer)
        ):
435
436
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
437
            hidden_states, residual = layer(positions, hidden_states, residual)
438
439

        if not get_pp_group().is_last_rank:
440
441
442
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
443

444
        hidden_states, _ = self.norm(hidden_states, residual)
445
446
447

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
448
449
        return hidden_states

450
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
451
452
453
454
455
456
457
458
459
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
460
        loaded_params: set[str] = set()
461
462
463
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
464
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
465
466
467
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
468
469
470
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
471
                # Loading kv cache quantization scales
472
                param = params_dict[scale_name]
473
474
475
476
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
477
                weight_loader(param, loaded_weight)
478
                loaded_params.add(scale_name)
479
                continue
480
481
482
483
484
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
509
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
510
                weight_loader(param, loaded_weight)
511
512
            loaded_params.add(name)
        return loaded_params
513

Woosuk Kwon's avatar
Woosuk Kwon committed
514

515
516
517
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
518
    packed_modules_mapping = {
519
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
520
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
521
522
523
524
525
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
526
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
527
    }
528

529
530
531
532
533
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
534
535
536
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
537
538
539
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
540
541
542
543
544
545
546
547
548
549
550
551
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
552
        "norm": "model.norm",
553
    }
554

555
556
557
558
559
560
561
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
562
        super().__init__()
563
564
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
565
        self.config = config
566

567
568
569
570
571
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
572

573
574
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
575
                config.vocab_size,
576
577
                config.hidden_size,
                quant_config=quant_config,
578
                prefix=maybe_prefix(prefix, "lm_head"),
579
580
            )
            if config.tie_word_embeddings:
581
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
582
583

            logit_scale = getattr(config, "logit_scale", 1.0)
584
            self.logits_processor = LogitsProcessor(
585
                config.vocab_size, scale=logit_scale
586
            )
587
588
        else:
            self.lm_head = PPMissingLayer()
589

590
        self.make_empty_intermediate_tensors = (
591
592
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
593

594
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
595
596
        self.model.aux_hidden_state_layers = layers

597
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
598
599
600
601
602
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
603
604
605
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

606
607
608
609
610
611
612
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
613

614
615
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
616

Woosuk Kwon's avatar
Woosuk Kwon committed
617
618
    def forward(
        self,
619
620
        input_ids: torch.Tensor,
        positions: torch.Tensor,
621
622
623
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
624
625
626
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
627
        return model_output
628

629
630
631
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
632
    ) -> torch.Tensor | None:
633
        logits = self.logits_processor(self.lm_head, hidden_states)
634
635
        return logits

636
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
637
638
        loader = AutoWeightsLoader(
            self,
639
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
640
        )
641
        return loader.load_weights(
642
            self.maybe_remap_mistral(name, loaded_weight)
643
644
            for name, loaded_weight in weights
        )
645

646
647
648
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
649
650
651
        self,
        name: str,
        loaded_weight: torch.Tensor,
652
    ) -> tuple[str, torch.Tensor]:
653
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
654
655
            attn_in = self.config.head_dim * n_heads

656
657
658
659
660
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
661
662
663
664
665

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
666
667
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
668
        if "wk" in modules and modules[-1] == "weight":
669
670
671
672
673
674
675
676
677
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
678
        elif "wq" in modules and modules[-1] == "weight":
679
680
681
682
683
684
685
686
687
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
688

689
690
691
692
693
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

694
            combined_item = f"{item}.{next_item}" if next_item is not None else None
695
696
697
698

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
699
700
701
                name = name.replace(item, mapping[item])

        return name, loaded_weight