llama.py 30.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26
27
from collections.abc import Iterable
from typing import Any, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32

import torch
from torch import nn
from transformers import LlamaConfig

zhuwenwen's avatar
zhuwenwen committed
33
import os
gaoqiong's avatar
gaoqiong committed
34
import re
zhuwenwen's avatar
zhuwenwen committed
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

class MultiModalConfigProxy:
    """
    Proxy class to handle both flat configs (e.g., LlamaConfig) and 
    nested multimodal configs (e.g., Glm4vConfig with text_config).
    
    For multimodal configs where attributes are in text_config, this proxy
    transparently delegates attribute access to text_config when needed.
    """
    
    def __init__(self, config):
        # Store original config (for attributes that do exist at top level)
        object.__setattr__(self, '_config', config)
    
    def __getattr__(self, name):
        # First try to get from the original config (works for flat configs)
        try:
            return getattr(self._config, name)
        except AttributeError:
            # If not found and config has text_config, try there
            if hasattr(self._config, 'text_config'):
                return getattr(self._config.text_config, name)
            # Re-raise the original error if text_config doesn't have it either
            raise AttributeError(
                f"'{type(self._config).__name__}' object has no attribute '{name}'"
            )
    
    def __setattr__(self, name, value):
        # Allow setting attributes on the proxy itself
        if name == '_config':
            object.__setattr__(self, name, value)
        else:
            setattr(self._config, name, value)
    
    def __hasattr__(self, name):
        return hasattr(self._config, name) or (
            hasattr(self._config, 'text_config') and 
            hasattr(self._config.text_config, name)
        )

76
from vllm.attention import Attention, AttentionType
77
from vllm.compilation.decorators import support_torch_compile
78
from vllm.config import CacheConfig, VllmConfig
79
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
80
from vllm.model_executor.layers.activation import SiluAndMul
81
from vllm.model_executor.layers.layernorm import RMSNorm
82
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
83
84
                                               QKVParallelLinear,
                                               RowParallelLinear)
85
from vllm.model_executor.layers.logits_processor import LogitsProcessor
86
from vllm.model_executor.layers.quantization import QuantizationConfig
87
from vllm.model_executor.layers.rotary_embedding import get_rope
88
from vllm.model_executor.layers.vocab_parallel_embedding import (
89
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
90
from vllm.model_executor.model_loader.weight_utils import (
91
    default_weight_loader, maybe_remap_kv_scale_name)
92
from vllm.model_executor.sampling_metadata import SamplingMetadata
93
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
94

95
from .interfaces import SupportsLoRA, SupportsPP
96
97
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
98
99
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
100

gaoqiong's avatar
gaoqiong committed
101
from vllm import _custom_ops as ops
102
103
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
104
105

class LlamaMLP(nn.Module):
106

Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
    def __init__(
        self,
109
110
111
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
112
        quant_config: Optional[QuantizationConfig] = None,
113
        bias: bool = False,
114
        prefix: str = "",
115
        reduce_results: bool = True,
116
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
117
        super().__init__()
118
        self.gate_up_proj = MergedColumnParallelLinear(
119
120
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
121
            bias=bias,
122
            quant_config=quant_config,
123
124
125
126
127
128
129
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
130
            reduce_results=reduce_results,
131
132
            prefix=f"{prefix}.down_proj",
        )
133
134
135
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
136
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138

    def forward(self, x):
139
140
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
144
145
146
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

147
148
149
150
151
152
153
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
154
        rope_scaling: Optional[dict[str, Any]] = None,
155
156
157
158
159
160
161
162
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        bias_o_proj: bool = False,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
163
        super().__init__()
164
        layer_idx = extract_layer_index(prefix)
165
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
166
        tp_size = get_tensor_model_parallel_world_size()
167
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
168
169
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
170
        self.total_num_kv_heads = num_kv_heads
171
172
173
174
175
176
177
178
179
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
180
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
181
182
183
184
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
185
        # Phi models introduced a partial_rotary_factor parameter in the config
186
187
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
                                             1)
Zhuohan Li's avatar
Zhuohan Li committed
188
189
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
190
        self.scaling = self.head_dim**-0.5
191
192
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
193

194
        self.qkv_proj = QKVParallelLinear(
195
196
197
198
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
199
            bias=bias,
200
            quant_config=quant_config,
201
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
202
        )
203

204
        self.o_proj = RowParallelLinear(
205
206
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
207
            bias=bias_o_proj,
208
            quant_config=quant_config,
209
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
210
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
211

212
213
214
        self._init_rotary_emb(config,
                              rope_scaling=rope_scaling,
                              quant_config=quant_config)
215
216

        if hasattr(config, "interleaved_sliding_window"):
217
218
219
220
221
222
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
223
            else:
224
225
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
226
227
228
        else:
            sliding_window = None

229
230
231
232
233
234
235
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
236
            per_layer_sliding_window=sliding_window,
237
            attn_type=attn_type,
238
            prefix=f"{prefix}.attn",
239
        )
240
241
242
243
244
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
247

    def forward(
        self,
248
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
251
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
252
253
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
254
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
255
        q, k = self.rotary_emb(positions, q, k)
256
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
        output, _ = self.o_proj(attn_output)
        return output

260
261
262
263
264
    def _init_rotary_emb(self, config: LlamaConfig,
                         rope_scaling: Optional[dict[str, Any]],
                         quant_config: Optional[QuantizationConfig]) -> None:
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
265
        if is_gguf and config.model_type == "llama":
266
267
268
269
270
271
272
273
274
275
276
277
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
280

class LlamaDecoderLayer(nn.Module):

281
282
283
    def __init__(
        self,
        config: LlamaConfig,
284
        cache_config: Optional[CacheConfig] = None,
285
        quant_config: Optional[QuantizationConfig] = None,
286
        prefix: str = "",
287
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
288
        super().__init__()
289
290
291
        # Wrap config to handle both flat and nested multimodal configs
        config = MultiModalConfigProxy(config)
        
Woosuk Kwon's avatar
Woosuk Kwon committed
292
        self.hidden_size = config.hidden_size
293
294
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
295
296
297
298
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
299
300
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
301
302
303
304
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
305
306
307
308
309
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

310
311
312
313
314
315
316
317
318
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
319
        self.self_attn = LlamaAttention(
320
            config=config,
321
322
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
323
324
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
325
326
327
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
328
            quant_config=quant_config,
329
            bias=attention_bias,
330
            bias_o_proj=bias_o_proj,
331
            cache_config=cache_config,
332
            prefix=f"{prefix}.self_attn",
333
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
        )
        self.mlp = LlamaMLP(
336
337
338
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
339
            quant_config=quant_config,
340
            bias=getattr(config, "mlp_bias", False),
341
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
342
        )
343
344
345
346
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
347
348
349

    def forward(
        self,
350
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
351
        hidden_states: torch.Tensor,
352
        residual: Optional[torch.Tensor],
353
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
354
        # Self Attention
355
356
357
358
359
360
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
361
        hidden_states = self.self_attn(positions=positions,
362
                                       hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
363
364

        # Fully Connected
365
366
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
367
        hidden_states = self.mlp(hidden_states)
368
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
369
370


371
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
372
373
class LlamaModel(nn.Module):

374
375
376
377
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
378
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
379
        super().__init__()
380
381
382
383
384
385

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

386
387
388
        # Wrap config to handle both flat and nested multimodal configs
        config = MultiModalConfigProxy(config)
        
Woosuk Kwon's avatar
Woosuk Kwon committed
389
        self.config = config
390
        self.quant_config = quant_config
391
392
393
394
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
395
396
397
398
399
400
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
401
                quant_config=quant_config,
402
403
404
            )
        else:
            self.embed_tokens = PPMissingLayer()
405
        self.start_layer, self.end_layer, self.layers = make_layers(
406
            config.num_hidden_layers,
407
408
409
410
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
411
412
            prefix=f"{prefix}.layers",
        )
413
414
415
416
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
417

418
419
        self.aux_hidden_state_layers: tuple[int] = tuple()

420
421
422
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
423
424
425
426
427
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
gaoqiong's avatar
gaoqiong committed
428

zhuwenwen's avatar
zhuwenwen committed
429
430
431
432
433
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
434

435
436
437
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
438
439
    def forward(
        self,
440
        input_ids: Optional[torch.Tensor],
441
        positions: torch.Tensor,
442
        intermediate_tensors: Optional[IntermediateTensors],
443
        inputs_embeds: Optional[torch.Tensor] = None,
444
445
    ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
                                                        list[torch.Tensor]]]:
446
447
448
449
450
451
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
452
        else:
453
454
455
456
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

457
458
459
460
461
        aux_hidden_states = []
        for idx, layer in enumerate(
                self.layers[self.start_layer:self.end_layer]):
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
462
            hidden_states, residual = layer(positions, hidden_states, residual)
463
464
465
466
467
468
469

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

470
        hidden_states, _ = self.norm(hidden_states, residual)
471
472
473

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
474
475
        return hidden_states

476
477
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
478
479
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
480
481
482
483
484
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
485
        ]
486
        params_dict = dict(self.named_parameters())
487
        loaded_params: set[str] = set()
488
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
489
490
491
            if self.use_llama_nn:
                current_count = loaded_weight.current_count 
                total_count = loaded_weight.total_count
492
493
            if "rotary_emb.inv_freq" in name:
                continue
494
495
496
497
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
498
                continue
499
500
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
501
                # Loading kv cache quantization scales
502
503
504
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
505
506
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
507
                weight_loader(param, loaded_weight)
508
                loaded_params.add(scale_name)
509
                continue
510
511
512
513
514
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
515
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
516
                if weight_name not in name:
517
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
518
519
520
521
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
522
523
524
525

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
526
                param = params_dict[name]
527
528
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
529
                break
530
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
531
532
533
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
534
535
536
537

                if is_pp_missing_parameter(name, self):
                    continue

538
539
540
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
541
                weight_loader(param, loaded_weight)
542
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
543
            
zhuwenwen's avatar
zhuwenwen committed
544
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
gaoqiong's avatar
gaoqiong committed
545
546
547
548
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
549
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
550
551
552
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
553
554
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
555
            
zhuwenwen's avatar
zhuwenwen committed
556
            # for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
557
558
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
559
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
560
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
561
562
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
563
564
565
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
566
                    
gaoqiong's avatar
gaoqiong committed
567
                matches = re.findall(combined_words, layername)
568
                
569
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
570
571
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
572
                        
zhuwenwen's avatar
zhuwenwen committed
573
574
575
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
576
                                 
gaoqiong's avatar
gaoqiong committed
577
578
579
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
580
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
581
582
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
583
                    weight.data=weight.data.reshape(ori_shape[1], -1)
584
        else:
zhuwenwen's avatar
zhuwenwen committed
585
            os.environ['LM_NN'] = '0'
586
            os.environ['LLAMA_NN'] = '0'
gaoqiong's avatar
gaoqiong committed
587
             
588
        return loaded_params
589

Woosuk Kwon's avatar
Woosuk Kwon committed
590

591
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
592
    packed_modules_mapping = {
593
594
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
595
596
597
598
599
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
600
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
601
602
    }
    embedding_padding_modules = ["lm_head"]
603

604
605
606
607
608
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
609
610
611
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
612
613
614
615
616
617
618
619
620
621
622
623
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
624
        "norm": "model.norm",
625
    }
626

627
628
629
630
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
631
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
632
        super().__init__()
633
634
635
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
636
637
638
639
        
        # Wrap config to handle both flat and nested multimodal configs
        config = MultiModalConfigProxy(config)
        
Woosuk Kwon's avatar
Woosuk Kwon committed
640
        self.config = config
641
        self.lora_config = lora_config
642
        self.model = self._init_model(vllm_config=vllm_config,
643
644
                                      prefix=maybe_prefix(prefix, "model"),
                                      layer_type=layer_type)
645

646
647
648
649
650
651
652
653
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
654
655
656
657
658
659
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
660
                quant_config=quant_config,
661
                prefix=maybe_prefix(prefix, "lm_head"),
662
663
            )
            if config.tie_word_embeddings:
664
665
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
666
667
668
669
670
671
672

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
673

674
675
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
676

677
678
679
680
681
682
683
    def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

684
685
686
    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
687
                    layer_type: type[nn.Module] = LlamaDecoderLayer):
688
689
690
        return LlamaModel(vllm_config=vllm_config,
                          prefix=prefix,
                          layer_type=layer_type)
691

692
693
694
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
695
696
    def forward(
        self,
697
698
        input_ids: torch.Tensor,
        positions: torch.Tensor,
699
        intermediate_tensors: Optional[IntermediateTensors] = None,
700
        inputs_embeds: Optional[torch.Tensor] = None,
701
    ) -> Union[torch.Tensor, IntermediateTensors]:
702
        model_output = self.model(input_ids, positions, intermediate_tensors,
703
                                  inputs_embeds)
704
        return model_output
705

706
707
708
709
710
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
711
        logits = self.logits_processor(self.lm_head, hidden_states,
712
713
714
                                       sampling_metadata)
        return logits

715
716
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
717
718
719
720
721
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
722
        return loader.load_weights(
723
            self.maybe_remap_mistral(name, loaded_weight)
724
            for name, loaded_weight in weights)
725

726
727
728
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
729
730
731
        self,
        name: str,
        loaded_weight: torch.Tensor,
732
    ) -> tuple[str, torch.Tensor]:
733

734
        def permute(w: torch.Tensor, n_heads: int):
735
736
737
738
739
740
741
742
743
744
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
745
        if "wk" in modules and modules[-1] == "weight":
746
747
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
748
        elif "wq" in modules and modules[-1] == "weight":
749
750
751
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

752
753
754
755
756
757
758
759
760
761
762
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = (f"{item}.{next_item}"
                             if next_item is not None else None)

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
763
764
765
                name = name.replace(item, mapping[item])

        return name, loaded_weight