llama.py 26.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
35
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
36
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.activation import SiluAndMul
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
51
52
53
    ParallelLMHead,
    VocabParallelEmbedding,
)
54
from vllm.model_executor.model_loader.weight_utils import (
55
56
57
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
58
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
61
62
63
64
65
66
67
from .adapters import as_embedding_model, as_seq_cls_model
from .interfaces import (
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
from .interfaces_base import attn_type
68
69
70
71
72
73
74
75
76
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
77

Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81

class LlamaMLP(nn.Module):
    def __init__(
        self,
82
83
84
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
85
        quant_config: QuantizationConfig | None = None,
86
        bias: bool = False,
87
        prefix: str = "",
88
        reduce_results: bool = True,
89
        disable_tp: bool = False,
90
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        super().__init__()
92
        self.gate_up_proj = MergedColumnParallelLinear(
93
94
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
95
            bias=bias,
96
            quant_config=quant_config,
97
            disable_tp=disable_tp,
98
99
100
101
102
103
104
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
105
            reduce_results=reduce_results,
106
            disable_tp=disable_tp,
107
108
            prefix=f"{prefix}.down_proj",
        )
109
        if hidden_act != "silu":
110
111
112
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
113
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115

    def forward(self, x):
116
117
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
121
122
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
123
124
125
126
127
128
129
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
130
        quant_config: QuantizationConfig | None = None,
131
132
        bias: bool = False,
        bias_o_proj: bool = False,
133
        cache_config: CacheConfig | None = None,
134
135
136
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
137
        super().__init__()
138
        layer_idx = extract_layer_index(prefix)
139
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
140
        tp_size = get_tensor_model_parallel_world_size()
141
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
142
143
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
144
        self.total_num_kv_heads = num_kv_heads
145
146
147
148
149
150
151
152
153
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
154
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
155
156
157
158
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Zhuohan Li's avatar
Zhuohan Li committed
159
160
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
161
        self.scaling = self.head_dim**-0.5
162
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
163

164
165
166
167
168
169
170
171
        llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
        self.do_llama_4_scaling = llama_4_scaling_config is not None
        if self.do_llama_4_scaling:
            self.llama_4_scaling_original_max_position_embeddings = (
                llama_4_scaling_config["original_max_position_embeddings"]
            )
            self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

172
        self.qkv_proj = QKVParallelLinear(
173
174
175
176
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
177
            bias=bias,
178
            quant_config=quant_config,
179
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
180
        )
181

182
        self.o_proj = RowParallelLinear(
183
184
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
185
            bias=bias_o_proj,
186
            quant_config=quant_config,
187
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
188
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
189

190
        self._init_rotary_emb(config, quant_config=quant_config)
191

192
193
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
194
195
196
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
197
            if hasattr(config, "target_layer_count"):
198
199
200
201
202
203
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
204
            assert effective_layer_idx < len(layer_types), (
205
206
                f"effective_layer_idx: {effective_layer_idx} \
                is out of bounds for layer_types: {layer_types}"
207
            )
208

209
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
210
211
            if is_sliding:
                sliding_window = config.sliding_window
212

213
214
215
216
217
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
218
219

        self.attn = attn_cls(
220
221
222
223
224
225
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
226
            per_layer_sliding_window=sliding_window,
227
            attn_type=attn_type,
228
            prefix=f"{prefix}.attn",
229
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
230

231
232
233
234
235
236
237
238
239
240
241
    def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        # Llama4 scaling
        scaling = 1 + self.llama_4_scaling_beta * torch.log(
            1
            + torch.floor(
                positions / self.llama_4_scaling_original_max_position_embeddings
            )
        )
        # Broadcast over head_dim
        return scaling.unsqueeze(-1)

Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
    def forward(
        self,
244
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
247
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
248
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
249
        q, k = self.rotary_emb(positions, q, k)
250
251
252
        if self.do_llama_4_scaling:
            attn_scale = self._get_llama_4_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
253
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
        output, _ = self.o_proj(attn_output)
        return output

257
258
259
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
260
        quant_config: QuantizationConfig | None,
261
    ) -> None:
262
263
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
264
        if is_gguf and config.model_type == "llama":
265
266
267
268
269
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
270
            rope_parameters=getattr(config, "rope_parameters", None),
271
272
273
            is_neox_style=is_neox_style,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
274
275

class LlamaDecoderLayer(nn.Module):
276
277
278
279
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
280
        config: LlamaConfig | None = None,
281
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
282
        super().__init__()
283
284
285

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
286
        quant_config = self.get_quant_config(vllm_config)
287

Woosuk Kwon's avatar
Woosuk Kwon committed
288
        self.hidden_size = config.hidden_size
289
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
290
291
292
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
293
294
            config, "bias", False
        )
295
296
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
297
        if hasattr(config, "qkv_bias"):
298
299
            attention_bias = config.qkv_bias

300
301
302
303
304
305
306
307
308
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
309
        self.self_attn = LlamaAttention(
310
            config=config,
311
312
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
313
314
315
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
316
            max_position_embeddings=max_position_embeddings,
317
            quant_config=quant_config,
318
            bias=attention_bias,
319
            bias_o_proj=bias_o_proj,
320
            cache_config=cache_config,
321
            prefix=f"{prefix}.self_attn",
322
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
323
324
        )
        self.mlp = LlamaMLP(
325
326
327
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
328
            quant_config=quant_config,
329
            bias=getattr(config, "mlp_bias", False),
330
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
331
        )
332
333
334
335
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
336
337
338

    def forward(
        self,
339
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
340
        hidden_states: torch.Tensor,
341
        residual: torch.Tensor | None,
342
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
343
        # Self Attention
344
345
346
347
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
348
349
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351

        # Fully Connected
352
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
353
        hidden_states = self.mlp(hidden_states)
354
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
355

356
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
357
358
359
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
360

361
362
363
364
365
366
367
368
369
370
371
def llama_model_invariants(
    input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
    """Shape invariants for Llama model compilation, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])


@support_torch_compile(shape_invariants=llama_model_invariants)
Woosuk Kwon's avatar
Woosuk Kwon committed
372
class LlamaModel(nn.Module):
373
374
375
376
377
378
379
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
380
        super().__init__()
381
382
383
384

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
385
        self.config = config
386
        self.quant_config = quant_config
387
388
389

        self.vocab_size = config.vocab_size

390
391
392
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
393
394
395
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
396
                quant_config=quant_config,
397
398
399
            )
        else:
            self.embed_tokens = PPMissingLayer()
400
        self.start_layer, self.end_layer, self.layers = make_layers(
401
            config.num_hidden_layers,
402
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
403
404
            prefix=f"{prefix}.layers",
        )
405
406
407
408
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
409

410
        self.aux_hidden_state_layers = tuple[int, ...]()
411

412
413
414
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
415

416
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
417
418
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
419
420
    def forward(
        self,
421
        input_ids: torch.Tensor | None,
422
        positions: torch.Tensor,
423
424
425
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
426
427
428
429
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
430
                hidden_states = self.embed_input_ids(input_ids)
431
            residual = None
432
        else:
433
434
435
436
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

437
438
        aux_hidden_states = []
        for idx, layer in enumerate(
439
440
            islice(self.layers, self.start_layer, self.end_layer)
        ):
441
442
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
443
            hidden_states, residual = layer(positions, hidden_states, residual)
444
445

        if not get_pp_group().is_last_rank:
446
447
448
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
449

450
        hidden_states, _ = self.norm(hidden_states, residual)
451
452
453

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
454
455
        return hidden_states

456
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
457
458
459
460
461
462
463
464
465
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
466
        loaded_params: set[str] = set()
467
468
469
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
470
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
471
472
473
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
474
475
476
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
477
                # Loading kv cache quantization scales
478
                param = params_dict[scale_name]
479
480
481
482
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
483
                weight_loader(param, loaded_weight)
484
                loaded_params.add(scale_name)
485
                continue
486
487
488
489
490
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
515
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
516
                weight_loader(param, loaded_weight)
517
518
            loaded_params.add(name)
        return loaded_params
519

Woosuk Kwon's avatar
Woosuk Kwon committed
520

521
522
523
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
524
    packed_modules_mapping = {
525
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
526
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
527
528
529
530
531
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
532
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
533
    }
534

535
536
537
538
539
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
540
541
542
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
543
544
545
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
546
547
548
549
550
551
552
553
554
555
556
557
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
558
        "norm": "model.norm",
559
    }
560

561
562
563
564
565
566
567
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
568
        super().__init__()
569
570
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
571
        self.config = config
572

573
574
575
576
577
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
578

579
580
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
581
                config.vocab_size,
582
583
                config.hidden_size,
                quant_config=quant_config,
584
                prefix=maybe_prefix(prefix, "lm_head"),
585
586
            )
            if config.tie_word_embeddings:
587
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
588
589

            logit_scale = getattr(config, "logit_scale", 1.0)
590
            self.logits_processor = LogitsProcessor(
591
                config.vocab_size, scale=logit_scale
592
            )
593
594
        else:
            self.lm_head = PPMissingLayer()
595

596
        self.make_empty_intermediate_tensors = (
597
598
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
599

600
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
601
602
        self.model.aux_hidden_state_layers = layers

603
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
604
605
606
607
608
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
609
610
611
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

612
613
614
615
616
617
618
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
619

620
621
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
622

Woosuk Kwon's avatar
Woosuk Kwon committed
623
624
    def forward(
        self,
625
626
        input_ids: torch.Tensor,
        positions: torch.Tensor,
627
628
629
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
630
631
632
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
633
        return model_output
634

635
636
637
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
638
    ) -> torch.Tensor | None:
639
        logits = self.logits_processor(self.lm_head, hidden_states)
640
641
        return logits

642
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
643
644
        loader = AutoWeightsLoader(
            self,
645
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
646
        )
647
        return loader.load_weights(
648
            self.maybe_remap_mistral(name, loaded_weight)
649
650
            for name, loaded_weight in weights
        )
651

652
653
654
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
655
656
657
        self,
        name: str,
        loaded_weight: torch.Tensor,
658
    ) -> tuple[str, torch.Tensor]:
659
        def permute(w: torch.Tensor, n_heads: int, attn_out: int):
660
661
            attn_in = self.config.head_dim * n_heads

662
663
664
665
666
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
667
668
669
670
671

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
672
673
        # If using quantized model in mistral format,
        # quantization scales (qscale_weight) also need to be sliced
674
        if "wk" in modules and modules[-1] == "weight":
675
676
677
678
679
680
681
682
683
            loaded_weight = permute(
                loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
            )
        elif (
            "wk" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
684
        elif "wq" in modules and modules[-1] == "weight":
685
686
687
688
689
690
691
692
693
            loaded_weight = permute(
                loaded_weight, self.config.num_attention_heads, self.config.hidden_size
            )
        elif (
            "wq" in modules
            and modules[-1] == "qscale_weight"
            and loaded_weight.numel() > 1
        ):
            loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
694

695
696
697
698
699
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

700
            combined_item = f"{item}.{next_item}" if next_item is not None else None
701
702
703
704

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
705
706
707
                name = name.replace(item, mapping[item])

        return name, loaded_weight
708
709
710
711
712
713
714
715
716
717
718
719
720
721


@attn_type("encoder_only")
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
    # This class sets the correct attention type and pooling type
    # through LlamaBidirectionalConfig.
    pass


@attn_type("encoder_only")
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
    # This class sets the correct attention type and pooling type
    # through LlamaBidirectionalConfig.
    pass