"vllm/vscode:/vscode.git/clone" did not exist on "8510c10c99e68ba142e4331d9a9a8777921ad910"
llama.py 25.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
35
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
36
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.activation import SiluAndMul
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
51
52
53
    ParallelLMHead,
    VocabParallelEmbedding,
)
54
from vllm.model_executor.model_loader.weight_utils import (
55
56
57
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
58
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
61
62
63
64
65
66
67
68
69
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
70

Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74

class LlamaMLP(nn.Module):
    def __init__(
        self,
75
76
77
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
78
        quant_config: QuantizationConfig | None = None,
79
        bias: bool = False,
80
        prefix: str = "",
81
        reduce_results: bool = True,
82
        disable_tp: bool = False,
83
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        super().__init__()
85
        self.gate_up_proj = MergedColumnParallelLinear(
86
87
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
88
            bias=bias,
89
            quant_config=quant_config,
90
            disable_tp=disable_tp,
91
92
93
94
95
96
97
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
98
            reduce_results=reduce_results,
99
            disable_tp=disable_tp,
100
101
            prefix=f"{prefix}.down_proj",
        )
102
        if hidden_act != "silu":
103
104
105
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
106
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108

    def forward(self, x):
109
110
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
114
115
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
116
117
118
119
120
121
122
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
123
        quant_config: QuantizationConfig | None = None,
124
125
        bias: bool = False,
        bias_o_proj: bool = False,
126
        cache_config: CacheConfig | None = None,
127
128
129
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        super().__init__()
131
        layer_idx = extract_layer_index(prefix)
132
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
133
        tp_size = get_tensor_model_parallel_world_size()
134
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
135
136
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
137
        self.total_num_kv_heads = num_kv_heads
138
139
140
141
142
143
144
145
146
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
147
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
148
149
150
151
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Zhuohan Li's avatar
Zhuohan Li committed
152
153
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
154
        self.scaling = self.head_dim**-0.5
155
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
156

157
158
159
160
161
162
163
164
        llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

165
        self.qkv_proj = QKVParallelLinear(
166
167
168
169
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
170
            bias=bias,
171
            quant_config=quant_config,
172
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
173
        )
174

175
        self.o_proj = RowParallelLinear(
176
177
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
178
            bias=bias_o_proj,
179
            quant_config=quant_config,
180
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
181
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
182

183
        self._init_rotary_emb(config, quant_config=quant_config)
184

185
186
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
187
188
189
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
190
            if hasattr(config, "target_layer_count"):
191
192
193
194
195
196
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
197
            assert effective_layer_idx < len(layer_types), (
198
199
                f"effective_layer_idx: {effective_layer_idx} \
                is out of bounds for layer_types: {layer_types}"
200
            )
201

202
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
203
204
            if is_sliding:
                sliding_window = config.sliding_window
205

206
207
208
209
210
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
211
212

        self.attn = attn_cls(
213
214
215
216
217
218
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
219
            per_layer_sliding_window=sliding_window,
220
            attn_type=attn_type,
221
            prefix=f"{prefix}.attn",
222
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
223

224
225
226
227
228
229
230
231
232
233
234
    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
    def forward(
        self,
237
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
240
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
241
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
242
        q, k = self.rotary_emb(positions, q, k)
243
244
245
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
246
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
        output, _ = self.o_proj(attn_output)
        return output

250
251
252
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
253
        quant_config: QuantizationConfig | None,
254
    ) -> None:
255
256
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
257
        if is_gguf and config.model_type == "llama":
258
259
260
261
262
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
263
            rope_parameters=getattr(config, "rope_parameters", None),
264
265
266
            is_neox_style=is_neox_style,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
267
268

class LlamaDecoderLayer(nn.Module):
269
270
271
272
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
273
        config: LlamaConfig | None = None,
274
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
275
        super().__init__()
276
277
278

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
279
        quant_config = self.get_quant_config(vllm_config)
280

Woosuk Kwon's avatar
Woosuk Kwon committed
281
        self.hidden_size = config.hidden_size
282
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
283
284
285
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
286
287
            config, "bias", False
        )
288
289
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
290
        if hasattr(config, "qkv_bias"):
291
292
            attention_bias = config.qkv_bias

293
294
295
296
297
298
299
300
301
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
302
        self.self_attn = LlamaAttention(
303
            config=config,
304
305
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
306
307
308
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
309
            max_position_embeddings=max_position_embeddings,
310
            quant_config=quant_config,
311
            bias=attention_bias,
312
            bias_o_proj=bias_o_proj,
313
            cache_config=cache_config,
314
            prefix=f"{prefix}.self_attn",
315
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
316
317
        )
        self.mlp = LlamaMLP(
318
319
320
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
321
            quant_config=quant_config,
322
            bias=getattr(config, "mlp_bias", False),
323
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
324
        )
325
326
327
328
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
329
330
331

    def forward(
        self,
332
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
333
        hidden_states: torch.Tensor,
334
        residual: torch.Tensor | None,
335
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
336
        # Self Attention
337
338
339
340
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
341
342
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
343
344

        # Fully Connected
345
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
346
        hidden_states = self.mlp(hidden_states)
347
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
348

349
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
350
351
352
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
353

354
355
356
357
358
359
360
361
362
363
364
def llama_model_invariants(
    input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
    """Shape invariants for Llama model compilation, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])


@support_torch_compile(shape_invariants=llama_model_invariants)
Woosuk Kwon's avatar
Woosuk Kwon committed
365
class LlamaModel(nn.Module):
366
367
368
369
370
371
372
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
373
        super().__init__()
374
375
376
377

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
378
        self.config = config
379
        self.quant_config = quant_config
380
381
382

        self.vocab_size = config.vocab_size

383
384
385
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
386
387
388
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
389
                quant_config=quant_config,
390
391
392
            )
        else:
            self.embed_tokens = PPMissingLayer()
393
        self.start_layer, self.end_layer, self.layers = make_layers(
394
            config.num_hidden_layers,
395
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
396
397
            prefix=f"{prefix}.layers",
        )
398
399
400
401
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
402

403
        self.aux_hidden_state_layers = tuple[int, ...]()
404

405
406
407
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
408

409
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
410
411
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
    def forward(
        self,
414
        input_ids: torch.Tensor | None,
415
        positions: torch.Tensor,
416
417
418
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
419
420
421
422
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
423
                hidden_states = self.embed_input_ids(input_ids)
424
            residual = None
425
        else:
426
427
428
429
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

430
431
        aux_hidden_states = []
        for idx, layer in enumerate(
432
433
            islice(self.layers, self.start_layer, self.end_layer)
        ):
434
435
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
436
            hidden_states, residual = layer(positions, hidden_states, residual)
437
438

        if not get_pp_group().is_last_rank:
439
440
441
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
442

443
        hidden_states, _ = self.norm(hidden_states, residual)
444
445
446

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
447
448
        return hidden_states

449
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
450
451
452
453
454
455
456
457
458
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
459
        loaded_params: set[str] = set()
460
461
462
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
463
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
464
465
466
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
467
468
469
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
470
                # Loading kv cache quantization scales
471
                param = params_dict[scale_name]
472
473
474
475
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
476
                weight_loader(param, loaded_weight)
477
                loaded_params.add(scale_name)
478
                continue
479
480
481
482
483
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
508
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
509
                weight_loader(param, loaded_weight)
510
511
            loaded_params.add(name)
        return loaded_params
512

Woosuk Kwon's avatar
Woosuk Kwon committed
513

514
515
516
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
517
    packed_modules_mapping = {
518
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
519
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
520
521
522
523
524
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
525
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
526
    }
527

528
529
530
531
532
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
533
534
535
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
536
537
538
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
539
540
541
542
543
544
545
546
547
548
549
550
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
551
        "norm": "model.norm",
552
    }
553

554
555
556
557
558
559
560
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
561
        super().__init__()
562
563
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
564
        self.config = config
565

566
567
568
569
570
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
571

572
573
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
574
                config.vocab_size,
575
576
                config.hidden_size,
                quant_config=quant_config,
577
                prefix=maybe_prefix(prefix, "lm_head"),
578
579
            )
            if config.tie_word_embeddings:
580
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
581
582

            logit_scale = getattr(config, "logit_scale", 1.0)
583
            self.logits_processor = LogitsProcessor(
584
                config.vocab_size, scale=logit_scale
585
            )
586
587
        else:
            self.lm_head = PPMissingLayer()
588

589
        self.make_empty_intermediate_tensors = (
590
591
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
592

593
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
594
595
        self.model.aux_hidden_state_layers = layers

596
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
597
598
599
600
601
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
602
603
604
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

605
606
607
608
609
610
611
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
612

613
614
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
615

Woosuk Kwon's avatar
Woosuk Kwon committed
616
617
    def forward(
        self,
618
619
        input_ids: torch.Tensor,
        positions: torch.Tensor,
620
621
622
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
623
624
625
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
626
        return model_output
627

628
629
630
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
631
    ) -> torch.Tensor | None:
632
        logits = self.logits_processor(self.lm_head, hidden_states)
633
634
        return logits

635
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
636
637
        loader = AutoWeightsLoader(
            self,
638
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
639
        )
640
        return loader.load_weights(
641
            self.maybe_remap_mistral(name, loaded_weight)
642
643
            for name, loaded_weight in weights
        )
644

645
646
647
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
648
649
650
        self,
        name: str,
        loaded_weight: torch.Tensor,
651
    ) -> tuple[str, torch.Tensor]:
652
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
653
654
            attn_in = self.config.head_dim * n_heads

655
656
657
658
659
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
660
661
662
663
664

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
665
666
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
667
        if "wk" in modules and modules[-1] == "weight":
668
669
670
671
672
673
674
675
676
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
677
        elif "wq" in modules and modules[-1] == "weight":
678
679
680
681
682
683
684
685
686
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
687

688
689
690
691
692
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

693
            combined_item = f"{item}.{next_item}" if next_item is not None else None
694
695
696
697

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
698
699
700
                name = name.replace(item, mapping[item])

        return name, loaded_weight