llama.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
from vllm.attention import Attention, AttentionType
35
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.layernorm import RMSNorm
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
from vllm.model_executor.layers.rotary_embedding import get_rope
49
from vllm.model_executor.layers.vocab_parallel_embedding import (
50
51
52
    ParallelLMHead,
    VocabParallelEmbedding,
)
53
from vllm.model_executor.model_loader.weight_utils import (
54
55
56
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
57
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
60
61
62
63
64
65
66
67
68
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
69

Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72
73

class LlamaMLP(nn.Module):
    def __init__(
        self,
74
75
76
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
77
        quant_config: QuantizationConfig | None = None,
78
        bias: bool = False,
79
        prefix: str = "",
80
        reduce_results: bool = True,
81
        disable_tp: bool = False,
82
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        super().__init__()
84
        self.gate_up_proj = MergedColumnParallelLinear(
85
86
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
87
            bias=bias,
88
            quant_config=quant_config,
89
            disable_tp=disable_tp,
90
91
92
93
94
95
96
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
97
            reduce_results=reduce_results,
98
            disable_tp=disable_tp,
99
100
            prefix=f"{prefix}.down_proj",
        )
101
        if hidden_act != "silu":
102
103
104
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
105
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107

    def forward(self, x):
108
109
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113
114
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
115
116
117
118
119
120
121
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
122
        quant_config: QuantizationConfig | None = None,
123
124
        bias: bool = False,
        bias_o_proj: bool = False,
125
        cache_config: CacheConfig | None = None,
126
127
128
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
129
        super().__init__()
130
        layer_idx = extract_layer_index(prefix)
131
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
132
        tp_size = get_tensor_model_parallel_world_size()
133
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
134
135
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
136
        self.total_num_kv_heads = num_kv_heads
137
138
139
140
141
142
143
144
145
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
146
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
147
148
149
150
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
151
        # Phi models introduced a partial_rotary_factor parameter in the config
152
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
Zhuohan Li's avatar
Zhuohan Li committed
153
154
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
155
        self.scaling = self.head_dim**-0.5
156
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
157

158
159
160
161
162
163
164
165
        llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

166
        self.qkv_proj = QKVParallelLinear(
167
168
169
170
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
171
            bias=bias,
172
            quant_config=quant_config,
173
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        )
175

176
        self.o_proj = RowParallelLinear(
177
178
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
179
            bias=bias_o_proj,
180
            quant_config=quant_config,
181
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
182
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
183

184
        self._init_rotary_emb(config, quant_config=quant_config)
185

186
187
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
188
189
190
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
191
            if hasattr(config, "target_layer_count"):
192
193
194
195
196
197
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
198
            assert effective_layer_idx < len(layer_types), (
199
200
                f"effective_layer_idx: {effective_layer_idx} \
                is out of bounds for layer_types: {layer_types}"
201
            )
202

203
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
204
205
            if is_sliding:
                sliding_window = config.sliding_window
206

207
208
209
210
211
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
212
213

        self.attn = attn_cls(
214
215
216
217
218
219
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
220
            per_layer_sliding_window=sliding_window,
221
            attn_type=attn_type,
222
            prefix=f"{prefix}.attn",
223
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
224

225
226
227
228
229
230
231
232
233
234
235
    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
    def forward(
        self,
238
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
241
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
242
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
243
        q, k = self.rotary_emb(positions, q, k)
244
245
246
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
247
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
250
        output, _ = self.o_proj(attn_output)
        return output

251
252
253
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
254
        quant_config: QuantizationConfig | None,
255
    ) -> None:
256
257
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
258
        if is_gguf and config.model_type == "llama":
259
260
261
262
263
264
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
265
            rope_parameters=config.rope_parameters,
266
267
268
269
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
270
271

class LlamaDecoderLayer(nn.Module):
272
273
274
275
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
276
        config: LlamaConfig | None = None,
277
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
278
        super().__init__()
279
280
281

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
282
        quant_config = self.get_quant_config(vllm_config)
283

Woosuk Kwon's avatar
Woosuk Kwon committed
284
        self.hidden_size = config.hidden_size
285
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
286
287
288
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
289
290
            config, "bias", False
        )
291
292
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
293
        if hasattr(config, "qkv_bias"):
294
295
            attention_bias = config.qkv_bias

296
297
298
299
300
301
302
303
304
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
305
        self.self_attn = LlamaAttention(
306
            config=config,
307
308
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
309
310
311
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
312
            max_position_embeddings=max_position_embeddings,
313
            quant_config=quant_config,
314
            bias=attention_bias,
315
            bias_o_proj=bias_o_proj,
316
            cache_config=cache_config,
317
            prefix=f"{prefix}.self_attn",
318
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
319
320
        )
        self.mlp = LlamaMLP(
321
322
323
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
324
            quant_config=quant_config,
325
            bias=getattr(config, "mlp_bias", False),
326
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
327
        )
328
329
330
331
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
334

    def forward(
        self,
335
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
336
        hidden_states: torch.Tensor,
337
        residual: torch.Tensor | None,
338
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
339
        # Self Attention
340
341
342
343
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
344
345
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
346
347

        # Fully Connected
348
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
349
        hidden_states = self.mlp(hidden_states)
350
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
351

352
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
353
354
355
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
356

357
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
358
class LlamaModel(nn.Module):
359
360
361
362
363
364
365
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
366
        super().__init__()
367
368
369
370

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
371
        self.config = config
372
        self.quant_config = quant_config
373
374
375

        self.vocab_size = config.vocab_size

376
377
378
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
379
380
381
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
382
                quant_config=quant_config,
383
384
385
            )
        else:
            self.embed_tokens = PPMissingLayer()
386
        self.start_layer, self.end_layer, self.layers = make_layers(
387
            config.num_hidden_layers,
388
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
389
390
            prefix=f"{prefix}.layers",
        )
391
392
393
394
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
395

396
        self.aux_hidden_state_layers = tuple[int, ...]()
397

398
399
400
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
401

402
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
403
404
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
    def forward(
        self,
407
        input_ids: torch.Tensor | None,
408
        positions: torch.Tensor,
409
410
411
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
412
413
414
415
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
416
                hidden_states = self.embed_input_ids(input_ids)
417
            residual = None
418
        else:
419
420
421
422
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

423
424
        aux_hidden_states = []
        for idx, layer in enumerate(
425
426
            islice(self.layers, self.start_layer, self.end_layer)
        ):
427
428
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
429
            hidden_states, residual = layer(positions, hidden_states, residual)
430
431

        if not get_pp_group().is_last_rank:
432
433
434
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
435

436
        hidden_states, _ = self.norm(hidden_states, residual)
437
438
439

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
440
441
        return hidden_states

442
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
443
444
445
446
447
448
449
450
451
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
452
        loaded_params: set[str] = set()
453
454
455
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
456
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
457
458
459
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
460
461
462
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
463
                # Loading kv cache quantization scales
464
                param = params_dict[scale_name]
465
466
467
468
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
469
                weight_loader(param, loaded_weight)
470
                loaded_params.add(scale_name)
471
                continue
472
473
474
475
476
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
501
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
502
                weight_loader(param, loaded_weight)
503
504
            loaded_params.add(name)
        return loaded_params
505

Woosuk Kwon's avatar
Woosuk Kwon committed
506

507
508
509
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
510
    packed_modules_mapping = {
511
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
512
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
513
514
515
516
517
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
518
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
519
520
    }
    embedding_padding_modules = ["lm_head"]
521

522
523
524
525
526
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
527
528
529
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
530
531
532
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
533
534
535
536
537
538
539
540
541
542
543
544
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
545
        "norm": "model.norm",
546
    }
547

548
549
550
551
552
553
554
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
555
        super().__init__()
556
557
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
558
        self.config = config
559

560
561
562
563
564
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
565

566
567
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
568
                config.vocab_size,
569
570
                config.hidden_size,
                quant_config=quant_config,
571
                prefix=maybe_prefix(prefix, "lm_head"),
572
573
            )
            if config.tie_word_embeddings:
574
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
575
576

            logit_scale = getattr(config, "logit_scale", 1.0)
577
            self.logits_processor = LogitsProcessor(
578
                config.vocab_size, scale=logit_scale
579
            )
580
581
        else:
            self.lm_head = PPMissingLayer()
582

583
        self.make_empty_intermediate_tensors = (
584
585
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
586

587
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
588
589
        self.model.aux_hidden_state_layers = layers

590
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
591
592
593
594
595
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
596
597
598
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

599
600
601
602
603
604
605
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
606

607
608
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
609

Woosuk Kwon's avatar
Woosuk Kwon committed
610
611
    def forward(
        self,
612
613
        input_ids: torch.Tensor,
        positions: torch.Tensor,
614
615
616
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
617
618
619
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
620
        return model_output
621

622
623
624
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
625
    ) -> torch.Tensor | None:
626
        logits = self.logits_processor(self.lm_head, hidden_states)
627
628
        return logits

629
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
630
631
        loader = AutoWeightsLoader(
            self,
632
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
633
        )
634
        return loader.load_weights(
635
            self.maybe_remap_mistral(name, loaded_weight)
636
637
            for name, loaded_weight in weights
        )
638

639
640
641
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
642
643
644
        self,
        name: str,
        loaded_weight: torch.Tensor,
645
    ) -> tuple[str, torch.Tensor]:
646
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
647
648
            attn_in = self.config.head_dim * n_heads

649
650
651
652
653
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
654
655
656
657
658

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
659
660
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
661
        if "wk" in modules and modules[-1] == "weight":
662
663
664
665
666
667
668
669
670
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
671
        elif "wq" in modules and modules[-1] == "weight":
672
673
674
675
676
677
678
679
680
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
681

682
683
684
685
686
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

687
            combined_item = f"{item}.{next_item}" if next_item is not None else None
688
689
690
691

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
692
693
694
                name = name.replace(item, mapping[item])

        return name, loaded_weight