llama.py 25.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
33
import os
gaoqiong's avatar
gaoqiong committed
34
import re
zhuwenwen's avatar
zhuwenwen committed
35

36
from vllm.attention.layer import Attention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.activation import SiluAndMul
41
42
43
from vllm.model_executor.layers.attention.encoder_only_attention import (
    EncoderOnlyAttention,
)
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
46
47
48
49
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
50
from vllm.model_executor.layers.logits_processor import LogitsProcessor
51
from vllm.model_executor.layers.quantization import QuantizationConfig
52
from vllm.model_executor.layers.rotary_embedding import get_rope
53
from vllm.model_executor.layers.vocab_parallel_embedding import (
54
55
56
    ParallelLMHead,
    VocabParallelEmbedding,
)
57
from vllm.model_executor.model_loader.weight_utils import (
58
59
60
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
61
from vllm.sequence import IntermediateTensors
62
from vllm.v1.attention.backend import AttentionType
Woosuk Kwon's avatar
Woosuk Kwon committed
63

64
65
66
67
68
69
70
from .adapters import as_embedding_model, as_seq_cls_model
from .interfaces import (
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
71
72
73
74
75
76
77
78
79
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
80

gaoqiong's avatar
gaoqiong committed
81
from vllm import _custom_ops as ops
82
83
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
87

class LlamaMLP(nn.Module):
    def __init__(
        self,
88
89
90
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
91
        quant_config: QuantizationConfig | None = None,
92
        bias: bool = False,
93
        prefix: str = "",
94
        reduce_results: bool = True,
95
        disable_tp: bool = False,
96
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        super().__init__()
98
        self.gate_up_proj = MergedColumnParallelLinear(
99
100
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
101
            bias=bias,
102
            quant_config=quant_config,
103
            disable_tp=disable_tp,
104
105
106
107
108
109
110
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
111
            reduce_results=reduce_results,
112
            disable_tp=disable_tp,
113
114
            prefix=f"{prefix}.down_proj",
        )
115
        if hidden_act != "silu":
116
117
118
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121

    def forward(self, x):
122
123
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
127
128
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
129
130
131
132
133
134
135
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
136
        quant_config: QuantizationConfig | None = None,
137
138
        bias: bool = False,
        bias_o_proj: bool = False,
139
        cache_config: CacheConfig | None = None,
140
141
142
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
143
        super().__init__()
144
        layer_idx = extract_layer_index(prefix)
145
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
146
        tp_size = get_tensor_model_parallel_world_size()
147
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
148
149
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
150
        self.total_num_kv_heads = num_kv_heads
151
152
153
154
155
156
157
158
159
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
160

161
        head_dim = getattr(config, "head_dim", None)
162
        self.head_dim = head_dim or self.hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
163
164
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
165
        self.scaling = self.head_dim**-0.5
166
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
167

168
        self.qkv_proj = QKVParallelLinear(
169
170
171
172
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
173
            bias=bias,
174
            quant_config=quant_config,
175
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
176
        )
177

178
        self.o_proj = RowParallelLinear(
179
180
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
181
            bias=bias_o_proj,
182
            quant_config=quant_config,
183
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
184
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
185

186
        self._init_rotary_emb(config, quant_config=quant_config)
187

188
189
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
190
191
192
            # Fix for Eagle3 compatibility:
            # for draft models, subtract target layer count
            # to get draft-relative layer index starting from 0
193
            if hasattr(config, "target_layer_count"):
194
195
196
197
198
199
                # This is a draft model,
                # adjust layer_idx to be relative to draft layers
                effective_layer_idx = layer_idx - config.target_layer_count
            else:
                # This is a target model, use layer_idx directly
                effective_layer_idx = layer_idx
200
            assert effective_layer_idx < len(layer_types), (
201
202
                f"effective_layer_idx: {effective_layer_idx} "
                f"is out of bounds for layer_types: {layer_types}"
203
            )
204

205
            is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
206
207
            if is_sliding:
                sliding_window = config.sliding_window
208

209
210
211
212
213
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
214
215

        self.attn = attn_cls(
216
217
218
219
220
221
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
222
            per_layer_sliding_window=sliding_window,
223
            attn_type=attn_type,
224
            prefix=f"{prefix}.attn",
225
        )
226
227
228
229
230
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
233

    def forward(
        self,
234
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
237
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
238
239
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
240
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
241
        q, k = self.rotary_emb(positions, q, k)
242
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
245
        output, _ = self.o_proj(attn_output)
        return output

246
247
248
    def _init_rotary_emb(
        self,
        config: LlamaConfig,
249
        quant_config: QuantizationConfig | None,
250
    ) -> None:
251
252
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
253
        if is_gguf and config.model_type == "llama":
254
255
256
257
258
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
259
            rope_parameters=getattr(config, "rope_parameters", None),
260
261
262
            is_neox_style=is_neox_style,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
263
264

class LlamaDecoderLayer(nn.Module):
265
266
267
268
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
269
        config: LlamaConfig | None = None,
270
        attn_layer_type: type[nn.Module] = LlamaAttention,
271
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
272
        super().__init__()
273
274
275

        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
276
        quant_config = self.get_quant_config(vllm_config)
277

Woosuk Kwon's avatar
Woosuk Kwon committed
278
        self.hidden_size = config.hidden_size
279
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
280
281
282
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
283
284
            config, "bias", False
        )
285
286
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
287
        if hasattr(config, "qkv_bias"):
288
289
            attention_bias = config.qkv_bias

290
291
292
293
294
295
296
297
298
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

299
        self.self_attn = attn_layer_type(
300
            config=config,
301
302
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
303
304
305
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
306
            max_position_embeddings=max_position_embeddings,
307
            quant_config=quant_config,
308
            bias=attention_bias,
309
            bias_o_proj=bias_o_proj,
310
            cache_config=cache_config,
311
            prefix=f"{prefix}.self_attn",
312
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
313
314
        )
        self.mlp = LlamaMLP(
315
316
317
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
318
            quant_config=quant_config,
319
            bias=getattr(config, "mlp_bias", False),
320
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        )
322
323
324
325
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327
328

    def forward(
        self,
329
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
330
        hidden_states: torch.Tensor,
331
        residual: torch.Tensor | None,
332
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
333
        # Self Attention
334
335
336
337
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
338
339
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341

        # Fully Connected
342
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
343
        hidden_states = self.mlp(hidden_states)
344
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
345

346
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
347
348
        """Get quantization config for this layer. Override in subclasses."""
        return vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
349
350


351
352
353
354
355
356
357
358
359
360
def llama_model_invariants(
    input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
    """Shape invariants for Llama model compilation, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])


361
362
363
364
365
@support_torch_compile(
    # TODO[#32068]: Investigate recompilation
    # mark_unbacked_dims={"input_ids": 0},
    shape_invariants=llama_model_invariants
)
Woosuk Kwon's avatar
Woosuk Kwon committed
366
class LlamaModel(nn.Module):
367
368
369
370
371
372
373
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
374
        super().__init__()
375
376
377
378

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
379
        self.config = config
380
        self.quant_config = quant_config
381
382
383

        self.vocab_size = config.vocab_size

384
385
386
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
387
388
389
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
390
                quant_config=quant_config,
391
392
393
            )
        else:
            self.embed_tokens = PPMissingLayer()
394
        self.start_layer, self.end_layer, self.layers = make_layers(
395
            config.num_hidden_layers,
396
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
397
398
            prefix=f"{prefix}.layers",
        )
399
400
401
402
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
403

404
        self.aux_hidden_state_layers = tuple[int, ...]()
405

406
407
408
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
zhuwenwen's avatar
zhuwenwen committed
409
410
411
412
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
gaoqiong's avatar
gaoqiong committed
413

zhuwenwen's avatar
zhuwenwen committed
414
415
416
417
418
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
419

420
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
421
422
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
423
424
    def forward(
        self,
425
        input_ids: torch.Tensor | None,
426
        positions: torch.Tensor,
427
428
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
429
        **extra_layer_kwargs,
430
    ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
431
432
433
434
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
435
                hidden_states = self.embed_input_ids(input_ids)
436
            residual = None
437
        else:
438
439
440
441
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

442
443
        aux_hidden_states = []
        for idx, layer in enumerate(
444
445
            islice(self.layers, self.start_layer, self.end_layer)
        ):
446
447
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
448
449
450
            hidden_states, residual = layer(
                positions, hidden_states, residual, **extra_layer_kwargs
            )
451
452

        if not get_pp_group().is_last_rank:
453
454
455
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
456

457
        hidden_states, _ = self.norm(hidden_states, residual)
458
459
460

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
461
462
        return hidden_states

463
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
464
465
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
466
467
468
469
470
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
471
        ]
472
        params_dict = dict(self.named_parameters())
473
        loaded_params: set[str] = set()
474
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
475
476
477
            if self.use_llama_nn:
                current_count = loaded_weight.current_count 
                total_count = loaded_weight.total_count
478
479
            if "rotary_emb.inv_freq" in name:
                continue
480
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
481
482
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
483
                continue
484
485
486
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
487
                # Loading kv cache quantization scales
488
                param = params_dict[scale_name]
489
490
491
492
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
493
                weight_loader(param, loaded_weight)
494
                loaded_params.add(scale_name)
495
                continue
496
497
            if "scale" in name or "zero_point" in name:
                # Remapping the name of FP8 kv-scale or zero point.
498
499
500
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
501
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
502
                if weight_name not in name:
503
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
504
505
506
507
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
508
509
510
511

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
512
                param = params_dict[name]
513
514
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
515
                break
516
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
517
518
519
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
520
521
522
523

                if is_pp_missing_parameter(name, self):
                    continue

524
                param = params_dict[name]
525
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
526
                weight_loader(param, loaded_weight)
527
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
528
            
zhuwenwen's avatar
zhuwenwen committed
529
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
gaoqiong's avatar
gaoqiong committed
530
531
532
533
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
534
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
535
536
537
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
538
539
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
540
            
zhuwenwen's avatar
zhuwenwen committed
541
            # for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
542
543
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
544
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
545
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
546
547
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
548
549
550
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
551
                    
gaoqiong's avatar
gaoqiong committed
552
                matches = re.findall(combined_words, layername)
553
                
554
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
555
556
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
557
                        
zhuwenwen's avatar
zhuwenwen committed
558
559
560
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
561
                                 
gaoqiong's avatar
gaoqiong committed
562
563
564
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
565
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
566
567
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
568
                    weight.data=weight.data.reshape(ori_shape[1], -1)
569
        else:
zhuwenwen's avatar
zhuwenwen committed
570
            os.environ['LM_NN'] = '0'
571
            os.environ['LLAMA_NN'] = '0'
gaoqiong's avatar
gaoqiong committed
572
             
573
        return loaded_params
574

Woosuk Kwon's avatar
Woosuk Kwon committed
575

576
577
578
class LlamaForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
Terry's avatar
Terry committed
579
    packed_modules_mapping = {
580
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
581
        "gate_up_proj": ["gate_proj", "up_proj"],
Terry's avatar
Terry committed
582
583
584
585
586
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
587
        "lm_head": "output_embeddings",
Terry's avatar
Terry committed
588
    }
589

590
591
592
593
594
595
596
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
597
        super().__init__()
598
599
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
600
        self.config = config
601

602
603
604
605
606
        self.model = self._init_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
        )
607

608
609
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
610
                config.vocab_size,
611
612
                config.hidden_size,
                quant_config=quant_config,
613
                prefix=maybe_prefix(prefix, "lm_head"),
614
615
            )
            if config.tie_word_embeddings:
616
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
617
618

            logit_scale = getattr(config, "logit_scale", 1.0)
619
            self.logits_processor = LogitsProcessor(
620
                config.vocab_size, scale=logit_scale
621
            )
622
623
        else:
            self.lm_head = PPMissingLayer()
624

625
        self.make_empty_intermediate_tensors = (
626
627
            self.model.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
628

629
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
630
631
        self.model.aux_hidden_state_layers = layers

632
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
633
634
635
636
637
        """Override to return default layers for Llama

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
638
639
640
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

641
642
643
644
645
646
647
    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = LlamaDecoderLayer,
    ):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
648

649
650
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
651

Woosuk Kwon's avatar
Woosuk Kwon committed
652
653
    def forward(
        self,
654
        input_ids: torch.Tensor | None,
655
        positions: torch.Tensor,
656
657
658
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
659
660
661
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
662
        return model_output
663

664
665
666
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
667
    ) -> torch.Tensor | None:
668
        logits = self.logits_processor(self.lm_head, hidden_states)
669
670
        return logits

671
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
672
673
        loader = AutoWeightsLoader(
            self,
674
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
675
        )
676
        return loader.load_weights(weights)
677
678
679
680
681
682
683
684
685
686
687
688


class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
    # This class sets the correct attention type and pooling type
    # through LlamaBidirectionalConfig.
    pass


class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
    # This class sets the correct attention type and pooling type
    # through LlamaBidirectionalConfig.
    pass