llama.py 28.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
"""Inference-only LLaMA model compatible with HuggingFace weights."""
26
27
from collections.abc import Iterable
from typing import Any, Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32

import torch
from torch import nn
from transformers import LlamaConfig

zhuwenwen's avatar
zhuwenwen committed
33
import os
gaoqiong's avatar
gaoqiong committed
34
import re
zhuwenwen's avatar
zhuwenwen committed
35

36
from vllm.attention import Attention, AttentionType
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.activation import SiluAndMul
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
43
44
                                               QKVParallelLinear,
                                               RowParallelLinear)
45
from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
from vllm.model_executor.layers.quantization import QuantizationConfig
47
from vllm.model_executor.layers.rotary_embedding import get_rope
48
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
50
from vllm.model_executor.model_loader.weight_utils import (
51
    default_weight_loader, maybe_remap_kv_scale_name)
52
from vllm.model_executor.sampling_metadata import SamplingMetadata
53
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
54

55
from .interfaces import SupportsLoRA, SupportsPP
56
57
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
58
59
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
60

gaoqiong's avatar
gaoqiong committed
61
from vllm import _custom_ops as ops
62
63
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
64
65

class LlamaMLP(nn.Module):
66

Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
    def __init__(
        self,
69
70
71
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
72
        quant_config: Optional[QuantizationConfig] = None,
73
        bias: bool = False,
74
        prefix: str = "",
75
        reduce_results: bool = True,
76
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        super().__init__()
78
        self.gate_up_proj = MergedColumnParallelLinear(
79
80
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
81
            bias=bias,
82
            quant_config=quant_config,
83
84
85
86
87
88
89
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
90
            reduce_results=reduce_results,
91
92
            prefix=f"{prefix}.down_proj",
        )
93
94
95
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98

    def forward(self, x):
99
100
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
106
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

107
108
109
110
111
112
113
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
114
        rope_scaling: Optional[dict[str, Any]] = None,
115
116
117
118
119
120
121
122
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        bias_o_proj: bool = False,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
123
        super().__init__()
124
        layer_idx = extract_layer_index(prefix)
125
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
126
        tp_size = get_tensor_model_parallel_world_size()
127
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
128
129
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
130
        self.total_num_kv_heads = num_kv_heads
131
132
133
134
135
136
137
138
139
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
140
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
141
142
143
144
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = self.hidden_size // self.total_num_heads
        self.head_dim = head_dim
Amit Garg's avatar
Amit Garg committed
145
        # Phi models introduced a partial_rotary_factor parameter in the config
146
147
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
                                             1)
Zhuohan Li's avatar
Zhuohan Li committed
148
149
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
150
        self.scaling = self.head_dim**-0.5
151
152
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
153

154
        self.qkv_proj = QKVParallelLinear(
155
156
157
158
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
159
            bias=bias,
160
            quant_config=quant_config,
161
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
162
        )
163

164
        self.o_proj = RowParallelLinear(
165
166
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
167
            bias=bias_o_proj,
168
            quant_config=quant_config,
169
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
171

172
173
174
        self._init_rotary_emb(config,
                              rope_scaling=rope_scaling,
                              quant_config=quant_config)
175
176

        if hasattr(config, "interleaved_sliding_window"):
177
178
179
180
181
182
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
183
            else:
184
185
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
186
187
188
        else:
            sliding_window = None

189
190
191
192
193
194
195
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
196
            per_layer_sliding_window=sliding_window,
197
            attn_type=attn_type,
198
            prefix=f"{prefix}.attn",
199
        )
200
201
202
203
204
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207

    def forward(
        self,
208
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
211
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
212
213
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
214
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
215
        q, k = self.rotary_emb(positions, q, k)
216
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219
        output, _ = self.o_proj(attn_output)
        return output

220
221
222
223
224
    def _init_rotary_emb(self, config: LlamaConfig,
                         rope_scaling: Optional[dict[str, Any]],
                         quant_config: Optional[QuantizationConfig]) -> None:
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
225
        if is_gguf and config.model_type == "llama":
226
227
228
229
230
231
232
233
234
235
236
237
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=is_neox_style,
            partial_rotary_factor=self.partial_rotary_factor,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
240

class LlamaDecoderLayer(nn.Module):

241
242
243
    def __init__(
        self,
        config: LlamaConfig,
244
        cache_config: Optional[CacheConfig] = None,
245
        quant_config: Optional[QuantizationConfig] = None,
246
        prefix: str = "",
247
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
        super().__init__()
        self.hidden_size = config.hidden_size
250
251
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
252
253
254
255
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
256
257
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
258
259
260
261
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
262
263
264
265
266
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

267
268
269
270
271
272
273
274
275
        # By default, Llama uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. parasail-ai/GritLM-7B-vllm)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

Woosuk Kwon's avatar
Woosuk Kwon committed
276
        self.self_attn = LlamaAttention(
277
            config=config,
278
279
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
280
281
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
282
283
284
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
285
            quant_config=quant_config,
286
            bias=attention_bias,
287
            bias_o_proj=bias_o_proj,
288
            cache_config=cache_config,
289
            prefix=f"{prefix}.self_attn",
290
            attn_type=attn_type,
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
        )
        self.mlp = LlamaMLP(
293
294
295
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
296
            quant_config=quant_config,
297
            bias=getattr(config, "mlp_bias", False),
298
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
299
        )
300
301
302
303
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306

    def forward(
        self,
307
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        hidden_states: torch.Tensor,
309
        residual: Optional[torch.Tensor],
310
    ) -> tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
311
        # Self Attention
312
313
314
315
316
317
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
318
        hidden_states = self.self_attn(positions=positions,
319
                                       hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
320
321

        # Fully Connected
322
323
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
324
        hidden_states = self.mlp(hidden_states)
325
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327


328
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
329
330
class LlamaModel(nn.Module):

331
332
333
334
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
335
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
336
        super().__init__()
337
338
339
340
341
342

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
343
        self.config = config
344
        self.quant_config = quant_config
345
346
347
348
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
349
350
351
352
353
354
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
355
                quant_config=quant_config,
356
357
358
            )
        else:
            self.embed_tokens = PPMissingLayer()
359
        self.start_layer, self.end_layer, self.layers = make_layers(
360
            config.num_hidden_layers,
361
362
363
364
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
365
366
            prefix=f"{prefix}.layers",
        )
367
368
369
370
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
371

372
373
        self.aux_hidden_state_layers: tuple[int] = tuple()

374
375
376
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
377
378
379
380
381
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
gaoqiong's avatar
gaoqiong committed
382

zhuwenwen's avatar
zhuwenwen committed
383
384
385
386
387
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
388

389
390
391
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
    def forward(
        self,
394
        input_ids: Optional[torch.Tensor],
395
        positions: torch.Tensor,
396
        intermediate_tensors: Optional[IntermediateTensors],
397
        inputs_embeds: Optional[torch.Tensor] = None,
398
399
    ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
                                                        list[torch.Tensor]]]:
400
401
402
403
404
405
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
406
        else:
407
408
409
410
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

411
412
413
414
415
        aux_hidden_states = []
        for idx, layer in enumerate(
                self.layers[self.start_layer:self.end_layer]):
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
416
            hidden_states, residual = layer(positions, hidden_states, residual)
417
418
419
420
421
422
423

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

424
        hidden_states, _ = self.norm(hidden_states, residual)
425
426
427

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
428
429
        return hidden_states

430
431
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
432
433
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
434
435
436
437
438
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
439
        ]
440
        params_dict = dict(self.named_parameters())
441
        loaded_params: set[str] = set()
442
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
443
444
445
            if self.use_llama_nn:
                current_count = loaded_weight.current_count 
                total_count = loaded_weight.total_count
446
447
            if "rotary_emb.inv_freq" in name:
                continue
448
449
450
451
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
452
                continue
453
454
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
455
                # Loading kv cache quantization scales
456
457
458
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
459
460
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
461
                weight_loader(param, loaded_weight)
462
                loaded_params.add(scale_name)
463
                continue
464
465
466
467
468
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
469
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
470
                if weight_name not in name:
471
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
472
473
474
475
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
476
477
478
479

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
480
                param = params_dict[name]
481
482
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
483
                break
484
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
485
486
487
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
488
489
490
491

                if is_pp_missing_parameter(name, self):
                    continue

492
493
494
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
495
                weight_loader(param, loaded_weight)
496
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
497
            
zhuwenwen's avatar
zhuwenwen committed
498
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
gaoqiong's avatar
gaoqiong committed
499
500
501
502
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
503
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
504
505
506
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
507
508
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
509
            
zhuwenwen's avatar
zhuwenwen committed
510
            # for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
511
512
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
513
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
514
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
515
516
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
517
518
519
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
520
                    
gaoqiong's avatar
gaoqiong committed
521
                matches = re.findall(combined_words, layername)
522
                
523
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
524
525
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
526
                        
zhuwenwen's avatar
zhuwenwen committed
527
528
529
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
530
                                 
gaoqiong's avatar
gaoqiong committed
531
532
533
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
534
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
535
536
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
537
                    weight.data=weight.data.reshape(ori_shape[1], -1)
538
        else:
zhuwenwen's avatar
zhuwenwen committed
539
            os.environ['LM_NN'] = '0'
540
            os.environ['LLAMA_NN'] = '0'
gaoqiong's avatar
gaoqiong committed
541
             
542
        return loaded_params
543

Woosuk Kwon's avatar
Woosuk Kwon committed
544

545
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
546
    packed_modules_mapping = {
547
548
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
549
550
551
552
553
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
554
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
555
556
    }
    embedding_padding_modules = ["lm_head"]
557

558
559
560
561
562
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
563
564
565
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
566
567
568
        "q_fake_quantizer.qscale_act": "attn.q_scale",
        "k_fake_quantizer.qscale_act": "k_scale",
        "v_fake_quantizer.qscale_act": "v_scale",
569
570
571
572
573
574
575
576
577
578
579
580
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
581
        "norm": "model.norm",
582
    }
583

584
585
586
587
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
588
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
589
        super().__init__()
590
591
592
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
593
        self.config = config
594
        self.lora_config = lora_config
595
        self.model = self._init_model(vllm_config=vllm_config,
596
597
                                      prefix=maybe_prefix(prefix, "model"),
                                      layer_type=layer_type)
598

599
600
601
602
603
604
605
606
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
607
608
609
610
611
612
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
613
                quant_config=quant_config,
614
                prefix=maybe_prefix(prefix, "lm_head"),
615
616
            )
            if config.tie_word_embeddings:
617
618
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
619
620
621
622
623
624
625

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
626

627
628
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
629

630
631
632
633
634
635
636
    def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

637
638
639
    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
640
                    layer_type: type[nn.Module] = LlamaDecoderLayer):
641
642
643
        return LlamaModel(vllm_config=vllm_config,
                          prefix=prefix,
                          layer_type=layer_type)
644

645
646
647
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
648
649
    def forward(
        self,
650
651
        input_ids: torch.Tensor,
        positions: torch.Tensor,
652
        intermediate_tensors: Optional[IntermediateTensors] = None,
653
        inputs_embeds: Optional[torch.Tensor] = None,
654
    ) -> Union[torch.Tensor, IntermediateTensors]:
655
        model_output = self.model(input_ids, positions, intermediate_tensors,
656
                                  inputs_embeds)
657
        return model_output
658

659
660
661
662
663
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
664
        logits = self.logits_processor(self.lm_head, hidden_states,
665
666
667
                                       sampling_metadata)
        return logits

668
669
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
670
671
672
673
674
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
675
        return loader.load_weights(
676
            self.maybe_remap_mistral(name, loaded_weight)
677
            for name, loaded_weight in weights)
678

679
680
681
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
682
683
684
        self,
        name: str,
        loaded_weight: torch.Tensor,
685
    ) -> tuple[str, torch.Tensor]:
686

687
        def permute(w: torch.Tensor, n_heads: int):
688
689
690
691
692
693
694
695
696
697
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
698
        if "wk" in modules and modules[-1] == "weight":
699
700
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
701
        elif "wq" in modules and modules[-1] == "weight":
702
703
704
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

705
706
707
708
709
710
711
712
713
714
715
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = (f"{item}.{next_item}"
                             if next_item is not None else None)

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
716
717
718
                name = name.replace(item, mapping[item])

        return name, loaded_weight