llama.py 22.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
from vllm.attention.layer import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.activation import SiluAndMul
39
40
41
from vllm.model_executor.layers.attention.encoder_only_attention import (
    EncoderOnlyAttention,
)
42
from vllm.model_executor.layers.layernorm import RMSNorm
43
44
45
46
47
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
48
from vllm.model_executor.layers.logits_processor import LogitsProcessor
49
from vllm.model_executor.layers.quantization import QuantizationConfig
50
from vllm.model_executor.layers.rotary_embedding import get_rope
51
from vllm.model_executor.layers.vocab_parallel_embedding import (
52
53
54
    ParallelLMHead,
    VocabParallelEmbedding,
)
55
from vllm.model_executor.model_loader.weight_utils import (
56
57
58
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
59
from vllm.sequence import IntermediateTensors
60
from vllm.v1.attention.backend import AttentionType
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
63
64
65
66
67
68
from .adapters import as_embedding_model, as_seq_cls_model
from .interfaces import (
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
69
70
71
72
73
74
75
76
77
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
78

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82

class LlamaMLP(nn.Module):
    def __init__(
        self,
83
84
85
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
86
        quant_config: QuantizationConfig | None = None,
87
        bias: bool = False,
88
        prefix: str = "",
89
        reduce_results: bool = True,
90
        disable_tp: bool = False,
91
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        super().__init__()
93
        self.gate_up_proj = MergedColumnParallelLinear(
94
95
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
96
            bias=bias,
97
            quant_config=quant_config,
98
            disable_tp=disable_tp,
99
100
101
102
103
104
105
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
106
            reduce_results=reduce_results,
107
            disable_tp=disable_tp,
108
109
            prefix=f"{prefix}.down_proj",
        )
110
        if hidden_act != "silu":
111
112
113
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116

    def forward(self, x):
117
118
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
122
123
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
124
125
126
127
128
129
130
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
131
        quant_config: QuantizationConfig | None = None,
132
133
        bias: bool = False,
        bias_o_proj: bool = False,
134
        cache_config: CacheConfig | None = None,
135
136
137
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
138
        super().__init__()
139
        layer_idx = extract_layer_index(prefix)
140
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
141
        tp_size = get_tensor_model_parallel_world_size()
142
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
143
144
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
145
        self.total_num_kv_heads = num_kv_heads
146
147
148
149
150
151
152
153
154
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
155

156
        head_dim = getattr(config, "head_dim", None)
157
        self.head_dim = head_dim or self.hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
158
159
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
160
        self.scaling = self.head_dim**-0.5
161
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
162

163
        self.qkv_proj = QKVParallelLinear(
164
165
166
167
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
168
            bias=bias,
169
            quant_config=quant_config,
170
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
171
        )
172

173
        self.o_proj = RowParallelLinear(
174
175
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
176
            bias=bias_o_proj,
177
            quant_config=quant_config,
178
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
179
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
180

181
        self._init_rotary_emb(config, quant_config=quant_config)
182

183
184
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
185
186
187
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
188
            if hasattr(config, "target_layer_count"):
189
190
191
192
193
194
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
195
            assert effective_layer_idx < len(layer_types), (
196
197
                f"effective_layer_idx: {effective_layer_idx} "
                f"is out of bounds for layer_types: {layer_types}"
198
            )
199

200
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
201
202
            if is_sliding:
                sliding_window = config.sliding_window
203

204
205
206
207
208
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
209
210

        self.attn = attn_cls(
211
212
213
214
215
216
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
217
            per_layer_sliding_window=sliding_window,
218
            attn_type=attn_type,
219
            prefix=f"{prefix}.attn",
220
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
223

    def forward(
        self,
224
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
227
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
228
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
229
        q, k = self.rotary_emb(positions, q, k)
230
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
233
        output, _ = self.o_proj(attn_output)
        return output

234
235
236
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
237
        quant_config: QuantizationConfig | None,
238
    ) -> None:
239
240
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
241
        if is_gguf and config.model_type == "llama":
242
243
244
245
246
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
247
            rope_parameters=getattr(config, "rope_parameters", None),
248
249
250
            is_neox_style=is_neox_style,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
251
252

class LlamaDecoderLayer(nn.Module):
253
254
255
256
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
257
        config: LlamaConfig | None = None,
258
        attn_layer_type: type[nn.Module] = LlamaAttention,
259
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
260
        super().__init__()
261
262
263

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
264
        quant_config = self.get_quant_config(vllm_config)
265

Woosuk Kwon's avatar
Woosuk Kwon committed
266
        self.hidden_size = config.hidden_size
267
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
268
269
270
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
271
272
            config, "bias", False
        )
273
274
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
275
        if hasattr(config, "qkv_bias"):
276
277
            attention_bias = config.qkv_bias

278
279
280
281
282
283
284
285
286
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

287
        self.self_attn = attn_layer_type(
288
            config=config,
289
290
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
291
292
293
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
294
            max_position_embeddings=max_position_embeddings,
295
            quant_config=quant_config,
296
            bias=attention_bias,
297
            bias_o_proj=bias_o_proj,
298
            cache_config=cache_config,
299
            prefix=f"{prefix}.self_attn",
300
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
        )
        self.mlp = LlamaMLP(
303
304
305
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
306
            quant_config=quant_config,
307
            bias=getattr(config, "mlp_bias", False),
308
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
309
        )
310
311
312
313
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
316

    def forward(
        self,
317
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
318
        hidden_states: torch.Tensor,
319
        residual: torch.Tensor | None,
320
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        # Self Attention
322
323
324
325
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
326
327
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329

        # Fully Connected
330
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
331
        hidden_states = self.mlp(hidden_states)
332
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
333

334
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
335
336
337
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
338

339
340
341
342
343
344
345
346
347
348
def llama_model_invariants(
    input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
    """Shape invariants for Llama model compilation, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])


349
350
351
352
353
@support_torch_compile(
    # TODO[#32068]: Investigate recompilation
    # mark_unbacked_dims={"input_ids": 0},
    shape_invariants=llama_model_invariants
)
Woosuk Kwon's avatar
Woosuk Kwon committed
354
class LlamaModel(nn.Module):
355
356
357
358
359
360
361
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
362
        super().__init__()
363
364
365
366

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
367
        self.config = config
368
        self.quant_config = quant_config
369
370
371

        self.vocab_size = config.vocab_size

372
373
374
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
375
376
377
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
378
                quant_config=quant_config,
379
380
381
            )
        else:
            self.embed_tokens = PPMissingLayer()
382
        self.start_layer, self.end_layer, self.layers = make_layers(
383
            config.num_hidden_layers,
384
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
385
386
            prefix=f"{prefix}.layers",
        )
387
388
389
390
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
391

392
        self.aux_hidden_state_layers = tuple[int, ...]()
393

394
395
396
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
397

398
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
399
400
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
    def forward(
        self,
403
        input_ids: torch.Tensor | None,
404
        positions: torch.Tensor,
405
406
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
407
        **extra_layer_kwargs,
408
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
409
410
411
412
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
413
                hidden_states = self.embed_input_ids(input_ids)
414
            residual = None
415
        else:
416
417
418
419
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

420
421
        aux_hidden_states = []
        for idx, layer in enumerate(
422
423
            islice(self.layers, self.start_layer, self.end_layer)
        ):
424
425
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
426
427
428
            hidden_states, residual = layer(
                positions, hidden_states, residual, **extra_layer_kwargs
            )
429
430

        if not get_pp_group().is_last_rank:
431
432
433
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
434

435
        hidden_states, _ = self.norm(hidden_states, residual)
436
437
438

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
439
440
        return hidden_states

441
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
442
443
444
445
446
447
448
449
450
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
451
        loaded_params: set[str] = set()
452
453
454
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
455
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
456
457
458
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
459
460
461
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
462
                # Loading kv cache quantization scales
463
                param = params_dict[scale_name]
464
465
466
467
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
468
                weight_loader(param, loaded_weight)
469
                loaded_params.add(scale_name)
470
                continue
471
472
            if "scale" in name or "zero_point" in name:
                # Remapping the name of FP8 kv-scale or zero point.
473
474
475
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
500
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
501
                weight_loader(param, loaded_weight)
502
503
            loaded_params.add(name)
        return loaded_params
504

Woosuk Kwon's avatar
Woosuk Kwon committed
505

506
507
508
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
509
    packed_modules_mapping = {
510
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
511
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
512
513
514
515
516
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
517
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
518
    }
519

520
521
522
523
524
525
526
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
527
        super().__init__()
528
529
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
530
        self.config = config
531

532
533
534
535
536
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
537

538
539
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
540
                config.vocab_size,
541
542
                config.hidden_size,
                quant_config=quant_config,
543
                prefix=maybe_prefix(prefix, "lm_head"),
544
545
            )
            if config.tie_word_embeddings:
546
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
547
548

            logit_scale = getattr(config, "logit_scale", 1.0)
549
            self.logits_processor = LogitsProcessor(
550
                config.vocab_size, scale=logit_scale
551
            )
552
553
        else:
            self.lm_head = PPMissingLayer()
554

555
        self.make_empty_intermediate_tensors = (
556
557
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
558

559
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
560
561
        self.model.aux_hidden_state_layers = layers

562
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
563
564
565
566
567
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
568
569
570
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

571
572
573
574
575
576
577
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
578

579
580
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
581

Woosuk Kwon's avatar
Woosuk Kwon committed
582
583
    def forward(
        self,
584
        input_ids: torch.Tensor | None,
585
        positions: torch.Tensor,
586
587
588
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
589
590
591
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
592
        return model_output
593

594
595
596
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
597
    ) -> torch.Tensor | None:
598
        logits = self.logits_processor(self.lm_head, hidden_states)
599
600
        return logits

601
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
602
603
        loader = AutoWeightsLoader(
            self,
604
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
605
        )
606
        return loader.load_weights(weights)
607
608
609
610
611
612
613
614
615
616
617
618


class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
    # This class sets the correct attention type and pooling type
    # through LlamaBidirectionalConfig.
    pass


class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
    # This class sets the correct attention type and pooling type
    # through LlamaBidirectionalConfig.
    pass