llama.py 28.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26
from collections.abc import Iterable
27
from itertools import islice
28
from typing import Any, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

zhuwenwen's avatar
zhuwenwen committed
34
import os
gaoqiong's avatar
gaoqiong committed
35
import re
zhuwenwen's avatar
zhuwenwen committed
36

37
from vllm.attention import Attention, AttentionType
38
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
39
from vllm.compilation.decorators import support_torch_compile
40
from vllm.config import CacheConfig, VllmConfig
41
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.model_executor.layers.activation import SiluAndMul
43
from vllm.model_executor.layers.layernorm import RMSNorm
44
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
45
46
                                               QKVParallelLinear,
                                               RowParallelLinear)
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
51
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
52
from vllm.model_executor.model_loader.weight_utils import (
53
    default_weight_loader, maybe_remap_kv_scale_name)
54
from vllm.model_executor.sampling_metadata import SamplingMetadata
55
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
58
59
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
60
61
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
62

gaoqiong's avatar
gaoqiong committed
63
from vllm import _custom_ops as ops
64
65
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
66
67

class LlamaMLP(nn.Module):
68

Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
    def __init__(
        self,
71
72
73
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
74
        quant_config: Optional[QuantizationConfig] = None,
75
        bias: bool = False,
76
        prefix: str = "",
77
        reduce_results: bool = True,
78
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        super().__init__()
80
        self.gate_up_proj = MergedColumnParallelLinear(
81
82
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
83
            bias=bias,
84
            quant_config=quant_config,
85
86
87
88
89
90
91
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
92
            reduce_results=reduce_results,
93
94
            prefix=f"{prefix}.down_proj",
        )
95
96
97
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100

    def forward(self, x):
101
102
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
106
107
108
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

109
110
111
112
113
114
115
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
116
        rope_scaling: Optional[dict[str, Any]] = None,
117
118
119
120
121
122
123
124
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        bias_o_proj: bool = False,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
125
        super().__init__()
126
        layer_idx = extract_layer_index(prefix)
127
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
128
        tp_size = get_tensor_model_parallel_world_size()
129
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
130
131
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
132
        self.total_num_kv_heads = num_kv_heads
133
134
135
136
137
138
139
140
141
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
142
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
143
144
145
146
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
147
        # Phi models introduced a partial_rotary_factor parameter in the config
148
149
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
                                             1)
Zhuohan Li's avatar
Zhuohan Li committed
150
151
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
152
        self.scaling = self.head_dim**-0.5
153
154
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
155

156
        self.qkv_proj = QKVParallelLinear(
157
158
159
160
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
161
            bias=bias,
162
            quant_config=quant_config,
163
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        )
165

166
        self.o_proj = RowParallelLinear(
167
168
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
169
            bias=bias_o_proj,
170
            quant_config=quant_config,
171
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
172
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
173

174
175
176
        self._init_rotary_emb(config,
                              rope_scaling=rope_scaling,
                              quant_config=quant_config)
177

178
179
180
181
182
        sliding_window = None
        if layer_types := getattr(config, "layer_types", None):
            is_sliding = layer_types[layer_idx] == "sliding_attention"
            if is_sliding:
                sliding_window = config.sliding_window
183

184
185
186
187
        attn_cls = (EncoderOnlyAttention
                    if attn_type == AttentionType.ENCODER_ONLY else Attention)

        self.attn = attn_cls(
188
189
190
191
192
193
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
194
            per_layer_sliding_window=sliding_window,
195
            attn_type=attn_type,
196
            prefix=f"{prefix}.attn",
197
        )
198
199
200
201
202
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205

    def forward(
        self,
206
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
209
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
210
211
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
212
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
213
        q, k = self.rotary_emb(positions, q, k)
214
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217
        output, _ = self.o_proj(attn_output)
        return output

218
219
220
221
222
    def _init_rotary_emb(self, config: LlamaConfig,
                         rope_scaling: Optional[dict[str, Any]],
                         quant_config: Optional[QuantizationConfig]) -> None:
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
223
        if is_gguf and config.model_type == "llama":
224
225
226
227
228
229
230
231
232
233
234
235
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238

class LlamaDecoderLayer(nn.Module):

239
240
241
    def __init__(
        self,
        config: LlamaConfig,
242
        cache_config: Optional[CacheConfig] = None,
243
        quant_config: Optional[QuantizationConfig] = None,
244
        prefix: str = "",
245
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
        super().__init__()
        self.hidden_size = config.hidden_size
248
249
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
250
251
252
253
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
254
255
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
256
257
258
259
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
260
261
262
263
264
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

265
266
267
268
269
270
271
272
273
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
274
        self.self_attn = LlamaAttention(
275
            config=config,
276
277
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
278
279
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
280
281
282
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
283
            quant_config=quant_config,
284
            bias=attention_bias,
285
            bias_o_proj=bias_o_proj,
286
            cache_config=cache_config,
287
            prefix=f"{prefix}.self_attn",
288
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290
        )
        self.mlp = LlamaMLP(
291
292
293
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
294
            quant_config=quant_config,
295
            bias=getattr(config, "mlp_bias", False),
296
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
297
        )
298
299
300
301
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303
304

    def forward(
        self,
305
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
306
        hidden_states: torch.Tensor,
307
        residual: Optional[torch.Tensor],
308
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
309
        # Self Attention
310
311
312
313
314
315
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
316
        hidden_states = self.self_attn(positions=positions,
317
                                       hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
318
319

        # Fully Connected
320
321
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
322
        hidden_states = self.mlp(hidden_states)
323
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
324
325


326
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
class LlamaModel(nn.Module):

329
330
331
332
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
333
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
334
        super().__init__()
335
336
337
338
339
340

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
341
        self.config = config
342
        self.quant_config = quant_config
343
344
345
346
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
347
348
349
350
351
352
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
353
                quant_config=quant_config,
354
355
356
            )
        else:
            self.embed_tokens = PPMissingLayer()
357
        self.start_layer, self.end_layer, self.layers = make_layers(
358
            config.num_hidden_layers,
359
360
361
362
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
363
364
            prefix=f"{prefix}.layers",
        )
365
366
367
368
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
369

370
        self.aux_hidden_state_layers = tuple[int, ...]()
371

372
373
374
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
375
376
377
378
379
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
gaoqiong's avatar
gaoqiong committed
380

zhuwenwen's avatar
zhuwenwen committed
381
382
383
384
385
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
386

387
388
389
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
390
391
    def forward(
        self,
392
        input_ids: Optional[torch.Tensor],
393
        positions: torch.Tensor,
394
        intermediate_tensors: Optional[IntermediateTensors],
395
        inputs_embeds: Optional[torch.Tensor] = None,
396
397
    ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
                                                        list[torch.Tensor]]]:
398
399
400
401
402
403
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
404
        else:
405
406
407
408
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

409
410
        aux_hidden_states = []
        for idx, layer in enumerate(
411
                islice(self.layers, self.start_layer, self.end_layer)):
412
413
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
414
            hidden_states, residual = layer(positions, hidden_states, residual)
415
416
417
418
419
420
421

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

422
        hidden_states, _ = self.norm(hidden_states, residual)
423
424
425

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
426
427
        return hidden_states

428
429
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
430
431
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
432
433
434
435
436
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
437
        ]
438
        params_dict = dict(self.named_parameters())
439
        loaded_params: set[str] = set()
440
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
441
442
443
            if self.use_llama_nn:
                current_count = loaded_weight.current_count 
                total_count = loaded_weight.total_count
444
445
            if "rotary_emb.inv_freq" in name:
                continue
446
447
448
449
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
450
                continue
451
452
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
453
                # Loading kv cache quantization scales
454
455
456
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
457
458
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
459
                weight_loader(param, loaded_weight)
460
                loaded_params.add(scale_name)
461
                continue
462
463
464
465
466
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
467
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
468
                if weight_name not in name:
469
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
470
471
472
473
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
474
475
476
477

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
478
                param = params_dict[name]
479
480
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
481
                break
482
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
483
484
485
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
486
487
488
489

                if is_pp_missing_parameter(name, self):
                    continue

490
491
492
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
493
                weight_loader(param, loaded_weight)
494
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
495
            
zhuwenwen's avatar
zhuwenwen committed
496
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
gaoqiong's avatar
gaoqiong committed
497
498
499
500
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
501
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
502
503
504
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
505
506
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
507
            
zhuwenwen's avatar
zhuwenwen committed
508
            # for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
509
510
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
511
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
512
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
513
514
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
515
516
517
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
518
                    
gaoqiong's avatar
gaoqiong committed
519
                matches = re.findall(combined_words, layername)
520
                
521
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
522
523
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
524
                        
zhuwenwen's avatar
zhuwenwen committed
525
526
527
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
528
                                 
gaoqiong's avatar
gaoqiong committed
529
530
531
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
532
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
533
534
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
535
                    weight.data=weight.data.reshape(ori_shape[1], -1)
536
        else:
zhuwenwen's avatar
zhuwenwen committed
537
            os.environ['LM_NN'] = '0'
538
            os.environ['LLAMA_NN'] = '0'
gaoqiong's avatar
gaoqiong committed
539
             
540
        return loaded_params
541

Woosuk Kwon's avatar
Woosuk Kwon committed
542

543
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
Terry's avatar
Terry committed
544
    packed_modules_mapping = {
545
546
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
547
548
549
550
551
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
552
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
553
554
    }
    embedding_padding_modules = ["lm_head"]
555

556
557
558
559
560
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
561
562
563
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
564
565
566
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
567
568
569
570
571
572
573
574
575
576
577
578
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
579
        "norm": "model.norm",
580
    }
581

582
583
584
585
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
586
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
587
        super().__init__()
588
589
590
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
591
        self.config = config
592
        self.lora_config = lora_config
593
        self.model = self._init_model(vllm_config=vllm_config,
594
595
                                      prefix=maybe_prefix(prefix, "model"),
                                      layer_type=layer_type)
596

597
598
599
600
601
602
603
604
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
605
606
607
608
609
610
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
611
                quant_config=quant_config,
612
                prefix=maybe_prefix(prefix, "lm_head"),
613
614
            )
            if config.tie_word_embeddings:
615
616
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
617
618
619
620
621
622
623

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
624

625
626
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
627

628
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
629
630
        self.model.aux_hidden_state_layers = layers

631
    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
632
633
634
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

635
636
637
    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
638
                    layer_type: type[nn.Module] = LlamaDecoderLayer):
639
640
641
        return LlamaModel(vllm_config=vllm_config,
                          prefix=prefix,
                          layer_type=layer_type)
642

643
644
645
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
646
647
    def forward(
        self,
648
649
        input_ids: torch.Tensor,
        positions: torch.Tensor,
650
        intermediate_tensors: Optional[IntermediateTensors] = None,
651
        inputs_embeds: Optional[torch.Tensor] = None,
652
    ) -> Union[torch.Tensor, IntermediateTensors]:
653
        model_output = self.model(input_ids, positions, intermediate_tensors,
654
                                  inputs_embeds)
655
        return model_output
656

657
658
659
660
661
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
662
        logits = self.logits_processor(self.lm_head, hidden_states,
663
664
665
                                       sampling_metadata)
        return logits

666
667
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
668
669
670
671
672
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
673
        return loader.load_weights(
674
            self.maybe_remap_mistral(name, loaded_weight)
675
            for name, loaded_weight in weights)
676

677
678
679
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
680
681
682
        self,
        name: str,
        loaded_weight: torch.Tensor,
683
    ) -> tuple[str, torch.Tensor]:
684

685
        def permute(w: torch.Tensor, n_heads: int):
686
687
688
689
690
691
692
693
694
695
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
696
        if "wk" in modules and modules[-1] == "weight":
697
698
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
699
        elif "wq" in modules and modules[-1] == "weight":
700
701
702
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

703
704
705
706
707
708
709
710
711
712
713
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = (f"{item}.{next_item}"
                             if next_item is not None else None)

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
714
715
716
                name = name.replace(item, mapping[item])

        return name, loaded_weight