llama.py 22.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

34
from vllm.attention.layer import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.activation import SiluAndMul
39
40
41
from vllm.model_executor.layers.attention.encoder_only_attention import (
    EncoderOnlyAttention,
)
42
from vllm.model_executor.layers.layernorm import RMSNorm
43
44
45
46
47
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
48
from vllm.model_executor.layers.logits_processor import LogitsProcessor
49
from vllm.model_executor.layers.quantization import QuantizationConfig
50
from vllm.model_executor.layers.rotary_embedding import get_rope
51
from vllm.model_executor.layers.vocab_parallel_embedding import (
52
53
54
    ParallelLMHead,
    VocabParallelEmbedding,
)
55
from vllm.model_executor.model_loader.weight_utils import (
56
57
58
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
59
from vllm.sequence import IntermediateTensors
60
from vllm.v1.attention.backend import AttentionType
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
63
64
65
66
67
68
from .adapters import as_embedding_model, as_seq_cls_model
from .interfaces import (
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
69
70
71
72
73
74
75
76
77
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
78

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82

class LlamaMLP(nn.Module):
    def __init__(
        self,
83
84
85
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
86
        quant_config: QuantizationConfig | None = None,
87
        bias: bool = False,
88
        prefix: str = "",
89
        reduce_results: bool = True,
90
        disable_tp: bool = False,
91
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        super().__init__()
93
        self.gate_up_proj = MergedColumnParallelLinear(
94
95
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
96
            bias=bias,
97
            quant_config=quant_config,
98
            disable_tp=disable_tp,
99
100
101
102
103
104
105
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
106
            reduce_results=reduce_results,
107
            disable_tp=disable_tp,
108
109
            prefix=f"{prefix}.down_proj",
        )
110
        if hidden_act != "silu":
111
112
113
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116

    def forward(self, x):
117
118
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
122
123
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
124
125
126
127
128
129
130
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
131
        quant_config: QuantizationConfig | None = None,
132
133
        bias: bool = False,
        bias_o_proj: bool = False,
134
        cache_config: CacheConfig | None = None,
135
136
137
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
138
        super().__init__()
139
        layer_idx = extract_layer_index(prefix)
140
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
141
        tp_size = get_tensor_model_parallel_world_size()
142
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
143
144
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
145
        self.total_num_kv_heads = num_kv_heads
146
147
148
149
150
151
152
153
154
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
155

156
        head_dim = getattr(config, "head_dim", None)
157
        self.head_dim = head_dim or self.hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
158
159
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
160
        self.scaling = self.head_dim**-0.5
161
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
162

163
        self.qkv_proj = QKVParallelLinear(
164
165
166
167
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
168
            bias=bias,
169
            quant_config=quant_config,
170
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
171
        )
172

173
        self.o_proj = RowParallelLinear(
174
175
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
176
            bias=bias_o_proj,
177
            quant_config=quant_config,
178
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
179
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
180

181
        self._init_rotary_emb(config, quant_config=quant_config)
182

183
184
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
185
186
187
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
188
            if hasattr(config, "target_layer_count"):
189
190
191
192
193
194
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
195
            assert effective_layer_idx < len(layer_types), (
196
197
                f"effective_layer_idx: {effective_layer_idx} "
                f"is out of bounds for layer_types: {layer_types}"
198
            )
199

200
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
201
202
            if is_sliding:
                sliding_window = config.sliding_window
203

204
205
206
207
208
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
209
210

        self.attn = attn_cls(
211
212
213
214
215
216
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
217
            per_layer_sliding_window=sliding_window,
218
            attn_type=attn_type,
219
            prefix=f"{prefix}.attn",
220
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
223

    def forward(
        self,
224
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
227
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
228
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
229
        q, k = self.rotary_emb(positions, q, k)
230
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
233
        output, _ = self.o_proj(attn_output)
        return output

234
235
236
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
237
        quant_config: QuantizationConfig | None,
238
    ) -> None:
239
240
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
241
        if is_gguf and config.model_type == "llama":
242
243
244
245
246
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
247
            rope_parameters=getattr(config, "rope_parameters", None),
248
249
250
            is_neox_style=is_neox_style,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
251
252

class LlamaDecoderLayer(nn.Module):
253
254
255
256
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
257
        config: LlamaConfig | None = None,
258
        attn_layer_type: type[nn.Module] = LlamaAttention,
259
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
260
        super().__init__()
261
262
263

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
264
        quant_config = self.get_quant_config(vllm_config)
265

Woosuk Kwon's avatar
Woosuk Kwon committed
266
        self.hidden_size = config.hidden_size
267
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
268
269
270
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
271
272
            config, "bias", False
        )
273
274
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
275
        if hasattr(config, "qkv_bias"):
276
277
            attention_bias = config.qkv_bias

278
279
280
281
282
283
284
285
286
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

287
        self.self_attn = attn_layer_type(
288
            config=config,
289
290
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
291
292
293
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
294
            max_position_embeddings=max_position_embeddings,
295
            quant_config=quant_config,
296
            bias=attention_bias,
297
            bias_o_proj=bias_o_proj,
298
            cache_config=cache_config,
299
            prefix=f"{prefix}.self_attn",
300
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
        )
        self.mlp = LlamaMLP(
303
304
305
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
306
            quant_config=quant_config,
307
            bias=getattr(config, "mlp_bias", False),
308
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
309
        )
310
311
312
313
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
316

    def forward(
        self,
317
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
318
        hidden_states: torch.Tensor,
319
        residual: torch.Tensor | None,
320
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        # Self Attention
322
323
324
325
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
326
327
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329

        # Fully Connected
330
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
331
        hidden_states = self.mlp(hidden_states)
332
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
333

334
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
335
336
337
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
338

339
340
341
342
343
344
345
346
347
348
def llama_model_invariants(
    input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
    """Shape invariants for Llama model compilation, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])


349
350
351
352
353
@support_torch_compile(
    # TODO[#32068]: Investigate recompilation
    # mark_unbacked_dims={"input_ids": 0},
    shape_invariants=llama_model_invariants
)
Woosuk Kwon's avatar
Woosuk Kwon committed
354
class LlamaModel(nn.Module):
355
356
357
358
359
360
361
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
362
        super().__init__()
363
364
365
366

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
367
        self.config = config
368
        self.quant_config = quant_config
369
370
371

        self.vocab_size = config.vocab_size

372
373
374
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
375
376
377
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
378
                quant_config=quant_config,
379
380
381
            )
        else:
            self.embed_tokens = PPMissingLayer()
382
        self.start_layer, self.end_layer, self.layers = make_layers(
383
            config.num_hidden_layers,
384
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
385
386
            prefix=f"{prefix}.layers",
        )
387
388
389
390
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
391

392
        self.aux_hidden_state_layers = tuple[int, ...]()
393

394
395
396
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
397

398
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
399
400
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
    def forward(
        self,
403
        input_ids: torch.Tensor | None,
404
        positions: torch.Tensor,
405
406
407
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
408
409
410
411
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
412
                hidden_states = self.embed_input_ids(input_ids)
413
            residual = None
414
        else:
415
416
417
418
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

419
420
        aux_hidden_states = []
        for idx, layer in enumerate(
421
422
            islice(self.layers, self.start_layer, self.end_layer)
        ):
423
424
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
425
            hidden_states, residual = layer(positions, hidden_states, residual)
426
427

        if not get_pp_group().is_last_rank:
428
429
430
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
431

432
        hidden_states, _ = self.norm(hidden_states, residual)
433
434
435

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
436
437
        return hidden_states

438
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
439
440
441
442
443
444
445
446
447
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
448
        loaded_params: set[str] = set()
449
450
451
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
452
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
453
454
455
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
456
457
458
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
459
                # Loading kv cache quantization scales
460
                param = params_dict[scale_name]
461
462
463
464
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
465
                weight_loader(param, loaded_weight)
466
                loaded_params.add(scale_name)
467
                continue
468
469
            if "scale" in name or "zero_point" in name:
                # Remapping the name of FP8 kv-scale or zero point.
470
471
472
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
497
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
498
                weight_loader(param, loaded_weight)
499
500
            loaded_params.add(name)
        return loaded_params
501

Woosuk Kwon's avatar
Woosuk Kwon committed
502

503
504
505
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
506
    packed_modules_mapping = {
507
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
508
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
509
510
511
512
513
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
514
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
515
    }
516

517
518
519
520
521
522
523
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
524
        super().__init__()
525
526
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
527
        self.config = config
528

529
530
531
532
533
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
534

535
536
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
537
                config.vocab_size,
538
539
                config.hidden_size,
                quant_config=quant_config,
540
                prefix=maybe_prefix(prefix, "lm_head"),
541
542
            )
            if config.tie_word_embeddings:
543
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
544
545

            logit_scale = getattr(config, "logit_scale", 1.0)
546
            self.logits_processor = LogitsProcessor(
547
                config.vocab_size, scale=logit_scale
548
            )
549
550
        else:
            self.lm_head = PPMissingLayer()
551

552
        self.make_empty_intermediate_tensors = (
553
554
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
555

556
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
557
558
        self.model.aux_hidden_state_layers = layers

559
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
560
561
562
563
564
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
565
566
567
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

568
569
570
571
572
573
574
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
575

576
577
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
578

Woosuk Kwon's avatar
Woosuk Kwon committed
579
580
    def forward(
        self,
581
582
        input_ids: torch.Tensor,
        positions: torch.Tensor,
583
584
585
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
586
587
588
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
589
        return model_output
590

591
592
593
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
594
    ) -> torch.Tensor | None:
595
        logits = self.logits_processor(self.lm_head, hidden_states)
596
597
        return logits

598
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
599
600
        loader = AutoWeightsLoader(
            self,
601
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
602
        )
603
        return loader.load_weights(weights)
604
605
606
607
608
609
610
611
612
613
614
615


class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
    # This class sets the correct attention type and pooling type
    # through LlamaBidirectionalConfig.
    pass


class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
    # This class sets the correct attention type and pooling type
    # through LlamaBidirectionalConfig.
    pass